实现自定义 wg 以统计协程数量

This commit is contained in:
2025-02-25 15:44:09 +08:00
parent 7f23e2741f
commit 9a8680a221
4 changed files with 92 additions and 64 deletions

48
pkg/utils/chan.go Normal file
View File

@@ -0,0 +1,48 @@
package utils
import (
"context"
"log/slog"
"net"
"github.com/pkg/errors"
)
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
connCh := make(chan net.Conn)
go func() {
for {
conn, err := ls.Accept()
if err != nil {
slog.Error("接受连接失败", err)
// 临时错误重试连接
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
slog.Debug("临时错误重试")
continue
}
return
}
// ctx 取消后退出
select {
case <-ctx.Done():
Close(conn)
return
case connCh <- conn:
}
}
}()
return connCh
}
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} {
ch := make(chan struct{})
go func() {
wg.Wait()
select {
case <-ctx.Done():
case ch <- struct{}{}:
}
}()
return ch
}

29
pkg/utils/sync.go Normal file
View File

@@ -0,0 +1,29 @@
package utils
import (
"sync"
"sync/atomic"
)
type CountWaitGroup struct {
wg sync.WaitGroup
num atomic.Uint64
}
func (c *CountWaitGroup) Add(delta uint64) {
c.wg.Add(int(delta))
c.num.Add(delta)
}
func (c *CountWaitGroup) Done() {
c.wg.Done()
c.num.Add(-1)
}
func (c *CountWaitGroup) Wait() {
c.wg.Wait()
}
func (c *CountWaitGroup) Count() uint64 {
return c.num.Load()
}

View File

@@ -1,13 +1,8 @@
package utils
import (
"context"
"io"
"log/slog"
"net"
"sync"
"github.com/pkg/errors"
)
func ReadByte(reader io.Reader) (byte, error) {
@@ -36,42 +31,3 @@ func Close[T io.Closer](v T) {
slog.Warn("对象关闭失败", "err", err)
}
}
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
connCh := make(chan net.Conn)
go func() {
for {
conn, err := ls.Accept()
if err != nil {
slog.Error("接受连接失败", err)
// 临时错误重试连接
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
slog.Debug("临时错误重试")
continue
}
return
}
// ctx 取消后退出
select {
case <-ctx.Done():
Close(conn)
return
case connCh <- conn:
}
}
}()
return connCh
}
func WaitChan(ctx context.Context, wg *sync.WaitGroup) chan struct{} {
ch := make(chan struct{})
go func() {
wg.Wait()
select {
case <-ctx.Done():
case ch <- struct{}{}:
}
}()
return ch
}

View File

@@ -13,7 +13,6 @@ import (
"proxy-server/server/pkg/socks5"
"proxy-server/server/web/app/models"
"strconv"
"sync"
"time"
"github.com/pkg/errors"
@@ -24,9 +23,9 @@ type Config struct {
type Service struct {
Config *Config
ConnMap map[string]socks5.ProxyData
ctrlConnWg sync.WaitGroup
dataConnWg sync.WaitGroup
connMap map[string]socks5.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
func New(config *Config) *Service {
@@ -37,9 +36,9 @@ func New(config *Config) *Service {
return &Service{
Config: _config,
ConnMap: make(map[string]socks5.ProxyData),
ctrlConnWg: sync.WaitGroup{},
dataConnWg: sync.WaitGroup{},
connMap: make(map[string]socks5.ProxyData),
ctrlConnWg: utils.CountWaitGroup{},
dataConnWg: utils.CountWaitGroup{},
}
}
@@ -111,6 +110,7 @@ loop:
slog.Debug("结束处理连接,由于获取连接失败")
break loop
}
s.ctrlConnWg.Add(1)
go s.processCtrlConn(conn)
}
}
@@ -190,7 +190,7 @@ func (s *Service) processCtrlConn(controller net.Conn) {
slog.Error("write error", err)
return
}
s.ConnMap[tag] = user
s.connMap[tag] = user
}
}
@@ -222,6 +222,7 @@ loop:
slog.Debug("结束处理连接,由于获取连接失败")
break loop
}
s.dataConnWg.Add(1)
go s.processDataConn(conn)
}
}
@@ -245,7 +246,10 @@ loop:
}
func (s *Service) processDataConn(client net.Conn) {
defer func() {
s.dataConnWg.Done()
utils.Close(client)
}()
slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String())
// 读取 tag
@@ -262,7 +266,7 @@ func (s *Service) processDataConn(client net.Conn) {
tag := string(tagBuf)
// 找到用户连接
data, ok := s.ConnMap[tag]
data, ok := s.connMap[tag]
if !ok {
slog.Error("no such connection")
return
@@ -270,6 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
// 响应用户
user := data.Conn
defer utils.Close(user)
socks5.SendSuccess(user, client)
// 写入目标地址
@@ -303,16 +308,6 @@ func (s *Service) processDataConn(client net.Conn) {
}()
<-errCh
slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest)
defer func() {
err := user.Close()
if err != nil {
slog.Error("close error", err)
}
err = client.Close()
if err != nil {
slog.Error("close error", err)
}
}()
}
type NoAuthAuthenticator struct {