From 7ee4ded08cde937a61e744c64c441732ac7149e0 Mon Sep 17 00:00:00 2001 From: luorijun Date: Wed, 26 Feb 2025 13:56:56 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=20socks=20=E8=A7=A3=E6=9E=90?= =?UTF-8?q?=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 16 +++ pkg/utils/chan.go | 6 +- server/fwd/service.go | 53 ++++---- server/fwd/service_test.go | 20 +-- server/{pkg/socks5 => fwd/socks}/auth.go | 12 +- server/{pkg/socks5 => fwd/socks}/error.go | 2 +- server/{pkg/socks5 => fwd/socks}/request.go | 48 +++---- server/{pkg/socks5 => fwd/socks}/resolver.go | 2 +- server/{pkg/socks5 => fwd/socks}/ruleset.go | 2 +- server/{pkg/socks5 => fwd/socks}/server.go | 105 +++++++++++---- template/service/service.go | 135 +++++++++++++++++++ test/server/fwd/auth_test.go | 4 +- 12 files changed, 301 insertions(+), 104 deletions(-) rename server/{pkg/socks5 => fwd/socks}/auth.go (79%) rename server/{pkg/socks5 => fwd/socks}/error.go (84%) rename server/{pkg/socks5 => fwd/socks}/request.go (85%) rename server/{pkg/socks5 => fwd/socks}/resolver.go (96%) rename server/{pkg/socks5 => fwd/socks}/ruleset.go (98%) rename server/{pkg/socks5 => fwd/socks}/server.go (59%) create mode 100644 template/service/service.go diff --git a/README.md b/README.md index 8c3f1e0..4c68f02 100644 --- a/README.md +++ b/README.md @@ -18,12 +18,28 @@ fwd 使用自定义 context 实现在一个上下文中控制 cancel,errCh 和 网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应 +数据通道池化 + ### 长期 +代理端口支持混合端口转发(支持 tcp_mux) + +数据通道支持 tcp 多路复用(分离逻辑流) + +👆 进阶黑魔法 multipath tcp + 多路复用 + 考虑一下连接安全性 内部接口 rtt 是否还有优化空间(当前30-300ms,根据内容大小增长) +### 代码清理 + +检查 slog 级别: + +ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印 + +其他级别日志就地打印,Info 只用来跟踪关键流程 + ## 开发相关 ### 环境变量 diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go index 0151d69..2f60185 100644 --- a/pkg/utils/chan.go +++ b/pkg/utils/chan.go @@ -8,12 +8,12 @@ import ( "github.com/pkg/errors" ) -func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { +func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { connCh := make(chan net.Conn) go func() { for { conn, err := ls.Accept() - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { slog.Error("接受连接失败", err) // 临时错误重试连接 var ne net.Error @@ -35,7 +35,7 @@ func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { return connCh } -func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} { +func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} { ch := make(chan struct{}) go func() { wg.Wait() diff --git a/server/fwd/service.go b/server/fwd/service.go index 5171cf2..bbcd495 100644 --- a/server/fwd/service.go +++ b/server/fwd/service.go @@ -8,9 +8,9 @@ import ( "log/slog" "net" "proxy-server/pkg/utils" + "proxy-server/server/fwd/socks" "proxy-server/server/pkg/env" "proxy-server/server/pkg/orm" - "proxy-server/server/pkg/socks5" "proxy-server/server/web/app/models" "strconv" "time" @@ -23,7 +23,7 @@ type Config struct { type Service struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -36,7 +36,7 @@ func New(config *Config) *Service { return &Service{ Config: _config, - connMap: make(map[string]socks5.ProxyData), + connMap: make(map[string]socks.ProxyData), ctrlConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{}, } @@ -95,7 +95,7 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) { defer utils.Close(ls) // 等待连接 - connCh := utils.ConnChan(ctx, ls) + connCh := utils.ChanConnAccept(ctx, ls) defer close(connCh) // 处理连接 @@ -119,7 +119,7 @@ loop: timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - procCh := utils.WaitChan(timeout, &s.ctrlConnWg) + procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg) defer close(procCh) select { @@ -144,20 +144,19 @@ func (s *Service) processCtrlConn(controller net.Conn) { reader := bufio.NewReader(controller) // 读取端口 - portBuf := make([]byte, 2) - _, err := io.ReadFull(reader, portBuf) + portBuf, err := utils.ReadBuffer(reader, 2) if err != nil { - slog.Error("读取转发端口失败", "err", err) + slog.Error("接收转发端口失败", "err", err) return } port := binary.BigEndian.Uint16(portBuf) // 新建代理服务 slog.Info("新建代理服务", "port", port) - proxy, err := socks5.New(&socks5.Config{ + proxy, err := socks.New(&socks.Config{ Name: strconv.Itoa(int(port)), Port: port, - AuthMethods: []socks5.Authenticator{ + AuthMethods: []socks.Authenticator{ &UserPassAuthenticator{}, &NoAuthAuthenticator{}, }, @@ -207,7 +206,7 @@ func (s *Service) startDataTun(ctx context.Context, errCh chan error) { defer utils.Close(lData) // 等待连接 - connCh := utils.ConnChan(ctx, lData) + connCh := utils.ChanConnAccept(ctx, lData) defer close(connCh) // 处理连接 @@ -231,7 +230,7 @@ loop: timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - procCh := utils.WaitChan(timeout, &s.dataConnWg) + procCh := utils.ChanWgWait(timeout, &s.dataConnWg) defer close(procCh) select { @@ -275,7 +274,7 @@ func (s *Service) processDataConn(client net.Conn) { // 响应用户 user := data.Conn defer utils.Close(user) - socks5.SendSuccess(user, client) + socks.SendSuccess(user, client) // 写入目标地址 _, err = client.Write([]byte{byte(len(data.Dest))}) @@ -313,11 +312,11 @@ func (s *Service) processDataConn(client net.Conn) { type NoAuthAuthenticator struct { } -func (a *NoAuthAuthenticator) Method() socks5.AuthMethod { - return socks5.NoAuth +func (a *NoAuthAuthenticator) Method() socks.AuthMethod { + return socks.NoAuth } -func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { +func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) { // 获取用户地址 conn, ok := writer.(net.Conn) @@ -332,7 +331,7 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader slog.Debug("用户的地址为 " + client) // 获取服务 - server, ok := ctx.Value("service").(*socks5.Server) + server, ok := ctx.Value("service").(*socks.Server) if !ok { return nil, errors.New("noAuth 认证失败,无法获取服务信息") } @@ -373,8 +372,8 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader } slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) - return &socks5.AuthContext{ - Method: socks5.NoAuth, + return &socks.AuthContext{ + Method: socks.NoAuth, Timeout: uint(timeout), Payload: nil, }, nil @@ -383,11 +382,11 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader type UserPassAuthenticator struct { } -func (a *UserPassAuthenticator) Method() socks5.AuthMethod { - return socks5.UserPassAuth +func (a *UserPassAuthenticator) Method() socks.AuthMethod { + return socks.UserPassAuth } -func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { +func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) { // 检查认证版本 slog.Debug("验证认证版本") @@ -395,8 +394,8 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read if err != nil { return nil, errors.Wrap(err, "读取版本号失败") } - if v != socks5.AuthVersion { - _, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure}) + if v != socks.AuthVersion { + _, err := writer.Write([]byte{socks.Version, socks.AuthFailure}) if err != nil { return nil, errors.Wrap(err, "响应认证失败") } @@ -484,14 +483,14 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read } // 响应认证成功 - _, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess}) + _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess}) if err != nil { slog.Error("响应认证失败", "err", err) return nil, err } - return &socks5.AuthContext{ - Method: socks5.UserPassAuth, + return &socks.AuthContext{ + Method: socks.UserPassAuth, Timeout: uint(timeout), Payload: nil, }, nil diff --git a/server/fwd/service_test.go b/server/fwd/service_test.go index f80c5ab..bc18ed0 100644 --- a/server/fwd/service_test.go +++ b/server/fwd/service_test.go @@ -6,7 +6,7 @@ import ( "io" "net" "proxy-server/pkg/utils" - "proxy-server/server/pkg/socks5" + socks6 "proxy-server/server/fwd/socks" "reflect" "testing" ) @@ -49,7 +49,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) { name string args args wantWriter string - want *socks5.AuthContext + want *socks6.AuthContext wantErr bool }{ // TODO: Add test cases. @@ -76,7 +76,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) { func TestNoAuthAuthenticator_Method(t *testing.T) { tests := []struct { name string - want socks5.AuthMethod + want socks6.AuthMethod }{ // TODO: Add test cases. } @@ -93,7 +93,7 @@ func TestNoAuthAuthenticator_Method(t *testing.T) { func TestService_Run(t *testing.T) { type fields struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks6.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -124,7 +124,7 @@ func TestService_Run(t *testing.T) { func TestService_processCtrlConn(t *testing.T) { type fields struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks6.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -154,7 +154,7 @@ func TestService_processCtrlConn(t *testing.T) { func TestService_processDataConn(t *testing.T) { type fields struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks6.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -184,7 +184,7 @@ func TestService_processDataConn(t *testing.T) { func TestService_startCtrlTun(t *testing.T) { type fields struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks6.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -215,7 +215,7 @@ func TestService_startCtrlTun(t *testing.T) { func TestService_startDataTun(t *testing.T) { type fields struct { Config *Config - connMap map[string]socks5.ProxyData + connMap map[string]socks6.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } @@ -252,7 +252,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) { name string args args wantWriter string - want *socks5.AuthContext + want *socks6.AuthContext wantErr bool }{ // TODO: Add test cases. @@ -279,7 +279,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) { func TestUserPassAuthenticator_Method(t *testing.T) { tests := []struct { name string - want socks5.AuthMethod + want socks6.AuthMethod }{ // TODO: Add test cases. } diff --git a/server/pkg/socks5/auth.go b/server/fwd/socks/auth.go similarity index 79% rename from server/pkg/socks5/auth.go rename to server/fwd/socks/auth.go index d7014ce..1cbdc03 100644 --- a/server/pkg/socks5/auth.go +++ b/server/fwd/socks/auth.go @@ -1,4 +1,4 @@ -package socks5 +package socks import ( "context" @@ -27,7 +27,7 @@ type Authenticator interface { } // authenticate 执行认证流程 -func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { +func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { // 版本检查 err := checkVersion(reader) @@ -46,18 +46,18 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon } // 认证客户端连接 - for _, authenticator := range server.config.AuthMethods { + for _, authenticator := range s.config.AuthMethods { method := authenticator.Method() if slices.Contains(methods, byte(method)) { slog.Debug("使用的认证方式", method) - _, err := writer.Write([]byte{SocksVersion, byte(method)}) + _, err := writer.Write([]byte{Version, byte(method)}) if err != nil { slog.Error("响应认证方式失败", err) return nil, err } - ctx := context.WithValue(context.Background(), "service", server) + ctx := context.WithValue(context.Background(), "service", s) authContext, err := authenticator.Authenticate(ctx, reader, writer) if err != nil { return nil, err @@ -67,7 +67,7 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon } // 无适用的认证方式 - _, err = writer.Write([]byte{SocksVersion, NoAcceptable}) + _, err = writer.Write([]byte{Version, NoAcceptable}) if err != nil { return nil, err } diff --git a/server/pkg/socks5/error.go b/server/fwd/socks/error.go similarity index 84% rename from server/pkg/socks5/error.go rename to server/fwd/socks/error.go index 54baaaa..3f3f24d 100644 --- a/server/pkg/socks5/error.go +++ b/server/fwd/socks/error.go @@ -1,4 +1,4 @@ -package socks5 +package socks type ConfigError string diff --git a/server/pkg/socks5/request.go b/server/fwd/socks/request.go similarity index 85% rename from server/pkg/socks5/request.go rename to server/fwd/socks/request.go index 4e3389a..508d976 100644 --- a/server/pkg/socks5/request.go +++ b/server/fwd/socks/request.go @@ -1,4 +1,4 @@ -package socks5 +package socks import ( "context" @@ -65,7 +65,7 @@ func (a AddrSpec) Address() string { return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) } -func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) { +func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) { // 检查版本 err := checkVersion(reader) @@ -95,13 +95,13 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err } // 获取目标地址 - dest, err := server.parseTarget(reader, writer) + dest, err := s.parseTarget(reader, writer) if err != nil { return nil, err } request := &Request{ - Version: SocksVersion, + Version: Version, Command: command, DestAddr: dest, bufConn: reader, @@ -110,7 +110,7 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err return request, nil } -func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) { +func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) { dest := &AddrSpec{} aTypeBuf := make([]byte, 1) @@ -152,7 +152,7 @@ func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec dest.FQDN = string(fqdnBuff) // 域名解析 - addr, err := server.config.Resolver.Resolve(dest.FQDN) + addr, err := s.config.Resolver.Resolve(dest.FQDN) if err != nil { err := sendReply(writer, hostUnreachable, nil) if err != nil { @@ -197,33 +197,33 @@ type Request struct { bufConn io.Reader } -func (server *Server) handle(req *Request, conn net.Conn) error { +func (s *Server) handle(req *Request, conn net.Conn) error { ctx := context.Background() // 目标地址重写 req.realDestAddr = req.DestAddr - if server.config.Rewriter != nil { - ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req) + if s.config.Rewriter != nil { + ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) } // 根据协商方法建立连接 switch req.Command { case ConnectCommand: - return server.handleConnect(ctx, conn, req) + return s.handleConnect(ctx, conn, req) case BindCommand: - return server.handleBind(ctx, conn, req) + return s.handleBind(ctx, conn, req) case AssociateCommand: - return server.handleAssociate(ctx, conn, req) + return s.handleAssociate(ctx, conn, req) default: return fmt.Errorf("unsupported command: %v", req.Command) } } -func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error { +func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error { // 检查规则集约束 - server.config.Logger.Printf("检查约束规则\n") - if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { + s.config.Logger.Printf("检查约束规则\n") + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) } @@ -233,12 +233,12 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req } slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接") - server.Conn <- ProxyData{conn, req.realDestAddr.Address()} + s.Conn <- ProxyData{conn, req.realDestAddr.Address()} return nil // 与目标服务器建立连接 - // server.config.Logger.Printf("与目标服务器建立连接\n") - // dial := server.config.Dial + // s.config.Logger.Printf("与目标服务器建立连接\n") + // dial := s.config.Dial // target, err := dial("tcp", req.realDestAddr.Address()) // if err != nil { // msg := err.Error() @@ -271,7 +271,7 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req // timeout := req.AuthContext.Timeout // slog.Debug("超时时间", "timeout", timeout) // - // timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + // timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second) // defer cancel() // // // 代理流量 @@ -304,9 +304,9 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req } -func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error { +func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error { // Check if this is allowed - if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) } @@ -322,9 +322,9 @@ func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Reques return nil } -func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error { +func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error { // Check if this is allowed - if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { + if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if err := sendReply(conn, ruleFailure, nil); err != nil { return fmt.Errorf("failed to send reply: %v", err) } @@ -370,7 +370,7 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { } msg := make([]byte, 6+len(addrBody)) - msg[0] = SocksVersion + msg[0] = Version msg[1] = resp msg[2] = 0 // Reserved msg[3] = addrType diff --git a/server/pkg/socks5/resolver.go b/server/fwd/socks/resolver.go similarity index 96% rename from server/pkg/socks5/resolver.go rename to server/fwd/socks/resolver.go index 230d6c5..04f45c6 100644 --- a/server/pkg/socks5/resolver.go +++ b/server/fwd/socks/resolver.go @@ -1,4 +1,4 @@ -package socks5 +package socks import ( "net" diff --git a/server/pkg/socks5/ruleset.go b/server/fwd/socks/ruleset.go similarity index 98% rename from server/pkg/socks5/ruleset.go rename to server/fwd/socks/ruleset.go index d65699d..00a6686 100644 --- a/server/pkg/socks5/ruleset.go +++ b/server/fwd/socks/ruleset.go @@ -1,4 +1,4 @@ -package socks5 +package socks import ( "context" diff --git a/server/pkg/socks5/server.go b/server/fwd/socks/server.go similarity index 59% rename from server/pkg/socks5/server.go rename to server/fwd/socks/server.go index caff406..69e780e 100644 --- a/server/pkg/socks5/server.go +++ b/server/fwd/socks/server.go @@ -1,8 +1,8 @@ -package socks5 +package socks import ( "bufio" - "errors" + "context" "fmt" "io" "log" @@ -11,10 +11,13 @@ import ( "os" "proxy-server/pkg/utils" "strconv" + "time" + + "github.com/pkg/errors" ) const ( - SocksVersion = byte(5) + Version = byte(5) ) type Config struct { @@ -47,6 +50,9 @@ type Config struct { type Server struct { config *Config + ctx context.Context + cancel context.CancelFunc + wg utils.CountWaitGroup Name string Port uint16 Conn chan ProxyData @@ -76,8 +82,12 @@ func New(conf *Config) (*Server, error) { } } + ctx, cancel := context.WithCancel(context.Background()) return &Server{ config: conf, + ctx: ctx, + cancel: cancel, + wg: utils.CountWaitGroup{}, Name: conf.Name, Port: conf.Port, Conn: make(chan ProxyData, 100), @@ -85,46 +95,83 @@ func New(conf *Config) (*Server, error) { } // Run 监听端口 -func (server *Server) Run() error { - host := server.config.Host - port := server.config.Port - addr := net.JoinHostPort(host, strconv.Itoa(int(port))) - +func (s *Server) Run() error { slog.Info("启动 socks5 代理服务") - listener, err := net.Listen("tcp", addr) + // 监听端口 + host := s.config.Host + port := s.config.Port + addr := net.JoinHostPort(host, strconv.Itoa(int(port))) + ls, err := net.Listen("tcp", addr) if err != nil { - return err + return errors.Wrap(err, "监听端口失败") } - defer utils.Close(listener) + defer utils.Close(ls) + slog.Info("正在监听端口", slog.Uint64("port", uint64(port))) - slog.Info("代理服务已启动,正在监听端口 " + addr) + // 处理连接 + connCh := utils.ChanConnAccept(s.ctx, ls) + defer close(connCh) - for { - conn, err := listener.Accept() - if err != nil { - slog.Error("客户端连接失败", err) - continue - } - - go func() { - err := server.serve(conn) - if err != nil { - slog.Error("连接异常退出", err) + err = nil + for loop := true; loop; { + select { + case <-s.ctx.Done(): + slog.Debug("服务主动停止") + loop = false + case conn, ok := <-connCh: + if !ok { + err = errors.New("意外错误,无法获取连接") + loop = false + break } - }() + s.wg.Add(1) + go func() { + defer s.wg.Done() + err := s.serve(conn) + if err != nil { + slog.Error("处理连接失败", err) + } + }() + } } + if err != nil { + s.Close() + } + + // 关闭服务 + timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + wgCh := utils.ChanWgWait(timeout, &s.wg) + + err = nil + select { + case <-timeout.Done(): + err = errors.New("关闭超时(强制关闭)") + case <-wgCh: + } + + if s.Conn != nil { + close(s.Conn) + } + + return err +} + +// Close 关闭服务 +func (s *Server) Close() { + s.cancel() } // serve 建立连接 -func (server *Server) serve(conn net.Conn) error { +func (s *Server) serve(conn net.Conn) error { slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接") reader := bufio.NewReader(conn) // 认证 slog.Debug("开始认证流程") - authContext, err := server.authenticate(reader, conn) + authContext, err := s.authenticate(reader, conn) if err != nil { utils.Close(conn) slog.Error("认证失败", err) @@ -135,7 +182,7 @@ func (server *Server) serve(conn net.Conn) error { // 处理连接请求 slog.Debug("处理连接请求") - request, err := server.request(reader, conn) + request, err := s.request(reader, conn) if err != nil { slog.Error("连接请求处理失败", err) return err @@ -156,7 +203,7 @@ func (server *Server) serve(conn net.Conn) error { // 处理请求 slog.Debug("开始代理流量") - err = server.handle(request, conn) + err = s.handle(request, conn) if err != nil { return err } @@ -173,7 +220,7 @@ func checkVersion(reader io.Reader) error { slog.Debug("客户端请求版本", "version", version) - if version != SocksVersion { + if version != Version { return errors.New("客户端版本不兼容") } diff --git a/template/service/service.go b/template/service/service.go new file mode 100644 index 0000000..efc1b45 --- /dev/null +++ b/template/service/service.go @@ -0,0 +1,135 @@ +package service + +import ( + "context" + "errors" + "io" + "log" + "net" + "strconv" + "sync" + "time" +) + +type Config struct { + Host string + Port uint16 + CloseWait time.Duration +} + +type Server struct { + config *Config + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup +} + +func New(conf *Config) (*Server, error) { + + if conf.Host == "" { + conf.Host = "localhost" + } + if conf.Port == 0 { + return nil, errors.New("port is required") + } + + ctx, cancel := context.WithCancel(context.Background()) + return &Server{ + conf, + ctx, + cancel, + sync.WaitGroup{}, + }, nil +} + +func (s *Server) Run() error { + + // start listen + addr := net.JoinHostPort(s.config.Host, strconv.Itoa(int(s.config.Port))) + ls, err := net.Listen("tcp", addr) + if err != nil { + return errors.New("failed to listen") + } + defer closeRes(ls) + + // wait accept + connCh := make(chan net.Conn) + defer close(connCh) + go func() { + for { + conn, err := ls.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + log.Println("accept failed", err) + } + // retry on temporary error + var ne net.Error + if errors.As(err, &ne) && ne.Temporary() { + continue + } + return + } + select { + case <-s.ctx.Done(): + closeRes(conn) + return + case connCh <- conn: + } + } + }() + + // handle accept + func() { + for { + select { + case <-s.ctx.Done(): + return + case conn, ok := <-connCh: + if !ok { + return + } + s.wg.Add(1) + go func() { + defer s.wg.Done() + _ = s.handle(conn) + }() + } + } + }() + + // close + timeout, cancel := context.WithTimeout(context.Background(), s.config.CloseWait) + defer cancel() + + waitCh := make(chan struct{}) + defer close(waitCh) + go func() { + s.wg.Wait() + select { + case <-timeout.Done(): // waitCh may be closed + case waitCh <- struct{}{}: + } + }() + + err = nil + select { + case <-timeout.Done(): + err = errors.New("close timeout") + case <-waitCh: + } + + return err +} + +func (s *Server) Close() { + s.cancel() +} + +func (s *Server) handle(conn net.Conn) error { + defer closeRes(conn) + return nil +} + +func closeRes[T io.Closer](res T) { + _ = res.Close() +} diff --git a/test/server/fwd/auth_test.go b/test/server/fwd/auth_test.go index 85a1476..d556615 100644 --- a/test/server/fwd/auth_test.go +++ b/test/server/fwd/auth_test.go @@ -2,7 +2,7 @@ package fwd import ( "net" - "proxy-server/server/pkg/socks5" + socks6 "proxy-server/server/fwd/socks" "testing" ) @@ -26,7 +26,7 @@ func fakeRequest() { } // 发送认证请求 - _, err = conn.Write([]byte{socks5.SocksVersion, byte(1), byte(socks5.NoAuth)}) + _, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)}) if err != nil { panic(err) }