diff --git a/README.md b/README.md index 964eb5c..3d5bf59 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,7 @@ ## todo -日志格式自定义转换 - -客户端断开后端口未释放问题 - -需要压测 - ProxyConn 直接实现 Conn 相同的接口,不再取出 Conn 使用 -配置退出等待时间 - -输出错误堆栈 - 读取 conn 时加上超时机制 代理节点超时控制 @@ -24,8 +14,12 @@ ProxyConn 直接实现 Conn 相同的接口,不再取出 Conn 使用 检查退出超时的问题 +补全测试 + ### 长期 +配置退出等待时间 + 需要测试,考虑是否切换到 gnet 实现一个 socks context 以在子组件中获取 socks 相关信息 diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go index 9045ced..b01469b 100644 --- a/server/fwd/dispatcher/dispatch.go +++ b/server/fwd/dispatcher/dispatch.go @@ -9,6 +9,7 @@ import ( "proxy-server/server/fwd/http" "proxy-server/server/fwd/socks" "strconv" + "strings" "time" "github.com/pkg/errors" @@ -54,36 +55,43 @@ func (s *Server) Run() error { m.SetReadTimeout(5 * time.Second) defer m.Close() - go func() { - <-s.ctx.Done() - close(s.Conn) - m.Close() - }() - socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05}))) - defer utils.Close(socksLs) go func() { err = s.acceptSocks(socksLs) if err != nil { - slog.Error("dispatcher socks accept error", "err", err) + if strings.Contains(err.Error(), "mux: server closed") { + return + } + slog.Warn("dispatcher socks accept error", "err", err) } }() httpLs := m.Match(cmux.HTTP1Fast("PATCH")) - defer utils.Close(httpLs) go func() { err = s.acceptHttp(httpLs) if err != nil { - slog.Error("dispatcher http accept error", "err", err) + if strings.Contains(err.Error(), "mux: server closed") { + return + } + slog.Warn("dispatcher http accept error", "err", err) } }() - err = m.Serve() - if err != nil { - return errors.Wrap(err, "dispatcher serve error") - } + errCh := make(chan error) + go func() { + err = m.Serve() + if err != nil { + err = errors.Wrap(err, "dispatcher serve error") + } + errCh <- err + }() - return nil + select { + case <-s.ctx.Done(): + return nil + case err := <-errCh: + return err + } } func (s *Server) acceptHttp(ls net.Listener) error { diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index 5a037b7..efb362d 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -12,6 +12,7 @@ import ( "proxy-server/server/fwd/dispatcher" "proxy-server/server/pkg/env" "strconv" + "strings" "sync" "github.com/pkg/errors" @@ -167,7 +168,6 @@ func (s *Service) startCtrlTun() error { func (s *Service) processCtrlConn(conn net.Conn) error { slog.Debug("客户端连入", "addr", conn.RemoteAddr().String()) - reader := bufio.NewReader(conn) // 获取转发端口 @@ -177,20 +177,58 @@ func (s *Service) processCtrlConn(conn net.Conn) error { } port := binary.BigEndian.Uint16(portBuf) - // 开放转发端口 + // 启动转发服务 + proxy, err := dispatcher.New(port) + if err != nil { + return errors.Wrap(err, "创建 socks 转发服务失败") + } + defer proxy.Close() + s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() - err := s.startFwdTun(port) + err := proxy.Run() if err != nil { - slog.Error("代理服务启动失败", "err", err) + slog.Error("代理服务运行失败", "err", err) return } }() - // 记录控制连接 - s.ctrlConnMap.Store(port, conn) - return nil + // 监听客户端连接 + errCh := make(chan error) + defer close(errCh) + go func() { + _, err := reader.ReadByte() + errCh <- err + }() + + // 处理连接 + for { + select { + case <-s.ctx.Done(): + return nil + case err := <-errCh: + switch { + case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."): + slog.Debug("客户端主动断开连接") + return nil + case err == nil: + return errors.New("客户端握手失败") + default: + return errors.Wrap(err, "客户端意外断开连接") + } + case user := <-proxy.Conn: + s.userConnWg.Add(1) + go func() { + defer s.userConnWg.Done() + err := s.processUserConn(user, conn) + if err != nil { + slog.Error("处理用户连接失败", "err", err) + utils.Close(user) + } + }() + } + } } func (s *Service) startDataTun() error { @@ -312,71 +350,22 @@ func (s *Service) processDataConn(client net.Conn) error { return nil } -func (s *Service) startFwdTun(port uint16) error { - slog.Debug("监听转发通道", "port", port) - - proxy, err := dispatcher.New(port) - if err != nil { - return errors.Wrap(err, "创建 socks 转发服务失败") - } - defer proxy.Close() - - errCh := make(chan error) - defer close(errCh) - go func() { - err := proxy.Run() - errCh <- err - }() - - for { - select { - case <-s.ctx.Done(): - return nil - case err := <-errCh: - if err != nil { - return errors.Wrap(err, "转发服务发生错误") - } - case conn := <-proxy.Conn: - s.userConnWg.Add(1) - go func() { - defer s.userConnWg.Done() - err := s.processUserConn(conn, port) - if err != nil { - slog.Error("处理用户连接失败", "err", err) - } - }() - } - } -} - -func (s *Service) processUserConn(conn *core.Conn, port uint16) error { - - // 记录用户连接 - s.userConnMap.Store(conn.Tag, conn) - - // 通知客户端建立数据通道 - ctrlConnAny, ok := s.ctrlConnMap.Load(port) - if !ok { - return errors.New("查找控制连接失败") - } - ctrlConn := ctrlConnAny.(net.Conn) +func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { // 发送 tag - select { - case <-s.ctx.Done(): - return nil - default: - tag := conn.Tag - tagLen := len(tag) - tagBuf := make([]byte, 1+tagLen) - tagBuf[0] = byte(tagLen) - copy(tagBuf[1:], tag) - _, err := ctrlConn.Write(tagBuf) - if err != nil { - return errors.Wrap(err, "向控制通道发送 tag 失败") - } + tag := user.Tag + tagLen := len(tag) + tagBuf := make([]byte, 1+tagLen) + tagBuf[0] = byte(tagLen) + copy(tagBuf[1:], tag) + + _, err := ctrl.Write(tagBuf) + if err != nil { + return errors.Wrap(err, "向控制通道发送 tag 失败") } - return nil + // 记录用户连接 + s.userConnMap.Store(user.Tag, user) + return nil } diff --git a/server/pkg/env/env.go b/server/pkg/env/env.go index f11514d..48e7345 100644 --- a/server/pkg/env/env.go +++ b/server/pkg/env/env.go @@ -2,11 +2,8 @@ package env import ( "fmt" - "log/slog" "os" "strconv" - - "github.com/joho/godotenv" ) var ( @@ -24,12 +21,6 @@ var ( func Init() { - // 加载 .env 文件 - err := godotenv.Load() - if err != nil { - slog.Debug("没有本地环境变量文件") - } - // AppCtrlPort appCtrlPortStr := os.Getenv("APP_CTRL_PORT") if appCtrlPortStr == "" { diff --git a/server/pkg/log/logs.go b/server/pkg/log/logs.go new file mode 100644 index 0000000..ff7fec0 --- /dev/null +++ b/server/pkg/log/logs.go @@ -0,0 +1,41 @@ +package log + +import ( + "log/slog" + "os" + "time" + + "github.com/lmittmann/tint" + "github.com/mattn/go-colorable" +) + +func Init() { + mode := os.Getenv("APP_LOG_MODE") + if mode == "" { + mode = "dev" + } + + switch mode { + case "dev": + writer := colorable.NewColorable(os.Stdout) + logger := slog.New(tint.NewHandler(writer, &tint.Options{ + Level: slog.LevelDebug, + TimeFormat: time.RFC3339, + ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { + err, ok := attr.Value.Any().(error) + if ok { + return tint.Err(err) + } + return attr + }, + })) + slog.SetDefault(logger) + case "test": + logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelInfo, + })) + slog.SetDefault(logger) + default: + panic("日志模式错误") + } +} diff --git a/server/server.go b/server/server.go index ef95183..b775ead 100644 --- a/server/server.go +++ b/server/server.go @@ -8,13 +8,13 @@ import ( "proxy-server/pkg/utils" "proxy-server/server/fwd" "proxy-server/server/pkg/env" + "proxy-server/server/pkg/log" "proxy-server/server/pkg/orm" "sync" "syscall" "time" - "github.com/lmittmann/tint" - "github.com/mattn/go-colorable" + "github.com/joho/godotenv" ) type Context struct { @@ -25,7 +25,12 @@ type Context struct { func Start() { // 初始化 - initLog() + err := godotenv.Load() + if err != nil { + println("没有本地环境变量文件") + } + + log.Init() env.Init() orm.Init() @@ -77,28 +82,6 @@ func Start() { time.Sleep(3 * time.Second) } -func initLog() { - switch env.AppLogMode { - case "dev": - writer := colorable.NewColorable(os.Stdout) - logger := slog.New(tint.NewHandler(writer, &tint.Options{ - Level: slog.LevelDebug, - TimeFormat: time.RFC3339, - ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { - err, ok := attr.Value.Any().(error) - if ok { - return tint.Err(err) - } - return attr - }, - })) - slog.SetDefault(logger) - case "test": - logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{})) - slog.SetDefault(logger) - } -} - func startFwdServer(ctx context.Context) error { server := fwd.New(nil) go func() {