From 8a6a4833d4ab930c18207239c8b88b6715778d7e Mon Sep 17 00:00:00 2001 From: luorijun Date: Fri, 16 May 2025 15:13:16 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=96=B0=E8=A7=84=E5=88=92=E7=BD=91?= =?UTF-8?q?=E5=85=B3=E4=B8=8E=E8=8A=82=E7=82=B9=E7=9A=84=E4=BA=A4=E4=BA=92?= =?UTF-8?q?=E5=8D=8F=E8=AE=AE=EF=BC=8C=E5=AE=9E=E7=8E=B0=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E5=91=BD=E4=BB=A4=E4=BD=8D=E7=9A=84=E8=AF=86=E5=88=AB=E5=92=8C?= =?UTF-8?q?=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 77 +++++-- client/client.go | 272 +++++++++++++++++-------- client/core/{env.go => consts.go} | 1 + client/report/{online.go => report.go} | 4 + pkg/utils/chan.go | 8 +- server/app/app.go | 9 +- server/{fwd => }/core/conn.go | 33 +-- server/core/map.go | 65 ++++++ server/env/env.go | 29 ++- server/fwd/analysis.go | 4 +- server/fwd/auth/auth.go | 7 +- server/fwd/ctrl.go | 241 +++++++++++----------- server/fwd/data.go | 40 ++-- server/fwd/dispatcher/dispatch.go | 16 +- server/fwd/fwd.go | 15 +- server/fwd/http/http.go | 7 +- server/fwd/repo/channel.go | 22 -- server/fwd/repo/node.go | 15 -- server/fwd/socks/socks.go | 5 +- server/fwd/user.go | 82 ++++++++ server/server.go | 28 +-- server/web/handlers/auth.go | 2 +- 22 files changed, 609 insertions(+), 373 deletions(-) rename client/core/{env.go => consts.go} (79%) rename client/report/{online.go => report.go} (97%) rename server/{fwd => }/core/conn.go (66%) create mode 100644 server/core/map.go delete mode 100644 server/fwd/repo/channel.go delete mode 100644 server/fwd/repo/node.go create mode 100644 server/fwd/user.go diff --git a/README.md b/README.md index ba167a5..b35a56c 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ### 目录结构 -server/fwd: 服务端核心代码 +server/fwd: 网关核心代码 - core: 核心代码,目前主要是连接管理 - dispatcher: 请求处理器,负责解析传入协议,并将请求分发到对应的处理器 @@ -24,7 +24,7 @@ server/fwd: 服务端核心代码 ## 底层协议设计 -### 步骤说明 +### 步骤说明(待更新) 1. 启动转发服务,尝试注册自身到后端服务,随后持续报告心跳 2. 启动边缘节点后,尝试注册自身到后端服务,随后持续报告心跳 @@ -36,33 +36,71 @@ server/fwd: 服务端核心代码 8. 当成功建立数据通道后,边缘节点将数据通道标识以及对目标地址的连接结果提供给转发服务 9. 如果连接成功建立,则开始代理流量,如果连接失败,则关闭数据通道 -### 协议报文详情 +### 协议报文说明 协议中所有数值都以大端形式传输 -#### 建立控制通道 +控制通道命令列表: -1.客户端发送: +| 编码 | 命令 | 发起方 | 备注 | +|----|-------|-------|--------------------| +| 1 | pong | proxy | 网关响应,和 status 相同 | +| 2 | ping | edge | 边缘节点心跳 | +| 3 | open | edge | 边缘节点连接到网关 | +| 4 | close | edge | 边缘节点断开连接 | +| 5 | proxy | proxy | 网关通知节点打开数据通道用来转发数据 | -| id(4) | -|--------| -| 客户端 ID | +### 建立控制通道 -2.服务端发送: +1.节点发起 + +| cmd(1) | id(4) | +|---------|-------| +| 命令固定为 3 | 节点 ID | + +2.网关响应 | status(1) | |-----------| | 状态,固定为 1 | -#### 建立数据通道 +### 控制通道心跳 -1.服务端发送: +1.节点发起 -| tag(16) | dst_len(2) | dst_buf(n) | -|---------|------------|------------| -| 通道标识 | 目标地址长度 | 目标地址 | +| cmd(1) | +|---------| +| 命令固定为 2 | -2.客户端发送: +2.网关响应: + +| status(1) | +|-----------| +| 状态,固定为 1 | + +### 关闭控制通道 + +1.节点发起 + +| cmd(1) | +|---------| +| 命令固定为 4 | + +2.网关响应: + +| status(1) | +|-----------| +| 状态,固定为 1 | + +### 建立数据通道 + +1.网关发送: + +| cmd(1) | tag(16) | dst_len(2) | dst_buf(n) | +|---------|---------|------------|------------| +| 命令固定为 5 | 通道标识 | 目标地址长度 | 目标地址 | + +2.节点发送: | tag(16) | status(1) | |---------|-----------------------| @@ -76,9 +114,12 @@ server/fwd: 服务端核心代码 ```json5 { - "content": "string", // base64 编码的请求体 - "nonce": "string", // 随机数值 - "timestamp": "number" // 时间戳,精确到毫秒 + "content": "string", + // base64 编码的请求体 + "nonce": "string", + // 随机数值 + "timestamp": "number" + // 时间戳,精确到毫秒 } ``` diff --git a/client/client.go b/client/client.go index b2e1ab2..b701947 100644 --- a/client/client.go +++ b/client/client.go @@ -2,20 +2,22 @@ package client import ( "bufio" + "context" "encoding/binary" + "errors" "fmt" "io" "log/slog" "net" + _ "net/http/pprof" + "os" + "os/signal" "proxy-server/client/core" "proxy-server/client/env" "proxy-server/client/geo" "proxy-server/client/report" "proxy-server/pkg/utils" "time" - - "errors" - _ "net/http/pprof" ) func Start() error { @@ -31,7 +33,7 @@ func Start() error { slog.Debug("获取节点归属地...") err = geo.Query() if err != nil { - slog.Error("获取归属地失败", "err", err) + return fmt.Errorf("获取节点归属地失败: %w", err) } // 注册节点 @@ -41,27 +43,38 @@ func Start() error { return fmt.Errorf("注册节点失败: %w", err) } - // 性能监控 - // go func() { - // runtime.SetBlockProfileRate(1) - // err := http.ListenAndServe(":7070", nil) - // if err != nil { - // slog.Error("性能监控服务启动失败", "err", err) - // } - // }() - // 建立控制通道 - for { - err := ctrl(id, host) - if err != nil { + var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill) + defer cancel() + + go func() { + for { + err = ctrl(ctx, id, host) + if err == nil { + return + } + select { + case <-ctx.Done(): + return + default: + } slog.Error("建立控制通道失败", "err", err) slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval)) time.Sleep(time.Duration(core.RetryInterval) * time.Second) } + }() + + // 下线节点 + slog.Debug("下线节点...") + err = report.Offline() + if err != nil { + slog.Error("下线节点失败", "err", err) } + + return ctx.Err() } -func ctrl(id int32, host string) error { +func ctrl(ctx context.Context, id int32, host string) error { ctrlAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdCtrlPort)) dataAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdDataPort)) @@ -71,103 +84,94 @@ func ctrl(id int32, host string) error { return errors.New("连接失败") } defer utils.Close(conn) + var reader = bufio.NewReader(conn) - // 发送客户端信息 - var buf = make([]byte, 4) - _, err = binary.Encode(buf, binary.BigEndian, id) + // 发送节点连接命令 + err = sendOpen(reader, conn, id) if err != nil { - return fmt.Errorf("编码客户端 ID 失败: %w", err) - } - _, err = conn.Write(buf) - if err != nil { - return fmt.Errorf("发送客户端 ID 失败: %w", err) + return fmt.Errorf("发送节点信息失败: %w", err) } - // 等待服务端响应 - reader := bufio.NewReader(conn) - respBuf, err := reader.ReadByte() - if err != nil { - return errors.New("接收响应失败") - } - if respBuf != 1 { - return errors.New("服务端响应失败") - } else { - slog.Info("成功建立连接") - } + // 异步定时发送心跳 + go func() { + ticker := time.NewTicker(time.Duration(core.HeartbeatInterval) * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case tick := <-ticker.C: + err := sendPing(reader, conn) + if err != nil { + slog.Error("发送心跳失败", "time", tick, "err", err) + } + } + } + }() // 等待用户连接 // 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道 slog.Info("等待用户连接") - for { - - // 接收 dst - dstLen, err := reader.ReadByte() - if err != nil { - return errors.New("接收 dstLen 失败") - } - dstBuf, err := utils.ReadBuffer(reader, int(dstLen)) - if err != nil { - return errors.New("接收 dstBuf 失败") - } - addr := string(dstBuf) - - // 接收 tag - tagLen, err := reader.ReadByte() - if err != nil { - return errors.New("接收 tagLen 失败") - } - tagBuf, err := utils.ReadBuffer(reader, int(tagLen)) - if err != nil { - return errors.New("接收 tagBuf 失败") - } - - // 建立数据通道 - go func() { - err := data(dataAddr, addr, tagBuf) + for loop := true; loop; { + select { + case <-ctx.Done(): + loop = false + default: + // 接收 dst + tag, addr, err := onConn(reader) if err != nil { - slog.Error("建立数据通道失败", "err", err) + return fmt.Errorf("接收连接命令失败: %w", err) } - }() + + // 建立数据通道 + go func() { + err := data(dataAddr, addr, tag) + if err != nil { + slog.Error("建立数据通道失败", "err", err) + } + }() + } } + + // 发送关闭连接(不 return err,否则会重新连接) + err = sendClose(reader, conn) + if err != nil { + slog.Error("发送关闭连接失败", "err", err) + } + + return nil } -func data(dataAddr string, dest string, tag []byte) error { +func data(proxy string, dest string, tag [16]byte) error { + + // 向目标地址建立连接 + var result = 1 + var dstErr error + dst, err := net.Dial("tcp", dest) + if err != nil { + dstErr = fmt.Errorf("连接目标地址失败: %w", dstErr) + result = 0 + } + defer utils.Close(dst) // 向服务端建立连接 - src, err := net.Dial("tcp", dataAddr) + src, err := net.Dial("tcp", proxy) if err != nil { return errors.New("连接服务端失败") } - - tagLen := byte(len(tag)) - tagBuf := make([]byte, 2+tagLen) - tagBuf[1] = tagLen - copy(tagBuf[2:], tag) - - // 向目标地址建立连接 - dst, dstErr := net.Dial("tcp", dest) - if dstErr != nil { - tagBuf[0] = 0 - } else { - tagBuf[0] = 1 - } + defer utils.Close(src) // 发送连接状态 - _, err = src.Write(tagBuf) + var buf = make([]byte, 17) + copy(buf[0:16], tag[:]) + buf[16] = byte(result) + _, err = src.Write(buf) if err != nil { - utils.Close(src) - if dst != nil { - utils.Close(dst) - } return errors.New("发送连接状态失败") } - if tagBuf[0] == 0 { - utils.Close(src) - if dst != nil { - utils.Close(dst) - } - return errors.New("连接目标地址失败") + if result == 0 { + return dstErr } go func() { @@ -186,3 +190,91 @@ func data(dataAddr string, dest string, tag []byte) error { }() return nil } + +func sendOpen(reader io.Reader, writer io.Writer, id int32) error { + + // 发送打开连接 + var buf = make([]byte, 5) + buf[0] = 3 + binary.BigEndian.PutUint32(buf[1:], uint32(id)) + + _, err := writer.Write(buf) + if err != nil { + return fmt.Errorf("发送打开连接失败: %w", err) + } + + // 等待服务端响应 + respBuf := make([]byte, 1) + _, err = io.ReadFull(reader, respBuf) + if err != nil { + return fmt.Errorf("接收服务端响应失败: %w", err) + } + if respBuf[0] != 1 { + return errors.New("服务端响应失败") + } + + return nil +} + +func sendClose(reader io.Reader, writer io.Writer) error { + // 发送关闭连接 + _, err := writer.Write([]byte{4}) + if err != nil { + return err + } + + // 等待服务端响应 + respBuf := make([]byte, 1) + _, err = io.ReadFull(reader, respBuf) + if err != nil { + return fmt.Errorf("接收服务端响应失败: %w", err) + } + if respBuf[0] != 1 { + return errors.New("服务端响应失败") + } + + return nil +} + +func sendPing(reader io.Reader, writer io.Writer) error { + _, err := writer.Write([]byte{2}) + if err != nil { + return err + } + + // 等待服务端响应 + respBuf := make([]byte, 1) + _, err = io.ReadFull(reader, respBuf) + if err != nil { + return fmt.Errorf("接收服务端响应失败: %w", err) + } + if respBuf[0] != 1 { + return errors.New("服务端响应失败") + } + + return nil +} + +func onConn(reader io.Reader) (tag [16]byte, addr string, err error) { + var buf = make([]byte, 1+16+2) + _, err = io.ReadFull(reader, buf) + if err != nil { + return [16]byte{}, "", err + } + + if buf[0] != 5 { + return [16]byte{}, "", errors.New("命令错误") + } + + tag = [16]byte(buf[1:17]) + + var addrLen = binary.BigEndian.Uint16(buf[17:19]) + var addrBuf = make([]byte, addrLen) + _, err = io.ReadFull(reader, addrBuf) + if err != nil { + return [16]byte{}, "", err + } + + addr = string(addrBuf) + return tag, addr, nil +} diff --git a/client/core/env.go b/client/core/consts.go similarity index 79% rename from client/core/env.go rename to client/core/consts.go index 403fa3c..d4de7d6 100644 --- a/client/core/env.go +++ b/client/core/consts.go @@ -6,3 +6,4 @@ const FwdCtrlPort uint = 18080 const FwdDataPort uint = 18081 const RetryInterval uint = 5 +const HeartbeatInterval uint = 30 diff --git a/client/report/online.go b/client/report/report.go similarity index 97% rename from client/report/online.go rename to client/report/report.go index e9a14bc..4e4c9ec 100644 --- a/client/report/online.go +++ b/client/report/report.go @@ -67,3 +67,7 @@ func Online(prov, city, isp string) (id int32, host string, err error) { return respBody.Id, respBody.Host, nil } + +func Offline() error { + return nil +} diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go index 2e563f9..ef7b4be 100644 --- a/pkg/utils/chan.go +++ b/pkg/utils/chan.go @@ -39,15 +39,11 @@ func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { return ch } -func ChanWgWait[T WaitGroup](ctx context.Context, wg T) chan struct{} { +func WgWait[T WaitGroup](wg T) <-chan struct{} { ch := make(chan struct{}) go func() { - defer close(ch) wg.Wait() - select { - case <-ctx.Done(): - case ch <- struct{}{}: - } + ch <- struct{}{} }() return ch } diff --git a/server/app/app.go b/server/app/app.go index f6c8fce..8ff7619 100644 --- a/server/app/app.go +++ b/server/app/app.go @@ -1,12 +1,15 @@ package app -import "proxy-server/server/core" +import ( + "proxy-server/server/core" +) var ( Id int32 Name string PlatformSecret string // 平台密钥,验证接收的请求是否属于平台 - Assigns = make(map[uint16]int32) // 转发端口 -> 转发服务ID - Permits = make(map[uint16]core.Permit) // 转发端口 -> 权限配置 + Clients = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口 + Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID + Permits = core.SyncMap[uint16, *core.Permit]{} // 转发端口 -> 权限配置 ) diff --git a/server/fwd/core/conn.go b/server/core/conn.go similarity index 66% rename from server/fwd/core/conn.go rename to server/core/conn.go index a1551b2..90fc288 100644 --- a/server/fwd/core/conn.go +++ b/server/core/conn.go @@ -4,40 +4,13 @@ import ( "bufio" "fmt" "net" - "sync" "time" ) -type ConnMap struct { - _map sync.Map -} - -func (c *ConnMap) LoadAndDelete(key string) (*Conn, bool) { - _value, ok := c._map.LoadAndDelete(key) - if !ok { - return nil, false - } - return _value.(*Conn), true -} - -func (c *ConnMap) Store(key string, value *Conn) { - c._map.Store(key, value) -} - -func (c *ConnMap) Range(f func(key string, value *Conn) bool) { - c._map.Range(func(key, value any) bool { - return f(key.(string), value.(*Conn)) - }) -} - -func (c *ConnMap) Clear() { - c._map.Clear() -} - type Conn struct { Conn net.Conn Reader *bufio.Reader - Tag string + Tag [16]byte Protocol string Dest *FwdAddr Auth *AuthContext @@ -75,10 +48,6 @@ func (c Conn) SetWriteDeadline(t time.Time) error { return c.Conn.SetWriteDeadline(t) } -func (c Conn) DestAddr() net.Addr { - return c.Dest -} - type FwdAddr struct { IP net.IP Port int diff --git a/server/core/map.go b/server/core/map.go new file mode 100644 index 0000000..e069f77 --- /dev/null +++ b/server/core/map.go @@ -0,0 +1,65 @@ +package core + +import "sync" + +type SyncMap[K any, V any] struct { + _map sync.Map +} + +func (m *SyncMap[K, V]) Store(key K, value V) { + m._map.Store(key, value) +} + +func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) { + v, ok := m._map.Load(key) + if ok { + value = v.(V) + } + return +} + +func (m *SyncMap[K, V]) Swap(key K, value V) (previous V, loaded bool) { + v, loaded := m._map.Swap(key, value) + if loaded { + previous = v.(V) + } + return +} + +func (m *SyncMap[K, V]) Delete(key K) { + m._map.Delete(key) +} + +func (m *SyncMap[K, V]) Clear() { + m._map.Clear() +} + +func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) { + m._map.Range(func(k, v any) bool { + return f(k.(K), v.(V)) + }) +} + +func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + v, loaded := m._map.LoadOrStore(key, value) + if loaded { + actual = v.(V) + } + return +} + +func (m *SyncMap[K, V]) LoadAndDelete(key K) (value V, ok bool) { + v, ok := m._map.LoadAndDelete(key) + if ok { + value = v.(V) + } + return +} + +func (m *SyncMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) { + return m._map.CompareAndSwap(key, old, new) +} + +func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { + return m._map.CompareAndDelete(key, old) +} diff --git a/server/env/env.go b/server/env/env.go index 38160b5..99004fb 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -10,10 +10,13 @@ import ( ) var ( - AppCtrlPort uint16 = 18080 - AppDataPort uint16 = 18081 - AppWebPort uint16 = 8848 - AppLogMode = "dev" + AppCtrlPort uint16 = 18080 + AppDataPort uint16 = 18081 + AppWebPort uint16 = 8848 + AppLogMode = "dev" + AppExitTimeout = 5 // 等待服务停止的超时时间 + AppDataTimeout = 10 // 等待数据通道连接的超时时间 + AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程) ClientId string ClientSecret string @@ -67,6 +70,24 @@ func Init() { AppLogMode = value } + value = os.Getenv("APP_EXIT_TIMEOUT") + if value != "" { + appExitTimeout, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_EXIT_TIMEOUT 格式错误: %v", err)) + } + AppExitTimeout = appExitTimeout + } + + value = os.Getenv("APP_DATA_TIMEOUT") + if value != "" { + appDataTimeout, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_DATA_TIMEOUT 格式错误: %v", err)) + } + AppDataTimeout = appDataTimeout + } + value = os.Getenv("CLIENT_ID") if value != "" { ClientId = value diff --git a/server/fwd/analysis.go b/server/fwd/analysis.go index aa06f08..182c1d2 100644 --- a/server/fwd/analysis.go +++ b/server/fwd/analysis.go @@ -7,7 +7,7 @@ import ( "io" "log/slog" "proxy-server/pkg/utils" - "proxy-server/server/fwd/core" + "proxy-server/server/core" "strings" "errors" @@ -27,7 +27,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error { slog.String("proxy", conn.Protocol), slog.String("node", conn.LocalAddr().String()), slog.String("proto", proto), - slog.String("dest", conn.DestAddr().String()), + slog.String("dest", conn.Dest.String()), slog.String("domain", domain), ) } diff --git a/server/fwd/auth/auth.go b/server/fwd/auth/auth.go index 0c4034a..0a8d106 100644 --- a/server/fwd/auth/auth.go +++ b/server/fwd/auth/auth.go @@ -4,7 +4,7 @@ import ( "fmt" "net" "proxy-server/server/app" - "proxy-server/server/fwd/core" + "proxy-server/server/core" "strconv" "time" @@ -36,7 +36,7 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A } // 查找权限配置 - var permit, ok = app.Permits[uint16(localPort)] + var permit, ok = app.Permits.Load(uint16(localPort)) if !ok { return nil, errors.New("没有权限") } @@ -68,10 +68,11 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A } } + var id, _ = app.Assigns.Load(uint16(localPort)) return &core.AuthContext{ Timeout: time.Since(permit.Expire).Seconds(), Payload: core.Payload{ - ID: app.Assigns[uint16(localPort)], + ID: id, }, }, nil } diff --git a/server/fwd/ctrl.go b/server/fwd/ctrl.go index b6c657f..1fa17cd 100644 --- a/server/fwd/ctrl.go +++ b/server/fwd/ctrl.go @@ -4,6 +4,7 @@ import ( "bufio" "context" "encoding/binary" + "errors" "fmt" "io" "log/slog" @@ -11,25 +12,21 @@ import ( "proxy-server/pkg/utils" "proxy-server/server/app" "proxy-server/server/env" - "proxy-server/server/fwd/core" - "proxy-server/server/fwd/dispatcher" - "proxy-server/server/fwd/metrics" "proxy-server/server/report" "strconv" - "strings" - "time" - - "errors" ) -type CtrlCmd struct { - conn net.Conn - buf []byte -} +type CtrlCmdType int -var ctrlCmdChan = make(chan CtrlCmd, 1024) +const ( + CtrlCmdPong CtrlCmdType = iota + 1 + CtrlCmdPing + CtrlCmdOpen + CtrlCmdClose + CtrlCmdProxy +) -func (s *Service) startCtrlTun() error { +func (s *Service) listenCtrl() error { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) @@ -56,7 +53,7 @@ func (s *Service) startCtrlTun() error { go func() { defer s.ctrlConnWg.Done() defer utils.Close(conn) - err := s.processCtrlConn(conn) + err := s.processCtrlConn(s.ctx, conn) if err != nil { slog.Error("处理控制通道连接失败", "err", err) } @@ -67,25 +64,80 @@ func (s *Service) startCtrlTun() error { return err } -func (s *Service) processCtrlConn(conn net.Conn) error { +func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) { reader := bufio.NewReader(conn) + for { + // 循环等待直到服务关闭 + select { + case <-ctx.Done(): + return nil + default: + } - var recv = make([]byte, 4) - _, err := io.ReadFull(reader, recv) - if err != nil { - return fmt.Errorf("读取客户端 ID 失败: %w", err) + // 读取命令 + cmdByte, err := reader.ReadByte() + if err != nil { + return fmt.Errorf("读取节点命令失败: %w", err) + } + var cmd = CtrlCmdType(cmdByte) + switch cmd { + + // 连接建立命令 + case CtrlCmdOpen: + var recv = make([]byte, 4) + _, err = io.ReadFull(reader, recv) + if err != nil { + return fmt.Errorf("读取节点 ID 失败: %w", err) + } + var client = int32(binary.BigEndian.Uint32(recv)) + err = s.onOpen(conn, client) + if err != nil { + return fmt.Errorf("处理连接建立命令失败: %w", err) + } + + // 心跳命令 + case CtrlCmdPing: + err = s.onPing(conn) + if err != nil { + return fmt.Errorf("处理心跳命令失败: %w", err) + } + + // 连接关闭命令 + case CtrlCmdClose: + err = s.onClose(conn) + if err != nil { + return fmt.Errorf("处理关闭命令失败: %w", err) + } + return nil + + // 忽略其他不应该由节点发起的命令 + default: + return fmt.Errorf("无法处理控制命令: %d", cmd) + } + } +} + +func (s *Service) onPing(conn net.Conn) (err error) { + return s.sendPong(conn) +} + +func (s *Service) onOpen(conn net.Conn, client int32) (err error) { + // open 命令全局只执行一次 + _, ok := app.Clients.Load(client) + if ok { + return fmt.Errorf("节点 ID %d 已经连接", client) } - var client = int32(binary.BigEndian.Uint32(recv)) // 分配端口 var minim uint16 = 20000 var maxim uint16 = 60000 var port uint16 for i := minim; i < maxim; i++ { - var _, ok = app.Assigns[i] + var _, ok = app.Assigns.Load(i) if !ok { port = i - app.Assigns[i] = client + app.Assigns.Store(i, client) + app.Clients.Store(client, i) break } } @@ -94,126 +146,75 @@ func (s *Service) processCtrlConn(conn net.Conn) error { } // 报告端口分配 - err = report.Assigned(client, port) - if err != nil { + if err = report.Assigned(client, port); err != nil { return fmt.Errorf("报告端口分配失败: %w", err) } // 响应客户端 - _, err = conn.Write([]byte{1}) - if err != nil { + if err = s.sendPong(conn); err != nil { return fmt.Errorf("响应客户端失败: %w", err) } // 启动转发服务 - slog.Info("监听转发端口", "port", port, "client", client) - proxy, err := dispatcher.New(port) - if err != nil { - return err - } - defer proxy.Close() - s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() - err := proxy.Run() + slog.Info("监听转发端口", "port", port, "client", client) + err = s.listenUser(port, conn) if err != nil { - slog.Error("代理服务运行失败", "err", err) + slog.Error("监听转发端口失败", "port", port, "client", client, "err", err) } }() - // 监听控制通道连接 - errCh := make(chan error, 1) - go func() { - defer close(errCh) - _, err := reader.ReadByte() - errCh <- err - }() - - // 批量同步写入 - go func() { - for { - select { - case <-s.ctx.Done(): - return - case cmd := <-ctrlCmdChan: - _, err := cmd.conn.Write(cmd.buf) - if err != nil { - slog.Error("批量写入失败", "err", err) - utils.Close(cmd.conn) - } - } - } - }() - - // 处理连接 - for { - select { - case <-s.ctx.Done(): - return nil - case err := <-errCh: - switch { - case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."): - slog.Debug("客户端主动断开连接") - return nil - case err == nil: - return errors.New("客户端握手失败") - default: - return fmt.Errorf("客户端意外断开连接: %w", err) - } - case user := <-proxy.Conn: - metrics.TimerAuth.Store(user.Conn, time.Now()) - s.userConnWg.Add(1) - go func() { - defer s.userConnWg.Done() - err := s.processUserConn(user, conn) - if err != nil { - slog.Error("处理用户连接失败", "err", err) - utils.Close(user) - } - }() - } - } + return nil } -func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { - - // 组织写入信息 - dst := user.DestAddr().String() - dstLen := len(dst) - - tag := user.Tag - tagLen := len(tag) - - writeBuf := make([]byte, 2+dstLen+tagLen) - writeBuf[0] = byte(dstLen) - copy(writeBuf[1:], dst) - writeBuf[1+dstLen] = byte(tagLen) - copy(writeBuf[2+dstLen:], tag) - - // 异步写入命令 - ctrlCmdChan <- CtrlCmd{ - conn: ctrl, - buf: writeBuf, +func (s *Service) onClose(conn net.Conn) (err error) { + _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) + if err != nil { + return err } - // 记录用户连接 - s.userConnMap.Store(user.Tag, user) + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return err + } - // 如果限定时间内没有建立数据通道,则关闭连接 - timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() + id, _ := app.Assigns.LoadAndDelete(uint16(port)) + app.Clients.Delete(id) + app.Assigns.Delete(uint16(port)) + app.Permits.Delete(uint16(port)) - select { - case <-s.ctx.Done(): - // 服务会在退出时统一关闭未消费的连接 - case <-timeout.Done(): - storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag) - if ok { - slog.Debug("建立数据通道超时", "tag", user.Tag) - utils.Close(storedUser) - } + err = s.sendPong(conn) + if err != nil { + return err } return nil } + +func (s *Service) sendPong(conn net.Conn) (err error) { + _, err = conn.Write([]byte{byte(CtrlCmdPong)}) + if err != nil { + return fmt.Errorf("响应客户端失败: %w", err) + } + return nil +} + +func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) { + if len(addr) > 65535 { + return fmt.Errorf("代理地址过长: %s", addr) + } + + buf := make([]byte, 1+16+2+len(addr)) + buf[0] = byte(CtrlCmdProxy) + copy(buf[1:], tag[:]) + binary.BigEndian.PutUint16(buf[17:], uint16(len(addr))) + copy(buf[19:], addr) + + _, err = conn.Write(buf) + if err != nil { + return fmt.Errorf("发送代理命令失败: %w", err) + } + return nil +} diff --git a/server/fwd/data.go b/server/fwd/data.go index 0513573..b083250 100644 --- a/server/fwd/data.go +++ b/server/fwd/data.go @@ -1,7 +1,9 @@ package fwd import ( + "bufio" "fmt" + "github.com/google/uuid" "io" "log/slog" "net" @@ -16,7 +18,7 @@ import ( "errors" ) -func (s *Service) startDataTun() error { +func (s *Service) listenData() error { dataPort := env.AppDataPort slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) @@ -57,38 +59,34 @@ func (s *Service) startDataTun() error { } func (s *Service) processDataConn(client net.Conn) error { + var reader = bufio.NewReader(client) - // 接收 status - status, err := utils.ReadByte(client) + // 接收连接结果 + var buf = make([]byte, 17) + _, err := io.ReadFull(reader, buf) if err != nil { - return fmt.Errorf("从客户端获取 status 失败: %w", err) + return fmt.Errorf("从客户端获取连接结果失败: %w", err) } - // 接收 tag - tagLen, err := utils.ReadByte(client) - if err != nil { - return fmt.Errorf("从客户端获取 tag 失败: %w", err) - } - tagBuf, err := utils.ReadBuffer(client, int(tagLen)) - if err != nil { - return fmt.Errorf("从客户端获取 tag 失败: %w", err) - } - tag := string(tagBuf) + tag := buf[0:16] + status := buf[16] - // 找到用户连接 - user, ok := s.userConnMap.LoadAndDelete(tag) + // 加载用户连接 + var tagStr = uuid.UUID(tag).String() + user, ok := s.userConnMap.LoadAndDelete(tagStr) if !ok { - return errors.New("用户连接已关闭,tag:" + tag) + return fmt.Errorf("用户连接已关闭,tag:%s", tagStr) } defer utils.Close(user) - data := time.Now() // 检查状态 if status != 1 { return errors.New("目标地址建立连接失败") } - // 数据转发 + // 转发数据 + data := time.Now() + userPipeReader, userPipeWriter := io.Pipe() defer utils.Close(userPipeWriter) teeUser := io.TeeReader(user, userPipeWriter) @@ -110,7 +108,7 @@ func (s *Service) processDataConn(client net.Conn) error { }() go func() { defer wg.Done() - _, err := io.Copy(user, client) + _, err := io.Copy(user, reader) if err != nil { slog.Error("数据转发失败 client->user", "err", err) } @@ -118,7 +116,7 @@ func (s *Service) processDataConn(client net.Conn) error { select { case <-s.ctx.Done(): - case <-utils.ChanWgWait(s.ctx, &wg): + case <-utils.WgWait(&wg): } proxy := time.Now() diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go index 626e3c6..130e7d6 100644 --- a/server/fwd/dispatcher/dispatch.go +++ b/server/fwd/dispatcher/dispatch.go @@ -6,7 +6,7 @@ import ( "log/slog" "net" "proxy-server/pkg/utils" - "proxy-server/server/fwd/core" + "proxy-server/server/core" "proxy-server/server/fwd/http" "proxy-server/server/fwd/metrics" "proxy-server/server/fwd/socks" @@ -19,13 +19,14 @@ import ( ) type Server struct { - ctx context.Context - cancel context.CancelFunc - Port uint16 - Conn chan *core.Conn + ctx context.Context + cancel context.CancelFunc + readTimeout time.Duration + Port uint16 + Conn chan *core.Conn } -func New(port uint16) (*Server, error) { +func New(port uint16, readTimeout time.Duration) (*Server, error) { if port == 0 { return nil, errors.New("port is required") @@ -35,6 +36,7 @@ func New(port uint16) (*Server, error) { return &Server{ ctx, cancel, + readTimeout, port, make(chan *core.Conn), }, nil @@ -54,7 +56,7 @@ func (s *Server) Run() error { defer utils.Close(ls) m := cmux.New(ls) - m.SetReadTimeout(5 * time.Second) + m.SetReadTimeout(s.readTimeout) defer m.Close() socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05}))) diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index 1d6758f..b0a80f2 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -4,7 +4,7 @@ import ( "context" "log/slog" "proxy-server/pkg/utils" - "proxy-server/server/fwd/core" + "proxy-server/server/core" "sync" ) @@ -12,7 +12,7 @@ type Service struct { ctx context.Context cancel context.CancelFunc - userConnMap core.ConnMap + userConnMap core.SyncMap[string, *core.Conn] fwdLesWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup @@ -40,7 +40,7 @@ func (s *Service) Run() error { wg.Add(1) go func() { defer wg.Done() - err := s.startCtrlTun() + err := s.listenCtrl() if err != nil { slog.Error("fwd 控制通道监听发生错误", "err", err) errQuit <- struct{}{} @@ -52,7 +52,7 @@ func (s *Service) Run() error { wg.Add(1) go func() { defer wg.Done() - err := s.startDataTun() + err := s.listenData() if err != nil { slog.Error("fwd 数据通道监听发生错误", "err", err) errQuit <- struct{}{} @@ -75,13 +75,6 @@ func (s *Service) Run() error { s.fwdLesWg.Wait() s.userConnWg.Wait() - // 清理资源 - s.userConnMap.Range(func(key string, value *core.Conn) bool { - utils.Close(value) - return true - }) - s.userConnMap.Clear() - s.ctrlConnWg.Wait() slog.Debug("控制通道连接已关闭") s.dataConnWg.Wait() diff --git a/server/fwd/http/http.go b/server/fwd/http/http.go index 098e1dd..b464a50 100644 --- a/server/fwd/http/http.go +++ b/server/fwd/http/http.go @@ -5,12 +5,13 @@ import ( "context" "encoding/base64" "fmt" + "github.com/google/uuid" "io" "net" "net/textproto" "net/url" + "proxy-server/server/core" "proxy-server/server/fwd/auth" - "proxy-server/server/fwd/core" "strings" "errors" @@ -132,7 +133,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) { return &core.Conn{ Conn: req.conn, Reader: req.reader, - Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(), + Tag: uuid.New(), Protocol: "http", Dest: req.dest, Auth: req.auth, @@ -176,7 +177,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) { return &core.Conn{ Conn: req.conn, Reader: newReader, - Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(), + Tag: uuid.New(), Protocol: "http", Dest: req.dest, Auth: req.auth, diff --git a/server/fwd/repo/channel.go b/server/fwd/repo/channel.go deleted file mode 100644 index 8710f8d..0000000 --- a/server/fwd/repo/channel.go +++ /dev/null @@ -1,22 +0,0 @@ -package repo - -import ( - "time" - - "gorm.io/gorm" -) - -// Channel 连接认证模型 -type Channel struct { - gorm.Model - UserId uint - NodeId uint - UserAddr string - NodePort int - AuthIp bool - AuthPass bool - Protocol string - Username string - Password string - Expiration time.Time -} diff --git a/server/fwd/repo/node.go b/server/fwd/repo/node.go deleted file mode 100644 index f2bae94..0000000 --- a/server/fwd/repo/node.go +++ /dev/null @@ -1,15 +0,0 @@ -package repo - -import "gorm.io/gorm" - -// Node 客户端模型 -type Node struct { - gorm.Model - Name string - Version byte - FwdPort uint16 - Provider string - Location string - - Channels []Channel `gorm:"foreignKey:NodeId"` -} diff --git a/server/fwd/socks/socks.go b/server/fwd/socks/socks.go index fa28dd0..a5757d6 100644 --- a/server/fwd/socks/socks.go +++ b/server/fwd/socks/socks.go @@ -6,12 +6,13 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/google/uuid" "io" "log/slog" "net" "proxy-server/pkg/utils" + "proxy-server/server/core" "proxy-server/server/fwd/auth" - "proxy-server/server/fwd/core" "slices" ) @@ -83,7 +84,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { Conn: conn, Reader: reader, Protocol: "socks5", - Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), + Tag: uuid.New(), Dest: request.DestAddr, Auth: authCtx, }, nil diff --git a/server/fwd/user.go b/server/fwd/user.go new file mode 100644 index 0000000..15df768 --- /dev/null +++ b/server/fwd/user.go @@ -0,0 +1,82 @@ +package fwd + +import ( + "context" + "encoding/hex" + "errors" + "log/slog" + "net" + "proxy-server/pkg/utils" + "proxy-server/server/core" + "proxy-server/server/env" + "proxy-server/server/fwd/dispatcher" + "proxy-server/server/fwd/metrics" + "time" +) + +func (s *Service) listenUser(port uint16, ctrl net.Conn) error { + dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second) + if err != nil { + return err + } + defer dspt.Close() + + go func() { + err := dspt.Run() + if err != nil { + slog.Error("代理服务运行失败", "err", err) + } + }() + + // 处理连接 + for { + select { + case <-s.ctx.Done(): + return nil + case user := <-dspt.Conn: + metrics.TimerAuth.Store(user.Conn, time.Now()) + s.userConnWg.Add(1) + go func() { + defer s.userConnWg.Done() + err := s.processUserConn(user, ctrl) + if err != nil { + slog.Error("处理用户连接失败", "err", err) + utils.Close(user) + } + }() + } + } +} + +func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) (err error) { + + // 发送代理命令 + err = s.sendProxy(ctrl, user.Tag, user.Dest.String()) + if err != nil { + return err + } + + // 保存用户连接 + s.userConnMap.Store(hex.EncodeToString(user.Tag[:]), user) + + // 如果限定时间内没有建立数据通道,则关闭连接 + var timeout, cancel = context.WithTimeout(context.Background(), time.Duration(env.AppDataTimeout)*time.Second) + defer cancel() + + select { + case <-timeout.Done(): + err = timeout.Err() + case <-s.ctx.Done(): + err = s.ctx.Err() + } + + _, ok := s.userConnMap.LoadAndDelete(hex.EncodeToString(user.Tag[:])) + if ok { + utils.Close(user) + if errors.Is(err, context.DeadlineExceeded) { + slog.Error("用户连接超时", "tag", hex.EncodeToString(user.Tag[:]), "addr", user.RemoteAddr().String()) + } + } + + return nil +} diff --git a/server/server.go b/server/server.go index 6c5f983..db2c9c4 100644 --- a/server/server.go +++ b/server/server.go @@ -109,24 +109,26 @@ func (s *server) Run() (err error) { } cancel() - timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() + // 主协程退出流程 + wg.Add(1) + go func() { + defer wg.Done() + // 报告下线 + slog.Debug("报告服务下线") + err = report.Offline(app.Name) + if err != nil { + slog.Error("服务下线失败", "err", err) + } - // 报告下线 - slog.Debug("报告服务下线") - err = report.Offline(app.Name) - if err != nil { - slog.Error("服务下线失败", "err", err) - } - - // 关闭 redis - g.ExitRedis() + // 关闭 redis + g.ExitRedis() + }() // 等待其它服务关闭 select { - case <-utils.ChanWgWait(timeout, &wg): + case <-utils.WgWait(&wg): slog.Info("服务正常关闭") - case <-timeout.Done(): + case <-time.After(time.Duration(env.AppExitTimeout) * time.Second): slog.Warn("超时强制关闭") } diff --git a/server/web/handlers/auth.go b/server/web/handlers/auth.go index d181435..9800577 100644 --- a/server/web/handlers/auth.go +++ b/server/web/handlers/auth.go @@ -26,7 +26,7 @@ func Auth(ctx *fiber.Ctx) (err error) { } // 保存授权配置 - app.Permits[req.Port] = req.Permit + app.Permits.Store(req.Port, &req.Permit) return nil }