From 84e01d3b5086ac9f49ea255ede582f3fed176348 Mon Sep 17 00:00:00 2001 From: luorijun Date: Sat, 17 May 2025 10:00:28 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E5=91=BD=E5=90=8D=E5=AE=A2=E6=88=B7?= =?UTF-8?q?=E7=AB=AF=E7=9B=B8=E5=85=B3=E6=9C=AF=E8=AF=AD=E4=B8=BA=E8=8A=82?= =?UTF-8?q?=E7=82=B9=EF=BC=9B=E7=A7=BB=E5=8A=A8=20utils=20=E5=8C=85?= =?UTF-8?q?=E5=88=B0=E6=A0=B9=E8=B7=AF=E5=BE=84=EF=BC=9B=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E7=BD=91=E5=85=B3=E5=AF=B9=E8=8A=82=E7=82=B9=E5=90=84=E7=A7=8D?= =?UTF-8?q?=E8=BF=9E=E6=8E=A5=E7=8A=B6=E6=80=81=E7=9A=84=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E5=B9=B6=E5=9C=A8=E8=8A=82=E7=82=B9=E6=96=AD=E8=81=94?= =?UTF-8?q?=E5=90=8E=E7=BB=9F=E4=B8=80=E6=B8=85=E7=90=86=E8=B5=84=E6=BA=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- edge/edge.go | 112 +++++++++++++++-------------- edge/env/env.go | 4 +- gateway/app/app.go | 17 ++++- gateway/fwd/analysis.go | 2 +- gateway/fwd/ctrl.go | 98 +++++++++++++------------ gateway/fwd/data.go | 16 ++--- gateway/fwd/dispatcher/dispatch.go | 2 +- gateway/fwd/fwd.go | 4 +- gateway/fwd/socks/socks.go | 8 +-- gateway/fwd/user.go | 8 +-- gateway/gateway.go | 6 +- gateway/web/handlers/auth.go | 2 +- {pkg/utils => utils}/sync.go | 0 {pkg/utils => utils}/utils.go | 0 14 files changed, 150 insertions(+), 129 deletions(-) rename {pkg/utils => utils}/sync.go (100%) rename {pkg/utils => utils}/utils.go (100%) diff --git a/edge/edge.go b/edge/edge.go index dc4da2f..5de0028 100644 --- a/edge/edge.go +++ b/edge/edge.go @@ -16,7 +16,7 @@ import ( "proxy-server/edge/env" "proxy-server/edge/geo" "proxy-server/edge/report" - "proxy-server/pkg/utils" + "proxy-server/utils" "time" ) @@ -103,7 +103,8 @@ func ctrl(ctx context.Context, id int32, host string) error { var reader = bufio.NewReader(conn) // 发送节点连接命令 - err = sendOpen(reader, conn, id) + slog.Debug("发送节点连接命令") + err = sendOpen(conn, id) if err != nil { return fmt.Errorf("发送节点信息失败: %w", err) } @@ -117,7 +118,8 @@ func ctrl(ctx context.Context, id int32, host string) error { case <-ctx.Done(): return case tick := <-ticker.C: - err := sendPing(reader, conn) + slog.Debug("发送心跳", "time", tick) + err := sendPing(conn) if err != nil { slog.Error("发送心跳失败", "time", tick, "err", err) } @@ -125,23 +127,51 @@ func ctrl(ctx context.Context, id int32, host string) error { } }() - // 等待用户连接 - // 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道 + // 异步等待连接命令 slog.Info("等待用户连接") + var cmdCh = make(chan ConnCmd) + go func() { + for { + cmd, err := reader.ReadByte() + if errors.Is(err, net.ErrClosed) { + slog.Debug("控制通道关闭") + return + } + if errors.Is(err, io.EOF) { + slog.Debug("网关关闭了控制通道") + return + } + if err != nil { + slog.Error("读取命令失败", "err", err) + return + } + + switch cmd { + case 1: + // 忽略网关响应的 pong 命令 + case 5: + tag, addr, err := onConn(reader) + if err != nil { + slog.Error("接收连接命令失败", "err", err) + return + } + cmdCh <- ConnCmd{ + Tag: tag, + Addr: addr, + } + } + } + }() + + // 等待建立数据通道 for loop := true; loop; { select { case <-ctx.Done(): loop = false - default: - // 接收 dst - tag, addr, err := onConn(reader) - if err != nil { - return fmt.Errorf("接收连接命令失败: %w", err) - } - - // 建立数据通道 + case cmd := <-cmdCh: + slog.Debug("建立数据通道", "tag", cmd.Tag, "addr", cmd.Addr) go func() { - err := data(dataAddr, addr, tag) + err := data(dataAddr, cmd.Addr, cmd.Tag) if err != nil { slog.Error("建立数据通道失败", "err", err) } @@ -150,7 +180,8 @@ func ctrl(ctx context.Context, id int32, host string) error { } // 发送关闭连接(不 return err,否则会重新连接) - err = sendClose(reader, conn) + slog.Debug("发送关闭连接") + err = sendClose(conn) if err != nil { slog.Error("发送关闭连接失败", "err", err) } @@ -207,7 +238,7 @@ func data(proxy string, dest string, tag [16]byte) error { return nil } -func sendOpen(reader io.Reader, writer io.Writer, id int32) error { +func sendOpen(writer io.Writer, id int32) error { // 发送打开连接 var buf = make([]byte, 5) @@ -219,72 +250,38 @@ func sendOpen(reader io.Reader, writer io.Writer, id int32) error { return fmt.Errorf("发送打开连接失败: %w", err) } - // 等待服务端响应 - respBuf := make([]byte, 1) - _, err = io.ReadFull(reader, respBuf) - if err != nil { - return fmt.Errorf("接收服务端响应失败: %w", err) - } - if respBuf[0] != 1 { - return errors.New("服务端响应失败") - } - return nil } -func sendClose(reader io.Reader, writer io.Writer) error { +func sendClose(writer io.Writer) error { // 发送关闭连接 _, err := writer.Write([]byte{4}) if err != nil { return err } - // 等待服务端响应 - respBuf := make([]byte, 1) - _, err = io.ReadFull(reader, respBuf) - if err != nil { - return fmt.Errorf("接收服务端响应失败: %w", err) - } - if respBuf[0] != 1 { - return errors.New("服务端响应失败") - } - return nil } -func sendPing(reader io.Reader, writer io.Writer) error { +func sendPing(writer io.Writer) error { _, err := writer.Write([]byte{2}) if err != nil { return err } - // 等待服务端响应 - respBuf := make([]byte, 1) - _, err = io.ReadFull(reader, respBuf) - if err != nil { - return fmt.Errorf("接收服务端响应失败: %w", err) - } - if respBuf[0] != 1 { - return errors.New("服务端响应失败") - } - return nil } func onConn(reader io.Reader) (tag [16]byte, addr string, err error) { - var buf = make([]byte, 1+16+2) + var buf = make([]byte, 16+2) _, err = io.ReadFull(reader, buf) if err != nil { return [16]byte{}, "", err } - if buf[0] != 5 { - return [16]byte{}, "", errors.New("命令错误") - } + tag = [16]byte(buf[0:16]) - tag = [16]byte(buf[1:17]) - - var addrLen = binary.BigEndian.Uint16(buf[17:19]) + var addrLen = binary.BigEndian.Uint16(buf[16:18]) var addrBuf = make([]byte, addrLen) _, err = io.ReadFull(reader, addrBuf) if err != nil { @@ -294,3 +291,8 @@ func onConn(reader io.Reader) (tag [16]byte, addr string, err error) { addr = string(addrBuf) return tag, addr, nil } + +type ConnCmd struct { + Tag [16]byte + Addr string +} diff --git a/edge/env/env.go b/edge/env/env.go index 548de09..4f95a41 100644 --- a/edge/env/env.go +++ b/edge/env/env.go @@ -15,7 +15,7 @@ var EndpointOffline = "https://api.lanhuip.com/api/edge/offline" func Init() error { var env = flag.String("e", "dev", "环境变量,可选值 dev 或 prod") - var name = flag.String("n", "", "客户端唯一标识") + var name = flag.String("n", "", "节点唯一标识") var online = flag.String("online", "", "服务注册地址") var offline = flag.String("offline", "", "服务注销地址") @@ -32,7 +32,7 @@ func Init() error { if name != nil && *name != "" { Name = *name } else { - return errors.New("客户端唯一标识不能为空") + return errors.New("节点唯一标识不能为空") } if online != nil && *online != "" { diff --git a/gateway/app/app.go b/gateway/app/app.go index fd19cb2..be53915 100644 --- a/gateway/app/app.go +++ b/gateway/app/app.go @@ -9,7 +9,22 @@ var ( Name string PlatformSecret string // 平台密钥,验证接收的请求是否属于平台 - Clients = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口 + Edges = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口 Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID Permits = core.SyncMap[uint16, *core.Permit]{} // 转发端口 -> 权限配置 ) + +func AddEdge(id int32, port uint16) { + Edges.Store(id, port) + Assigns.Store(port, id) +} + +func DelEdge(port uint16) { + id, _ := Assigns.LoadAndDelete(port) + Edges.Delete(id) + Permits.Delete(port) +} + +func PermitEdge(port uint16, permit *core.Permit) { + Permits.Store(port, permit) +} diff --git a/gateway/fwd/analysis.go b/gateway/fwd/analysis.go index 75be0b3..cebf5e8 100644 --- a/gateway/fwd/analysis.go +++ b/gateway/fwd/analysis.go @@ -7,7 +7,7 @@ import ( "io" "log/slog" "proxy-server/gateway/core" - "proxy-server/pkg/utils" + "proxy-server/utils" "strings" "errors" diff --git a/gateway/fwd/ctrl.go b/gateway/fwd/ctrl.go index 73cb4b2..5109c3b 100644 --- a/gateway/fwd/ctrl.go +++ b/gateway/fwd/ctrl.go @@ -12,8 +12,9 @@ import ( "proxy-server/gateway/app" "proxy-server/gateway/env" "proxy-server/gateway/report" - "proxy-server/pkg/utils" + "proxy-server/utils" "strconv" + "syscall" ) type CtrlCmdType int @@ -80,6 +81,21 @@ func (s *Service) listenCtrl() error { } func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) { + defer func() { + _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + slog.Error("获取控制通道端口失败", "err", err) + return + } + + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + slog.Error("解析控制通道端口失败", "err", err) + return + } + + app.DelEdge(uint16(port)) + }() reader := bufio.NewReader(conn) for { // 循环等待直到服务关闭 @@ -90,12 +106,21 @@ func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error } // 读取命令 - cmdByte, err := reader.ReadByte() + cmd, err := reader.ReadByte() + if errors.Is(err, syscall.WSAECONNRESET) { + slog.Debug("节点重置了控制通道连接(WSAECONNRESET)") + return nil + } + if errors.Is(err, io.EOF) { + slog.Debug("节点关闭了控制通道") + return nil + } if err != nil { return fmt.Errorf("读取节点命令失败: %w", err) } - var cmd = CtrlCmdType(cmdByte) - switch cmd { + + // 处理节点命令 + switch CtrlCmdType(cmd) { // 连接建立命令 case CtrlCmdOpen: @@ -132,15 +157,11 @@ func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error } } -func (s *Service) onPing(conn net.Conn) (err error) { - return s.sendPong(conn) -} - -func (s *Service) onOpen(conn net.Conn, client int32) (err error) { +func (s *Service) onOpen(writer io.Writer, edge int32) (err error) { // open 命令全局只执行一次 - _, ok := app.Clients.Load(client) + _, ok := app.Edges.Load(edge) if ok { - return fmt.Errorf("节点 ID %d 已经连接", client) + return fmt.Errorf("节点 ID %d 已经连接", edge) } // 分配端口 @@ -151,8 +172,7 @@ func (s *Service) onOpen(conn net.Conn, client int32) (err error) { var _, ok = app.Assigns.Load(i) if !ok { port = i - app.Assigns.Store(i, client) - app.Clients.Store(client, i) + app.AddEdge(edge, port) break } } @@ -161,62 +181,46 @@ func (s *Service) onOpen(conn net.Conn, client int32) (err error) { } // 报告端口分配 - if err = report.Assigned(client, port); err != nil { + if err = report.Assigned(edge, port); err != nil { return fmt.Errorf("报告端口分配失败: %w", err) } - // 响应客户端 - if err = s.sendPong(conn); err != nil { - return fmt.Errorf("响应客户端失败: %w", err) + // 响应节点 + if err = s.sendPong(writer); err != nil { + return fmt.Errorf("响应节点失败: %w", err) } // 启动转发服务 s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() - slog.Info("监听转发端口", "port", port, "client", client) - err = s.listenUser(port, conn) + slog.Info("监听转发端口", "port", port, "edge", edge) + err = s.listenUser(port, writer) if err != nil { - slog.Error("监听转发端口失败", "port", port, "client", client, "err", err) + slog.Error("监听转发端口失败", "port", port, "edge", edge, "err", err) } }() return nil } -func (s *Service) onClose(conn net.Conn) (err error) { - _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) - if err != nil { - return err - } - - port, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return err - } - - id, _ := app.Assigns.LoadAndDelete(uint16(port)) - app.Clients.Delete(id) - app.Assigns.Delete(uint16(port)) - app.Permits.Delete(uint16(port)) - - err = s.sendPong(conn) - if err != nil { - return err - } - - return nil +func (s *Service) onPing(writer io.Writer) (err error) { + return s.sendPong(writer) } -func (s *Service) sendPong(conn net.Conn) (err error) { - _, err = conn.Write([]byte{byte(CtrlCmdPong)}) +func (s *Service) onClose(writer io.Writer) (err error) { + return s.sendPong(writer) +} + +func (s *Service) sendPong(writer io.Writer) (err error) { + _, err = writer.Write([]byte{byte(CtrlCmdPong)}) if err != nil { - return fmt.Errorf("响应客户端失败: %w", err) + return fmt.Errorf("响应节点失败: %w", err) } return nil } -func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) { +func (s *Service) sendProxy(writer io.Writer, tag [16]byte, addr string) (err error) { if len(addr) > 65535 { return fmt.Errorf("代理地址过长: %s", addr) } @@ -227,7 +231,7 @@ func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error binary.BigEndian.PutUint16(buf[17:], uint16(len(addr))) copy(buf[19:], addr) - _, err = conn.Write(buf) + _, err = writer.Write(buf) if err != nil { return fmt.Errorf("发送代理命令失败: %w", err) } diff --git a/gateway/fwd/data.go b/gateway/fwd/data.go index fcd1a33..20d0655 100644 --- a/gateway/fwd/data.go +++ b/gateway/fwd/data.go @@ -11,7 +11,7 @@ import ( "proxy-server/gateway/debug" "proxy-server/gateway/env" "proxy-server/gateway/fwd/metrics" - "proxy-server/pkg/utils" + utils2 "proxy-server/utils" "strconv" "sync" "time" @@ -26,7 +26,7 @@ func (s *Service) listenData() error { if err != nil { return fmt.Errorf("监听数据通道失败: %w", err) } - defer utils.Close(ls) + defer utils2.Close(ls) // 异步等待连接 var connCh = make(chan net.Conn) @@ -44,7 +44,7 @@ func (s *Service) listenData() error { select { case connCh <- conn: case <-s.ctx.Done(): - utils.Close(conn) + utils2.Close(conn) return } } @@ -59,7 +59,7 @@ func (s *Service) listenData() error { s.dataConnWg.Add(1) go func() { defer s.dataConnWg.Done() - defer utils.Close(conn) + defer utils2.Close(conn) err := s.processDataConn(conn) if err != nil { slog.Error("处理数据通道连接失败", "err", err) @@ -76,7 +76,7 @@ func (s *Service) processDataConn(client net.Conn) error { var buf = make([]byte, 17) _, err := io.ReadFull(reader, buf) if err != nil { - return fmt.Errorf("从客户端获取连接结果失败: %w", err) + return fmt.Errorf("从节点获取连接结果失败: %w", err) } tag := buf[0:16] @@ -88,7 +88,7 @@ func (s *Service) processDataConn(client net.Conn) error { if !ok { return fmt.Errorf("用户连接已关闭,tag:%s", tagStr) } - defer utils.Close(user) + defer utils2.Close(user) // 检查状态 if status != 1 { @@ -99,7 +99,7 @@ func (s *Service) processDataConn(client net.Conn) error { data := time.Now() userPipeReader, userPipeWriter := io.Pipe() - defer utils.Close(userPipeWriter) + defer utils2.Close(userPipeWriter) teeUser := io.TeeReader(user, userPipeWriter) go func() { @@ -131,7 +131,7 @@ func (s *Service) processDataConn(client net.Conn) error { case <-s.ctx.Done(): return nil - case <-utils.WgWait(&wg): + case <-utils2.WgWait(&wg): proxy := time.Now() start, startOk := metrics.TimerStart.Load(user.Conn) diff --git a/gateway/fwd/dispatcher/dispatch.go b/gateway/fwd/dispatcher/dispatch.go index 714fb1b..7e76ef3 100644 --- a/gateway/fwd/dispatcher/dispatch.go +++ b/gateway/fwd/dispatcher/dispatch.go @@ -9,7 +9,7 @@ import ( "proxy-server/gateway/fwd/http" "proxy-server/gateway/fwd/metrics" "proxy-server/gateway/fwd/socks" - "proxy-server/pkg/utils" + "proxy-server/utils" "strconv" "strings" "time" diff --git a/gateway/fwd/fwd.go b/gateway/fwd/fwd.go index 3243b5f..3a6925a 100644 --- a/gateway/fwd/fwd.go +++ b/gateway/fwd/fwd.go @@ -4,7 +4,7 @@ import ( "context" "log/slog" "proxy-server/gateway/core" - "proxy-server/pkg/utils" + "proxy-server/utils" "sync" ) @@ -29,7 +29,7 @@ func New() *Service { } func (s *Service) Run() error { - slog.Debug("启动转发服务") + slog.Info("启动转发服务") errQuit := make(chan struct{}, 2) defer close(errQuit) diff --git a/gateway/fwd/socks/socks.go b/gateway/fwd/socks/socks.go index 64c7a06..ace8fc0 100644 --- a/gateway/fwd/socks/socks.go +++ b/gateway/fwd/socks/socks.go @@ -12,7 +12,7 @@ import ( "net" "proxy-server/gateway/core" "proxy-server/gateway/fwd/auth" - "proxy-server/pkg/utils" + "proxy-server/utils" "slices" ) @@ -90,7 +90,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { }, nil } -// checkVersion 检查客户端版本 +// checkVersion 检查节点版本 func checkVersion(reader io.Reader) error { version, err := utils.ReadByte(reader) if err != nil { @@ -98,7 +98,7 @@ func checkVersion(reader io.Reader) error { } if version != Version { - return errors.New("客户端版本不兼容") + return errors.New("节点版本不兼容") } return nil @@ -113,7 +113,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (aut return nil, err } - // 获取客户端认证方式 + // 获取节点认证方式 nAuth, err := utils.ReadByte(reader) if err != nil { return nil, err diff --git a/gateway/fwd/user.go b/gateway/fwd/user.go index 05f1b06..cc89a8a 100644 --- a/gateway/fwd/user.go +++ b/gateway/fwd/user.go @@ -4,17 +4,17 @@ import ( "context" "encoding/hex" "errors" + "io" "log/slog" - "net" "proxy-server/gateway/core" "proxy-server/gateway/env" "proxy-server/gateway/fwd/dispatcher" "proxy-server/gateway/fwd/metrics" - "proxy-server/pkg/utils" + "proxy-server/utils" "time" ) -func (s *Service) listenUser(port uint16, ctrl net.Conn) error { +func (s *Service) listenUser(port uint16, ctrl io.Writer) error { dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second) if err != nil { return err @@ -48,7 +48,7 @@ func (s *Service) listenUser(port uint16, ctrl net.Conn) error { } } -func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) (err error) { +func (s *Service) processUserConn(user *core.Conn, ctrl io.Writer) (err error) { // 发送代理命令 err = s.sendProxy(ctrl, user.Tag, user.Dest.String()) diff --git a/gateway/gateway.go b/gateway/gateway.go index 9236648..addbfa9 100644 --- a/gateway/gateway.go +++ b/gateway/gateway.go @@ -15,7 +15,7 @@ import ( "proxy-server/gateway/log" "proxy-server/gateway/report" "proxy-server/gateway/web" - "proxy-server/pkg/utils" + "proxy-server/utils" "sync" "time" @@ -88,7 +88,7 @@ func (s *server) Run() (err error) { // }() // 报告上线 - slog.Debug("报告服务上线") + slog.Info("报告服务上线") err = report.Online(app.Name) if err != nil { return fmt.Errorf("服务上线失败: %w", err) @@ -125,11 +125,11 @@ func (s *server) Run() (err error) { // 等待其它服务关闭 select { case <-utils.WgWait(&wg): - slog.Info("服务正常关闭") case <-time.After(time.Duration(env.AppExitTimeout) * time.Second): slog.Warn("超时强制关闭") } + slog.Info("服务已退出") return nil } diff --git a/gateway/web/handlers/auth.go b/gateway/web/handlers/auth.go index 3be90ce..943dcd5 100644 --- a/gateway/web/handlers/auth.go +++ b/gateway/web/handlers/auth.go @@ -26,7 +26,7 @@ func Auth(ctx *fiber.Ctx) (err error) { } // 保存授权配置 - app.Permits.Store(req.Port, &req.Permit) + app.PermitEdge(req.Port, &req.Permit) return nil } diff --git a/pkg/utils/sync.go b/utils/sync.go similarity index 100% rename from pkg/utils/sync.go rename to utils/sync.go diff --git a/pkg/utils/utils.go b/utils/utils.go similarity index 100% rename from pkg/utils/utils.go rename to utils/utils.go