From c2dcae7af54b8ef5cc15e35705f85518940b16dd Mon Sep 17 00:00:00 2001 From: luorijun Date: Mon, 26 May 2025 16:37:54 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E8=BF=9E=E6=8E=A5=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=8A=A0=E8=B6=85?= =?UTF-8?q?=E6=97=B6=E8=AE=BE=E7=BD=AE=EF=BC=9B=E9=87=8D=E6=9E=84=E5=91=BD?= =?UTF-8?q?=E4=BB=A4=E8=AF=BB=E5=8F=96=E4=B8=8E=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=EF=BC=9B=E6=96=B0=E5=A2=9E=E5=85=AC=E5=85=B1=E5=B7=A5?= =?UTF-8?q?=E5=85=B7=E5=87=BD=E6=95=B0=E4=BB=A5=E7=AE=80=E5=8C=96=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 2 +- edge/edge.go | 110 ++++++++++++-------------------- gateway/env/env.go | 50 ++++++++++++--- gateway/fwd/ctrl.go | 148 +++++++++++++++++++++++--------------------- gateway/fwd/data.go | 6 +- gateway/fwd/user.go | 2 +- utils/conn.go | 41 ++++++++++++ 7 files changed, 205 insertions(+), 154 deletions(-) create mode 100644 utils/conn.go diff --git a/README.md b/README.md index 4865bc9..219699f 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## TODO -- 连接断开时尽量由 +- 将协议内容抽离出公共包,gateway 和 edge 节点共同调用 ## 开发相关 diff --git a/edge/edge.go b/edge/edge.go index 73dde7f..f7202f3 100644 --- a/edge/edge.go +++ b/edge/edge.go @@ -21,6 +21,8 @@ import ( ) func Start() error { + var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() // 初始化环境变量 slog.Debug("初始化环境变量...") @@ -44,36 +46,17 @@ func Start() error { } // 连接到网关 - var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) - defer cancel() - var errCh = make(chan error) go func() { - for { - err = ctrl(ctx, id, host) - if err == nil { - errCh <- nil - return - } - - select { - case <-ctx.Done(): - return - default: - slog.Error("建立控制通道失败", "err", err) - slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval)) - } - - select { - case <-ctx.Done(): - return - case <-time.After(time.Duration(core.RetryInterval) * time.Second): - } + err = ctrl(ctx, id, host) + if err == nil { + errCh <- err } }() // 等待退出 select { + case <-ctx.Done(): case err := <-errCh: if err != nil { slog.Error("控制通道发生错误", "err", err) @@ -102,7 +85,7 @@ func ctrl(ctx context.Context, id int32, host string) error { defer utils.Close(conn) var reader = bufio.NewReader(conn) - // 发送节点连接命令 + // 发送开启连接 slog.Debug("发送节点连接命令") err = sendOpen(conn, id) if err != nil { @@ -126,78 +109,55 @@ func ctrl(ctx context.Context, id int32, host string) error { } }() - // 异步等待连接命令 - slog.Info("等待用户连接") - var cmdCh = make(chan ConnCmd) + // 异步读取节点命令 var errCh = make(chan error) go func() { for { + // 读取命令 cmd, err := reader.ReadByte() - if err != nil { - switch { - case errors.Is(err, net.ErrClosed): - err = fmt.Errorf("控制通道关闭: %w", err) - case errors.Is(err, io.EOF): - err = fmt.Errorf("网关关闭了控制通道: %w", err) - default: - err = fmt.Errorf("读取命令失败: %w", err) - } + if err := utils.WarpConnErr(err); err != nil { errCh <- err return } - switch cmd { + + // pong 命令,忽略 case 1: - // 忽略网关响应的 pong 命令 + + // 代理命令 case 5: - tag, addr, err := onConn(reader) + err := onConn(reader, dataAddr) if err != nil { - slog.Error("接收连接命令失败", "err", err) + errCh <- fmt.Errorf("处理代理命令失败: %w", err) return } - cmdCh <- ConnCmd{ - Tag: tag, - Addr: addr, - } } } }() // 等待建立数据通道 - for loop := true; loop; { - select { - case <-ctx.Done(): - loop = false - case err = <-errCh: - slog.Error("读取控制命令失败", "err", err) - loop = false - case cmd := <-cmdCh: - slog.Debug("建立数据通道", "tag", cmd.Tag, "addr", cmd.Addr) - go func() { - err := data(dataAddr, cmd.Addr, cmd.Tag) - if err != nil { - slog.Error("建立数据通道失败", "err", err) - } - }() - } + select { + case <-ctx.Done(): + case err = <-errCh: } - // 发送关闭连接(不 return err,否则会重新连接) + // 发送关闭连接 slog.Debug("发送关闭连接") err = sendClose(conn) if err != nil { - slog.Error("发送关闭连接失败", "err", err) + return fmt.Errorf("发送关闭连接失败: %w", err) } return nil } -func data(proxy string, dest string, tag [16]byte) error { +func data(proxy string, destination string, tag [16]byte) error { + slog.Debug("建立数据通道", "tag", tag, "addr", destination) // 向目标地址建立连接 var result = 1 var dstErr error - dst, err := net.Dial("tcp", dest) + dst, err := net.Dial("tcp", destination) if err != nil { dstErr = fmt.Errorf("连接目标地址失败: %w", dstErr) result = 0 @@ -295,24 +255,34 @@ func sendPing(writer io.Writer) error { return nil } -func onConn(reader io.Reader) (tag [16]byte, addr string, err error) { +func onConn(reader io.Reader, dataAddr string) (err error) { + + // 读取连接命令 var buf = make([]byte, 16+2) _, err = io.ReadFull(reader, buf) if err != nil { - return [16]byte{}, "", err + return err } - tag = [16]byte(buf[0:16]) + var tag = [16]byte(buf[0:16]) var addrLen = binary.BigEndian.Uint16(buf[16:18]) var addrBuf = make([]byte, addrLen) _, err = io.ReadFull(reader, addrBuf) if err != nil { - return [16]byte{}, "", err + return err } + var addr = string(addrBuf) - addr = string(addrBuf) - return tag, addr, nil + // 异步建立数据通道 + go func() { + err := data(dataAddr, addr, tag) + if err != nil { + slog.Error("建立数据通道失败", "err", err) + } + }() + + return nil } type ConnCmd struct { diff --git a/gateway/env/env.go b/gateway/env/env.go index d4abea6..62042dd 100644 --- a/gateway/env/env.go +++ b/gateway/env/env.go @@ -14,13 +14,16 @@ import ( var ( RunMode = "dev" // 运行模式,dev: 开发模式,prod: 生产模式 - AppCtrlPort uint16 = 18080 - AppDataPort uint16 = 18081 - AppWebPort uint16 = 8848 - AppLogMode = "dev" - AppExitTimeout = 5 // 等待服务停止的超时时间 - AppDataTimeout = 10 // 等待数据通道连接的超时时间 - AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程) + AppCtrlPort uint16 = 18080 + AppDataPort uint16 = 18081 + AppWebPort uint16 = 8848 + AppLogMode = "dev" + AppExitTimeout = 5 // 等待服务停止的超时时间 + AppDataTimeout = 10 // 等待数据通道连接的超时时间 + AppUserRWTimeout = 10 // 等待用户连接读写超时时间 + AppDataRWTimeout = 10 // 等待数据通道读写超时时间 + AppCtrlRWTimeout = 10 // 等待控制通道读写超时时间 + AppCtrlHBTimeout = 30 // 控制通道心跳超时时间(断开连接等待时间为:心跳等待时间 * 2 + 读写等待时间) AuthWhitelist []net.IP // 全局白名单,可以将白名单 IP 视为一个可信任代理 @@ -103,13 +106,40 @@ func Init() { AppDataTimeout = appDataTimeout } - value = os.Getenv("APP_USER_TIMEOUT") + value = os.Getenv("APP_USER_RW_TIMEOUT") if value != "" { appUserTimeout, err := strconv.Atoi(value) if err != nil { - panic(fmt.Sprintf("环境变量 APP_USER_TIMEOUT 格式错误: %v", err)) + panic(fmt.Sprintf("环境变量 APP_USER_RW_TIMEOUT 格式错误: %v", err)) } - AppUserTimeout = appUserTimeout + AppUserRWTimeout = appUserTimeout + } + + value = os.Getenv("APP_DATA_RW_TIMEOUT") + if value != "" { + appDataRWTimeout, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_DATA_RW_TIMEOUT 格式错误: %v", err)) + } + AppDataRWTimeout = appDataRWTimeout + } + + value = os.Getenv("APP_CTRL_RW_TIMEOUT") + if value != "" { + appCtrlRWTimeout, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_CTRL_RW_TIMEOUT 格式错误: %v", err)) + } + AppCtrlRWTimeout = appCtrlRWTimeout + } + + value = os.Getenv("APP_CTRL_HB_TIMEOUT") + if value != "" { + appCtrlHBTimeout, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_CTRL_HB_TIMEOUT 格式错误: %v", err)) + } + AppCtrlHBTimeout = appCtrlHBTimeout } value = os.Getenv("AUTH_WHITELIST") diff --git a/gateway/fwd/ctrl.go b/gateway/fwd/ctrl.go index 9e093dc..ab5717b 100644 --- a/gateway/fwd/ctrl.go +++ b/gateway/fwd/ctrl.go @@ -14,7 +14,7 @@ import ( "proxy-server/gateway/report" "proxy-server/utils" "strconv" - "syscall" + "time" ) type CtrlCmdType int @@ -37,8 +37,7 @@ func ListenCtrl(ctx context.Context) error { } defer utils.Close(ls) - // 处理连接 - // 异步等待连接 + // 异步等待处理连接 var connCh = make(chan net.Conn) go func() { for { @@ -80,77 +79,88 @@ func ListenCtrl(ctx context.Context) error { } func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) { - // 通道上下文 - ctx, cancel := context.WithCancel(_ctx) + reader := bufio.NewReader(conn) + + // 上下文与通道信息 + ctx, cancel := context.WithCancel(_ctx) + defer cancel() - // 结束后清理资源 var fwdPort uint16 - defer func() { - slog.Debug("关闭控制通道", "port", fwdPort) - app.DelEdge(fwdPort) + + // 处理连接命令 + var errCh = make(chan error) + go func() { + var err error + for { + // 读取命令 + var timeout = time.Duration(env.AppCtrlHBTimeout*2+env.AppCtrlRWTimeout) * time.Second + err = conn.SetReadDeadline(time.Now().Add(timeout)) + if err != nil { + errCh <- fmt.Errorf("设置读取超时失败: %w", err) + return + } + + cmd, err := reader.ReadByte() + if err := utils.WarpConnErr(err); err != nil { + errCh <- err + return + } + + // 处理节点命令 + err = conn.SetReadDeadline(time.Now().Add(time.Duration(env.AppCtrlRWTimeout) * time.Second)) + if err != nil { + errCh <- fmt.Errorf("设置读取超时失败: %w", err) + return + } + switch CtrlCmdType(cmd) { + + // 连接建立命令 + case CtrlCmdOpen: + var recv = make([]byte, 4) + _, err = io.ReadFull(reader, recv) + if err != nil { + errCh <- fmt.Errorf("读取节点 ID 失败: %w", err) + return + } + var client = int32(binary.BigEndian.Uint32(recv)) + fwdPort, err = onOpen(ctx, conn, client) + if err != nil { + errCh <- fmt.Errorf("处理连接建立命令失败: %w", err) + return + } + + // 心跳命令 + case CtrlCmdPing: + err = onPing(conn) + if err != nil { + errCh <- fmt.Errorf("处理心跳命令失败: %w", err) + return + } + + // 连接关闭命令 + case CtrlCmdClose: + err = onClose(conn) + if err != nil { + errCh <- fmt.Errorf("处理关闭命令失败: %w", err) + return + } + + // 忽略其他不应该由节点发起的命令 + default: + errCh <- fmt.Errorf("无法处理控制命令: %d", cmd) + return + } + } }() - // 处理控制命令 - defer cancel() - reader := bufio.NewReader(conn) - for { - // 循环等待直到服务关闭 - select { - case <-ctx.Done(): - return nil - default: - } - - // 读取命令 - cmd, err := reader.ReadByte() - if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.WSAECONNRESET) { - slog.Debug("节点重置了控制通道连接(WSAECONNRESET)") - return nil - } - if errors.Is(err, io.EOF) { - slog.Debug("节点关闭了控制通道") - return nil - } - if err != nil { - return fmt.Errorf("读取节点命令失败: %w", err) - } - - // 处理节点命令 - switch CtrlCmdType(cmd) { - - // 连接建立命令 - case CtrlCmdOpen: - var recv = make([]byte, 4) - _, err = io.ReadFull(reader, recv) - if err != nil { - return fmt.Errorf("读取节点 ID 失败: %w", err) - } - var client = int32(binary.BigEndian.Uint32(recv)) - fwdPort, err = onOpen(ctx, conn, client) - if err != nil { - return fmt.Errorf("处理连接建立命令失败: %w", err) - } - - // 心跳命令 - case CtrlCmdPing: - err = onPing(conn) - if err != nil { - return fmt.Errorf("处理心跳命令失败: %w", err) - } - - // 连接关闭命令 - case CtrlCmdClose: - err = onClose(conn) - if err != nil { - return fmt.Errorf("处理关闭命令失败: %w", err) - } - return nil - - // 忽略其他不应该由节点发起的命令 - default: - return fmt.Errorf("无法处理控制命令: %d", cmd) - } + // 等待处理结束 + select { + case <-ctx.Done(): + case err = <-errCh: } + + app.DelEdge(fwdPort) + return } func onOpen(ctx context.Context, writer io.Writer, edge int32) (port uint16, err error) { diff --git a/gateway/fwd/data.go b/gateway/fwd/data.go index bca7dfe..420bd75 100644 --- a/gateway/fwd/data.go +++ b/gateway/fwd/data.go @@ -73,8 +73,8 @@ func ListenData(ctx context.Context) error { } } -func processDataConn(ctx context.Context, client net.Conn) error { - var reader = bufio.NewReader(client) +func processDataConn(ctx context.Context, edge net.Conn) error { + var reader = bufio.NewReader(edge) // 接收连接结果 var buf = make([]byte, 17) @@ -133,7 +133,7 @@ func processDataConn(ctx context.Context, client net.Conn) error { // 复制用户数据到节点 var waitUser = make(chan error) go func() { - _, err := io.Copy(client, teeUser) + _, err := io.Copy(edge, teeUser) switch { case errors.Is(err, net.ErrClosed): slog.Debug("用户连接意外关闭") diff --git a/gateway/fwd/user.go b/gateway/fwd/user.go index bf62035..70defe6 100644 --- a/gateway/fwd/user.go +++ b/gateway/fwd/user.go @@ -17,7 +17,7 @@ import ( ) func ListenUser(ctx context.Context, port uint16, ctrl io.Writer) error { - dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second) + dspt, err := dispatcher.New(port, time.Duration(env.AppUserRWTimeout)*time.Second) if err != nil { return err } diff --git a/utils/conn.go b/utils/conn.go new file mode 100644 index 0000000..39ce9b2 --- /dev/null +++ b/utils/conn.go @@ -0,0 +1,41 @@ +package utils + +import ( + "errors" + "fmt" + "io" + "log/slog" + "net" + "syscall" +) + +func WarpConnErr(err error) error { + if err == nil { + return nil + } + + var opErr *net.OpError + switch { + + case errors.Is(err, net.ErrClosed): + slog.Debug("连接已关闭") + return nil + + case errors.Is(err, io.EOF): + slog.Debug("连接被对端关闭") + return nil + + case errors.As(err, &opErr): + switch { + case + errors.Is(opErr.Err, syscall.WSAECONNRESET), errors.Is(opErr.Err, syscall.WSAECONNABORTED), + errors.Is(opErr.Err, syscall.ECONNRESET), errors.Is(opErr.Err, syscall.ECONNABORTED): + slog.Debug("连接被对端重置") + return nil + case opErr.Timeout(): + slog.Debug("连接已超时") + return nil + } + } + return fmt.Errorf("连接发生未处理的错误: %w", err) +}