diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go new file mode 100644 index 0000000..0151d69 --- /dev/null +++ b/pkg/utils/chan.go @@ -0,0 +1,48 @@ +package utils + +import ( + "context" + "log/slog" + "net" + + "github.com/pkg/errors" +) + +func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { + connCh := make(chan net.Conn) + go func() { + for { + conn, err := ls.Accept() + if err != nil { + slog.Error("接受连接失败", err) + // 临时错误重试连接 + var ne net.Error + if errors.As(err, &ne) && ne.Temporary() { + slog.Debug("临时错误重试") + continue + } + return + } + // ctx 取消后退出 + select { + case <-ctx.Done(): + Close(conn) + return + case connCh <- conn: + } + } + }() + return connCh +} + +func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} { + ch := make(chan struct{}) + go func() { + wg.Wait() + select { + case <-ctx.Done(): + case ch <- struct{}{}: + } + }() + return ch +} diff --git a/pkg/utils/sync.go b/pkg/utils/sync.go new file mode 100644 index 0000000..984ab54 --- /dev/null +++ b/pkg/utils/sync.go @@ -0,0 +1,29 @@ +package utils + +import ( + "sync" + "sync/atomic" +) + +type CountWaitGroup struct { + wg sync.WaitGroup + num atomic.Uint64 +} + +func (c *CountWaitGroup) Add(delta uint64) { + c.wg.Add(int(delta)) + c.num.Add(delta) +} + +func (c *CountWaitGroup) Done() { + c.wg.Done() + c.num.Add(-1) +} + +func (c *CountWaitGroup) Wait() { + c.wg.Wait() +} + +func (c *CountWaitGroup) Count() uint64 { + return c.num.Load() +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c21399d..c288da2 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,13 +1,8 @@ package utils import ( - "context" "io" "log/slog" - "net" - "sync" - - "github.com/pkg/errors" ) func ReadByte(reader io.Reader) (byte, error) { @@ -36,42 +31,3 @@ func Close[T io.Closer](v T) { slog.Warn("对象关闭失败", "err", err) } } - -func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { - connCh := make(chan net.Conn) - go func() { - for { - conn, err := ls.Accept() - if err != nil { - slog.Error("接受连接失败", err) - // 临时错误重试连接 - var ne net.Error - if errors.As(err, &ne) && ne.Temporary() { - slog.Debug("临时错误重试") - continue - } - return - } - // ctx 取消后退出 - select { - case <-ctx.Done(): - Close(conn) - return - case connCh <- conn: - } - } - }() - return connCh -} - -func WaitChan(ctx context.Context, wg *sync.WaitGroup) chan struct{} { - ch := make(chan struct{}) - go func() { - wg.Wait() - select { - case <-ctx.Done(): - case ch <- struct{}{}: - } - }() - return ch -} diff --git a/server/fwd/service.go b/server/fwd/service.go index 8ed4500..913ff14 100644 --- a/server/fwd/service.go +++ b/server/fwd/service.go @@ -13,7 +13,6 @@ import ( "proxy-server/server/pkg/socks5" "proxy-server/server/web/app/models" "strconv" - "sync" "time" "github.com/pkg/errors" @@ -24,9 +23,9 @@ type Config struct { type Service struct { Config *Config - ConnMap map[string]socks5.ProxyData - ctrlConnWg sync.WaitGroup - dataConnWg sync.WaitGroup + connMap map[string]socks5.ProxyData + ctrlConnWg utils.CountWaitGroup + dataConnWg utils.CountWaitGroup } func New(config *Config) *Service { @@ -37,9 +36,9 @@ func New(config *Config) *Service { return &Service{ Config: _config, - ConnMap: make(map[string]socks5.ProxyData), - ctrlConnWg: sync.WaitGroup{}, - dataConnWg: sync.WaitGroup{}, + connMap: make(map[string]socks5.ProxyData), + ctrlConnWg: utils.CountWaitGroup{}, + dataConnWg: utils.CountWaitGroup{}, } } @@ -111,6 +110,7 @@ loop: slog.Debug("结束处理连接,由于获取连接失败") break loop } + s.ctrlConnWg.Add(1) go s.processCtrlConn(conn) } } @@ -190,7 +190,7 @@ func (s *Service) processCtrlConn(controller net.Conn) { slog.Error("write error", err) return } - s.ConnMap[tag] = user + s.connMap[tag] = user } } @@ -222,6 +222,7 @@ loop: slog.Debug("结束处理连接,由于获取连接失败") break loop } + s.dataConnWg.Add(1) go s.processDataConn(conn) } } @@ -245,7 +246,10 @@ loop: } func (s *Service) processDataConn(client net.Conn) { - + defer func() { + s.dataConnWg.Done() + utils.Close(client) + }() slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String()) // 读取 tag @@ -262,7 +266,7 @@ func (s *Service) processDataConn(client net.Conn) { tag := string(tagBuf) // 找到用户连接 - data, ok := s.ConnMap[tag] + data, ok := s.connMap[tag] if !ok { slog.Error("no such connection") return @@ -270,6 +274,7 @@ func (s *Service) processDataConn(client net.Conn) { // 响应用户 user := data.Conn + defer utils.Close(user) socks5.SendSuccess(user, client) // 写入目标地址 @@ -303,16 +308,6 @@ func (s *Service) processDataConn(client net.Conn) { }() <-errCh slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) - defer func() { - err := user.Close() - if err != nil { - slog.Error("close error", err) - } - err = client.Close() - if err != nil { - slog.Error("close error", err) - } - }() } type NoAuthAuthenticator struct {