From 9a8680a221f1bda8716a4c730c80f76a49b2e17a Mon Sep 17 00:00:00 2001 From: luorijun Date: Tue, 25 Feb 2025 15:44:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=20wg=20=E4=BB=A5=E7=BB=9F=E8=AE=A1=E5=8D=8F=E7=A8=8B=E6=95=B0?= =?UTF-8?q?=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/utils/chan.go | 48 +++++++++++++++++++++++++++++++++++++++++++ pkg/utils/sync.go | 29 ++++++++++++++++++++++++++ pkg/utils/utils.go | 44 --------------------------------------- server/fwd/service.go | 35 ++++++++++++++----------------- 4 files changed, 92 insertions(+), 64 deletions(-) create mode 100644 pkg/utils/chan.go create mode 100644 pkg/utils/sync.go diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go new file mode 100644 index 0000000..0151d69 --- /dev/null +++ b/pkg/utils/chan.go @@ -0,0 +1,48 @@ +package utils + +import ( + "context" + "log/slog" + "net" + + "github.com/pkg/errors" +) + +func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { + connCh := make(chan net.Conn) + go func() { + for { + conn, err := ls.Accept() + if err != nil { + slog.Error("接受连接失败", err) + // 临时错误重试连接 + var ne net.Error + if errors.As(err, &ne) && ne.Temporary() { + slog.Debug("临时错误重试") + continue + } + return + } + // ctx 取消后退出 + select { + case <-ctx.Done(): + Close(conn) + return + case connCh <- conn: + } + } + }() + return connCh +} + +func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} { + ch := make(chan struct{}) + go func() { + wg.Wait() + select { + case <-ctx.Done(): + case ch <- struct{}{}: + } + }() + return ch +} diff --git a/pkg/utils/sync.go b/pkg/utils/sync.go new file mode 100644 index 0000000..984ab54 --- /dev/null +++ b/pkg/utils/sync.go @@ -0,0 +1,29 @@ +package utils + +import ( + "sync" + "sync/atomic" +) + +type CountWaitGroup struct { + wg sync.WaitGroup + num atomic.Uint64 +} + +func (c *CountWaitGroup) Add(delta uint64) { + c.wg.Add(int(delta)) + c.num.Add(delta) +} + +func (c *CountWaitGroup) Done() { + c.wg.Done() + c.num.Add(-1) +} + +func (c *CountWaitGroup) Wait() { + c.wg.Wait() +} + +func (c *CountWaitGroup) Count() uint64 { + return c.num.Load() +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index c21399d..c288da2 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,13 +1,8 @@ package utils import ( - "context" "io" "log/slog" - "net" - "sync" - - "github.com/pkg/errors" ) func ReadByte(reader io.Reader) (byte, error) { @@ -36,42 +31,3 @@ func Close[T io.Closer](v T) { slog.Warn("对象关闭失败", "err", err) } } - -func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { - connCh := make(chan net.Conn) - go func() { - for { - conn, err := ls.Accept() - if err != nil { - slog.Error("接受连接失败", err) - // 临时错误重试连接 - var ne net.Error - if errors.As(err, &ne) && ne.Temporary() { - slog.Debug("临时错误重试") - continue - } - return - } - // ctx 取消后退出 - select { - case <-ctx.Done(): - Close(conn) - return - case connCh <- conn: - } - } - }() - return connCh -} - -func WaitChan(ctx context.Context, wg *sync.WaitGroup) chan struct{} { - ch := make(chan struct{}) - go func() { - wg.Wait() - select { - case <-ctx.Done(): - case ch <- struct{}{}: - } - }() - return ch -} diff --git a/server/fwd/service.go b/server/fwd/service.go index 8ed4500..913ff14 100644 --- a/server/fwd/service.go +++ b/server/fwd/service.go @@ -13,7 +13,6 @@ import ( "proxy-server/server/pkg/socks5" "proxy-server/server/web/app/models" "strconv" - "sync" "time" "github.com/pkg/errors" @@ -24,9 +23,9 @@ type Config struct { type Service struct { Config *Config - ConnMap map[string]socks5.ProxyData - ctrlConnWg sync.WaitGroup - dataConnWg sync.WaitGroup + connMap map[string]socks5.ProxyData + ctrlConnWg utils.CountWaitGroup + dataConnWg utils.CountWaitGroup } func New(config *Config) *Service { @@ -37,9 +36,9 @@ func New(config *Config) *Service { return &Service{ Config: _config, - ConnMap: make(map[string]socks5.ProxyData), - ctrlConnWg: sync.WaitGroup{}, - dataConnWg: sync.WaitGroup{}, + connMap: make(map[string]socks5.ProxyData), + ctrlConnWg: utils.CountWaitGroup{}, + dataConnWg: utils.CountWaitGroup{}, } } @@ -111,6 +110,7 @@ loop: slog.Debug("结束处理连接,由于获取连接失败") break loop } + s.ctrlConnWg.Add(1) go s.processCtrlConn(conn) } } @@ -190,7 +190,7 @@ func (s *Service) processCtrlConn(controller net.Conn) { slog.Error("write error", err) return } - s.ConnMap[tag] = user + s.connMap[tag] = user } } @@ -222,6 +222,7 @@ loop: slog.Debug("结束处理连接,由于获取连接失败") break loop } + s.dataConnWg.Add(1) go s.processDataConn(conn) } } @@ -245,7 +246,10 @@ loop: } func (s *Service) processDataConn(client net.Conn) { - + defer func() { + s.dataConnWg.Done() + utils.Close(client) + }() slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String()) // 读取 tag @@ -262,7 +266,7 @@ func (s *Service) processDataConn(client net.Conn) { tag := string(tagBuf) // 找到用户连接 - data, ok := s.ConnMap[tag] + data, ok := s.connMap[tag] if !ok { slog.Error("no such connection") return @@ -270,6 +274,7 @@ func (s *Service) processDataConn(client net.Conn) { // 响应用户 user := data.Conn + defer utils.Close(user) socks5.SendSuccess(user, client) // 写入目标地址 @@ -303,16 +308,6 @@ func (s *Service) processDataConn(client net.Conn) { }() <-errCh slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) - defer func() { - err := user.Close() - if err != nil { - slog.Error("close error", err) - } - err = client.Close() - if err != nil { - slog.Error("close error", err) - } - }() } type NoAuthAuthenticator struct {