重新规划网关与节点的交互协议,实现统一命令位的识别和处理

This commit is contained in:
2025-05-16 15:13:16 +08:00
parent d65fe4db6f
commit 8a6a4833d4
22 changed files with 609 additions and 373 deletions

View File

@@ -2,7 +2,7 @@
### 目录结构 ### 目录结构
server/fwd: 服务端核心代码 server/fwd: 网关核心代码
- core: 核心代码,目前主要是连接管理 - core: 核心代码,目前主要是连接管理
- dispatcher: 请求处理器,负责解析传入协议,并将请求分发到对应的处理器 - dispatcher: 请求处理器,负责解析传入协议,并将请求分发到对应的处理器
@@ -24,7 +24,7 @@ server/fwd: 服务端核心代码
## 底层协议设计 ## 底层协议设计
### 步骤说明 ### 步骤说明(待更新)
1. 启动转发服务,尝试注册自身到后端服务,随后持续报告心跳 1. 启动转发服务,尝试注册自身到后端服务,随后持续报告心跳
2. 启动边缘节点后,尝试注册自身到后端服务,随后持续报告心跳 2. 启动边缘节点后,尝试注册自身到后端服务,随后持续报告心跳
@@ -36,33 +36,71 @@ server/fwd: 服务端核心代码
8. 当成功建立数据通道后,边缘节点将数据通道标识以及对目标地址的连接结果提供给转发服务 8. 当成功建立数据通道后,边缘节点将数据通道标识以及对目标地址的连接结果提供给转发服务
9. 如果连接成功建立,则开始代理流量,如果连接失败,则关闭数据通道 9. 如果连接成功建立,则开始代理流量,如果连接失败,则关闭数据通道
### 协议报文详情 ### 协议报文说明
协议中所有数值都以大端形式传输 协议中所有数值都以大端形式传输
#### 建立控制通道 控制通道命令列表:
1.客户端发送: | 编码 | 命令 | 发起方 | 备注 |
|----|-------|-------|--------------------|
| 1 | pong | proxy | 网关响应,和 status 相同 |
| 2 | ping | edge | 边缘节点心跳 |
| 3 | open | edge | 边缘节点连接到网关 |
| 4 | close | edge | 边缘节点断开连接 |
| 5 | proxy | proxy | 网关通知节点打开数据通道用来转发数据 |
| id(4) | ### 建立控制通道
|--------|
| 客户端 ID |
2.服务端发送: 1.节点发起
| cmd(1) | id(4) |
|---------|-------|
| 命令固定为 3 | 节点 ID |
2.网关响应
| status(1) | | status(1) |
|-----------| |-----------|
| 状态,固定为 1 | | 状态,固定为 1 |
#### 建立数据通道 ### 控制通道心跳
1.服务端发送: 1.节点发起
| tag(16) | dst_len(2) | dst_buf(n) | | cmd(1) |
|---------|------------|------------| |---------|
| 通道标识 | 目标地址长度 | 目标地址 | | 命令固定为 2 |
2.客户端发送 2.网关响应
| status(1) |
|-----------|
| 状态,固定为 1 |
### 关闭控制通道
1.节点发起
| cmd(1) |
|---------|
| 命令固定为 4 |
2.网关响应:
| status(1) |
|-----------|
| 状态,固定为 1 |
### 建立数据通道
1.网关发送:
| cmd(1) | tag(16) | dst_len(2) | dst_buf(n) |
|---------|---------|------------|------------|
| 命令固定为 5 | 通道标识 | 目标地址长度 | 目标地址 |
2.节点发送:
| tag(16) | status(1) | | tag(16) | status(1) |
|---------|-----------------------| |---------|-----------------------|
@@ -76,9 +114,12 @@ server/fwd: 服务端核心代码
```json5 ```json5
{ {
"content": "string", // base64 编码的请求体 "content": "string",
"nonce": "string", // 随机数值 // base64 编码的请求体
"timestamp": "number" // 时间戳,精确到毫秒 "nonce": "string",
// 随机数值
"timestamp": "number"
// 时间戳,精确到毫秒
} }
``` ```

View File

@@ -2,20 +2,22 @@ package client
import ( import (
"bufio" "bufio"
"context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net" "net"
_ "net/http/pprof"
"os"
"os/signal"
"proxy-server/client/core" "proxy-server/client/core"
"proxy-server/client/env" "proxy-server/client/env"
"proxy-server/client/geo" "proxy-server/client/geo"
"proxy-server/client/report" "proxy-server/client/report"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"time" "time"
"errors"
_ "net/http/pprof"
) )
func Start() error { func Start() error {
@@ -31,7 +33,7 @@ func Start() error {
slog.Debug("获取节点归属地...") slog.Debug("获取节点归属地...")
err = geo.Query() err = geo.Query()
if err != nil { if err != nil {
slog.Error("获取归属地失败", "err", err) return fmt.Errorf("获取节点归属地失败: %w", err)
} }
// 注册节点 // 注册节点
@@ -41,27 +43,38 @@ func Start() error {
return fmt.Errorf("注册节点失败: %w", err) return fmt.Errorf("注册节点失败: %w", err)
} }
// 性能监控
// go func() {
// runtime.SetBlockProfileRate(1)
// err := http.ListenAndServe(":7070", nil)
// if err != nil {
// slog.Error("性能监控服务启动失败", "err", err)
// }
// }()
// 建立控制通道 // 建立控制通道
for { var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
err := ctrl(id, host) defer cancel()
if err != nil {
go func() {
for {
err = ctrl(ctx, id, host)
if err == nil {
return
}
select {
case <-ctx.Done():
return
default:
}
slog.Error("建立控制通道失败", "err", err) slog.Error("建立控制通道失败", "err", err)
slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval)) slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval))
time.Sleep(time.Duration(core.RetryInterval) * time.Second) time.Sleep(time.Duration(core.RetryInterval) * time.Second)
} }
}()
// 下线节点
slog.Debug("下线节点...")
err = report.Offline()
if err != nil {
slog.Error("下线节点失败", "err", err)
} }
return ctx.Err()
} }
func ctrl(id int32, host string) error { func ctrl(ctx context.Context, id int32, host string) error {
ctrlAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdCtrlPort)) ctrlAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdCtrlPort))
dataAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdDataPort)) dataAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdDataPort))
@@ -71,103 +84,94 @@ func ctrl(id int32, host string) error {
return errors.New("连接失败") return errors.New("连接失败")
} }
defer utils.Close(conn) defer utils.Close(conn)
var reader = bufio.NewReader(conn)
// 发送客户端信息 // 发送节点连接命令
var buf = make([]byte, 4) err = sendOpen(reader, conn, id)
_, err = binary.Encode(buf, binary.BigEndian, id)
if err != nil { if err != nil {
return fmt.Errorf("编码客户端 ID 失败: %w", err) return fmt.Errorf("发送节点信息失败: %w", err)
}
_, err = conn.Write(buf)
if err != nil {
return fmt.Errorf("发送客户端 ID 失败: %w", err)
} }
// 等待服务端响应 // 异步定时发送心跳
reader := bufio.NewReader(conn) go func() {
respBuf, err := reader.ReadByte() ticker := time.NewTicker(time.Duration(core.HeartbeatInterval) * time.Second)
if err != nil { defer ticker.Stop()
return errors.New("接收响应失败") for {
} select {
if respBuf != 1 { case <-ctx.Done():
return errors.New("服务端响应失败") return
} else { case tick := <-ticker.C:
slog.Info("成功建立连接") err := sendPing(reader, conn)
} if err != nil {
slog.Error("发送心跳失败", "time", tick, "err", err)
}
}
}
}()
// 等待用户连接 // 等待用户连接
// 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道 // 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道
slog.Info("等待用户连接") slog.Info("等待用户连接")
for { for loop := true; loop; {
select {
// 接收 dst case <-ctx.Done():
dstLen, err := reader.ReadByte() loop = false
if err != nil { default:
return errors.New("接收 dstLen 失败") // 接收 dst
} tag, addr, err := onConn(reader)
dstBuf, err := utils.ReadBuffer(reader, int(dstLen))
if err != nil {
return errors.New("接收 dstBuf 失败")
}
addr := string(dstBuf)
// 接收 tag
tagLen, err := reader.ReadByte()
if err != nil {
return errors.New("接收 tagLen 失败")
}
tagBuf, err := utils.ReadBuffer(reader, int(tagLen))
if err != nil {
return errors.New("接收 tagBuf 失败")
}
// 建立数据通道
go func() {
err := data(dataAddr, addr, tagBuf)
if err != nil { if err != nil {
slog.Error("建立数据通道失败", "err", err) return fmt.Errorf("接收连接命令失败: %w", err)
} }
}()
// 建立数据通道
go func() {
err := data(dataAddr, addr, tag)
if err != nil {
slog.Error("建立数据通道失败", "err", err)
}
}()
}
} }
// 发送关闭连接(不 return err否则会重新连接
err = sendClose(reader, conn)
if err != nil {
slog.Error("发送关闭连接失败", "err", err)
}
return nil
} }
func data(dataAddr string, dest string, tag []byte) error { func data(proxy string, dest string, tag [16]byte) error {
// 向目标地址建立连接
var result = 1
var dstErr error
dst, err := net.Dial("tcp", dest)
if err != nil {
dstErr = fmt.Errorf("连接目标地址失败: %w", dstErr)
result = 0
}
defer utils.Close(dst)
// 向服务端建立连接 // 向服务端建立连接
src, err := net.Dial("tcp", dataAddr) src, err := net.Dial("tcp", proxy)
if err != nil { if err != nil {
return errors.New("连接服务端失败") return errors.New("连接服务端失败")
} }
defer utils.Close(src)
tagLen := byte(len(tag))
tagBuf := make([]byte, 2+tagLen)
tagBuf[1] = tagLen
copy(tagBuf[2:], tag)
// 向目标地址建立连接
dst, dstErr := net.Dial("tcp", dest)
if dstErr != nil {
tagBuf[0] = 0
} else {
tagBuf[0] = 1
}
// 发送连接状态 // 发送连接状态
_, err = src.Write(tagBuf) var buf = make([]byte, 17)
copy(buf[0:16], tag[:])
buf[16] = byte(result)
_, err = src.Write(buf)
if err != nil { if err != nil {
utils.Close(src)
if dst != nil {
utils.Close(dst)
}
return errors.New("发送连接状态失败") return errors.New("发送连接状态失败")
} }
if tagBuf[0] == 0 { if result == 0 {
utils.Close(src) return dstErr
if dst != nil {
utils.Close(dst)
}
return errors.New("连接目标地址失败")
} }
go func() { go func() {
@@ -186,3 +190,91 @@ func data(dataAddr string, dest string, tag []byte) error {
}() }()
return nil return nil
} }
func sendOpen(reader io.Reader, writer io.Writer, id int32) error {
// 发送打开连接
var buf = make([]byte, 5)
buf[0] = 3
binary.BigEndian.PutUint32(buf[1:], uint32(id))
_, err := writer.Write(buf)
if err != nil {
return fmt.Errorf("发送打开连接失败: %w", err)
}
// 等待服务端响应
respBuf := make([]byte, 1)
_, err = io.ReadFull(reader, respBuf)
if err != nil {
return fmt.Errorf("接收服务端响应失败: %w", err)
}
if respBuf[0] != 1 {
return errors.New("服务端响应失败")
}
return nil
}
func sendClose(reader io.Reader, writer io.Writer) error {
// 发送关闭连接
_, err := writer.Write([]byte{4})
if err != nil {
return err
}
// 等待服务端响应
respBuf := make([]byte, 1)
_, err = io.ReadFull(reader, respBuf)
if err != nil {
return fmt.Errorf("接收服务端响应失败: %w", err)
}
if respBuf[0] != 1 {
return errors.New("服务端响应失败")
}
return nil
}
func sendPing(reader io.Reader, writer io.Writer) error {
_, err := writer.Write([]byte{2})
if err != nil {
return err
}
// 等待服务端响应
respBuf := make([]byte, 1)
_, err = io.ReadFull(reader, respBuf)
if err != nil {
return fmt.Errorf("接收服务端响应失败: %w", err)
}
if respBuf[0] != 1 {
return errors.New("服务端响应失败")
}
return nil
}
func onConn(reader io.Reader) (tag [16]byte, addr string, err error) {
var buf = make([]byte, 1+16+2)
_, err = io.ReadFull(reader, buf)
if err != nil {
return [16]byte{}, "", err
}
if buf[0] != 5 {
return [16]byte{}, "", errors.New("命令错误")
}
tag = [16]byte(buf[1:17])
var addrLen = binary.BigEndian.Uint16(buf[17:19])
var addrBuf = make([]byte, addrLen)
_, err = io.ReadFull(reader, addrBuf)
if err != nil {
return [16]byte{}, "", err
}
addr = string(addrBuf)
return tag, addr, nil
}

View File

@@ -6,3 +6,4 @@ const FwdCtrlPort uint = 18080
const FwdDataPort uint = 18081 const FwdDataPort uint = 18081
const RetryInterval uint = 5 const RetryInterval uint = 5
const HeartbeatInterval uint = 30

View File

@@ -67,3 +67,7 @@ func Online(prov, city, isp string) (id int32, host string, err error) {
return respBody.Id, respBody.Host, nil return respBody.Id, respBody.Host, nil
} }
func Offline() error {
return nil
}

View File

@@ -39,15 +39,11 @@ func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
return ch return ch
} }
func ChanWgWait[T WaitGroup](ctx context.Context, wg T) chan struct{} { func WgWait[T WaitGroup](wg T) <-chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
defer close(ch)
wg.Wait() wg.Wait()
select { ch <- struct{}{}
case <-ctx.Done():
case ch <- struct{}{}:
}
}() }()
return ch return ch
} }

View File

@@ -1,12 +1,15 @@
package app package app
import "proxy-server/server/core" import (
"proxy-server/server/core"
)
var ( var (
Id int32 Id int32
Name string Name string
PlatformSecret string // 平台密钥,验证接收的请求是否属于平台 PlatformSecret string // 平台密钥,验证接收的请求是否属于平台
Assigns = make(map[uint16]int32) // 转发端口 -> 转发服务ID Clients = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口
Permits = make(map[uint16]core.Permit) // 转发端口 -> 权限配置 Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID
Permits = core.SyncMap[uint16, *core.Permit]{} // 转发端口 -> 权限配置
) )

View File

@@ -4,40 +4,13 @@ import (
"bufio" "bufio"
"fmt" "fmt"
"net" "net"
"sync"
"time" "time"
) )
type ConnMap struct {
_map sync.Map
}
func (c *ConnMap) LoadAndDelete(key string) (*Conn, bool) {
_value, ok := c._map.LoadAndDelete(key)
if !ok {
return nil, false
}
return _value.(*Conn), true
}
func (c *ConnMap) Store(key string, value *Conn) {
c._map.Store(key, value)
}
func (c *ConnMap) Range(f func(key string, value *Conn) bool) {
c._map.Range(func(key, value any) bool {
return f(key.(string), value.(*Conn))
})
}
func (c *ConnMap) Clear() {
c._map.Clear()
}
type Conn struct { type Conn struct {
Conn net.Conn Conn net.Conn
Reader *bufio.Reader Reader *bufio.Reader
Tag string Tag [16]byte
Protocol string Protocol string
Dest *FwdAddr Dest *FwdAddr
Auth *AuthContext Auth *AuthContext
@@ -75,10 +48,6 @@ func (c Conn) SetWriteDeadline(t time.Time) error {
return c.Conn.SetWriteDeadline(t) return c.Conn.SetWriteDeadline(t)
} }
func (c Conn) DestAddr() net.Addr {
return c.Dest
}
type FwdAddr struct { type FwdAddr struct {
IP net.IP IP net.IP
Port int Port int

65
server/core/map.go Normal file
View File

@@ -0,0 +1,65 @@
package core
import "sync"
type SyncMap[K any, V any] struct {
_map sync.Map
}
func (m *SyncMap[K, V]) Store(key K, value V) {
m._map.Store(key, value)
}
func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) {
v, ok := m._map.Load(key)
if ok {
value = v.(V)
}
return
}
func (m *SyncMap[K, V]) Swap(key K, value V) (previous V, loaded bool) {
v, loaded := m._map.Swap(key, value)
if loaded {
previous = v.(V)
}
return
}
func (m *SyncMap[K, V]) Delete(key K) {
m._map.Delete(key)
}
func (m *SyncMap[K, V]) Clear() {
m._map.Clear()
}
func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) {
m._map.Range(func(k, v any) bool {
return f(k.(K), v.(V))
})
}
func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
v, loaded := m._map.LoadOrStore(key, value)
if loaded {
actual = v.(V)
}
return
}
func (m *SyncMap[K, V]) LoadAndDelete(key K) (value V, ok bool) {
v, ok := m._map.LoadAndDelete(key)
if ok {
value = v.(V)
}
return
}
func (m *SyncMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
return m._map.CompareAndSwap(key, old, new)
}
func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
return m._map.CompareAndDelete(key, old)
}

29
server/env/env.go vendored
View File

@@ -10,10 +10,13 @@ import (
) )
var ( var (
AppCtrlPort uint16 = 18080 AppCtrlPort uint16 = 18080
AppDataPort uint16 = 18081 AppDataPort uint16 = 18081
AppWebPort uint16 = 8848 AppWebPort uint16 = 8848
AppLogMode = "dev" AppLogMode = "dev"
AppExitTimeout = 5 // 等待服务停止的超时时间
AppDataTimeout = 10 // 等待数据通道连接的超时时间
AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程)
ClientId string ClientId string
ClientSecret string ClientSecret string
@@ -67,6 +70,24 @@ func Init() {
AppLogMode = value AppLogMode = value
} }
value = os.Getenv("APP_EXIT_TIMEOUT")
if value != "" {
appExitTimeout, err := strconv.Atoi(value)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_EXIT_TIMEOUT 格式错误: %v", err))
}
AppExitTimeout = appExitTimeout
}
value = os.Getenv("APP_DATA_TIMEOUT")
if value != "" {
appDataTimeout, err := strconv.Atoi(value)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_DATA_TIMEOUT 格式错误: %v", err))
}
AppDataTimeout = appDataTimeout
}
value = os.Getenv("CLIENT_ID") value = os.Getenv("CLIENT_ID")
if value != "" { if value != "" {
ClientId = value ClientId = value

View File

@@ -7,7 +7,7 @@ import (
"io" "io"
"log/slog" "log/slog"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/core"
"strings" "strings"
"errors" "errors"
@@ -27,7 +27,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
slog.String("proxy", conn.Protocol), slog.String("proxy", conn.Protocol),
slog.String("node", conn.LocalAddr().String()), slog.String("node", conn.LocalAddr().String()),
slog.String("proto", proto), slog.String("proto", proto),
slog.String("dest", conn.DestAddr().String()), slog.String("dest", conn.Dest.String()),
slog.String("domain", domain), slog.String("domain", domain),
) )
} }

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"net" "net"
"proxy-server/server/app" "proxy-server/server/app"
"proxy-server/server/fwd/core" "proxy-server/server/core"
"strconv" "strconv"
"time" "time"
@@ -36,7 +36,7 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
} }
// 查找权限配置 // 查找权限配置
var permit, ok = app.Permits[uint16(localPort)] var permit, ok = app.Permits.Load(uint16(localPort))
if !ok { if !ok {
return nil, errors.New("没有权限") return nil, errors.New("没有权限")
} }
@@ -68,10 +68,11 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
} }
} }
var id, _ = app.Assigns.Load(uint16(localPort))
return &core.AuthContext{ return &core.AuthContext{
Timeout: time.Since(permit.Expire).Seconds(), Timeout: time.Since(permit.Expire).Seconds(),
Payload: core.Payload{ Payload: core.Payload{
ID: app.Assigns[uint16(localPort)], ID: id,
}, },
}, nil }, nil
} }

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -11,25 +12,21 @@ import (
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/app" "proxy-server/server/app"
"proxy-server/server/env" "proxy-server/server/env"
"proxy-server/server/fwd/core"
"proxy-server/server/fwd/dispatcher"
"proxy-server/server/fwd/metrics"
"proxy-server/server/report" "proxy-server/server/report"
"strconv" "strconv"
"strings"
"time"
"errors"
) )
type CtrlCmd struct { type CtrlCmdType int
conn net.Conn
buf []byte
}
var ctrlCmdChan = make(chan CtrlCmd, 1024) const (
CtrlCmdPong CtrlCmdType = iota + 1
CtrlCmdPing
CtrlCmdOpen
CtrlCmdClose
CtrlCmdProxy
)
func (s *Service) startCtrlTun() error { func (s *Service) listenCtrl() error {
ctrlPort := env.AppCtrlPort ctrlPort := env.AppCtrlPort
slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort)))
@@ -56,7 +53,7 @@ func (s *Service) startCtrlTun() error {
go func() { go func() {
defer s.ctrlConnWg.Done() defer s.ctrlConnWg.Done()
defer utils.Close(conn) defer utils.Close(conn)
err := s.processCtrlConn(conn) err := s.processCtrlConn(s.ctx, conn)
if err != nil { if err != nil {
slog.Error("处理控制通道连接失败", "err", err) slog.Error("处理控制通道连接失败", "err", err)
} }
@@ -67,25 +64,80 @@ func (s *Service) startCtrlTun() error {
return err return err
} }
func (s *Service) processCtrlConn(conn net.Conn) error { func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
for {
// 循环等待直到服务关闭
select {
case <-ctx.Done():
return nil
default:
}
var recv = make([]byte, 4) // 读取命令
_, err := io.ReadFull(reader, recv) cmdByte, err := reader.ReadByte()
if err != nil { if err != nil {
return fmt.Errorf("读取客户端 ID 失败: %w", err) return fmt.Errorf("读取节点命令失败: %w", err)
}
var cmd = CtrlCmdType(cmdByte)
switch cmd {
// 连接建立命令
case CtrlCmdOpen:
var recv = make([]byte, 4)
_, err = io.ReadFull(reader, recv)
if err != nil {
return fmt.Errorf("读取节点 ID 失败: %w", err)
}
var client = int32(binary.BigEndian.Uint32(recv))
err = s.onOpen(conn, client)
if err != nil {
return fmt.Errorf("处理连接建立命令失败: %w", err)
}
// 心跳命令
case CtrlCmdPing:
err = s.onPing(conn)
if err != nil {
return fmt.Errorf("处理心跳命令失败: %w", err)
}
// 连接关闭命令
case CtrlCmdClose:
err = s.onClose(conn)
if err != nil {
return fmt.Errorf("处理关闭命令失败: %w", err)
}
return nil
// 忽略其他不应该由节点发起的命令
default:
return fmt.Errorf("无法处理控制命令: %d", cmd)
}
}
}
func (s *Service) onPing(conn net.Conn) (err error) {
return s.sendPong(conn)
}
func (s *Service) onOpen(conn net.Conn, client int32) (err error) {
// open 命令全局只执行一次
_, ok := app.Clients.Load(client)
if ok {
return fmt.Errorf("节点 ID %d 已经连接", client)
} }
var client = int32(binary.BigEndian.Uint32(recv))
// 分配端口 // 分配端口
var minim uint16 = 20000 var minim uint16 = 20000
var maxim uint16 = 60000 var maxim uint16 = 60000
var port uint16 var port uint16
for i := minim; i < maxim; i++ { for i := minim; i < maxim; i++ {
var _, ok = app.Assigns[i] var _, ok = app.Assigns.Load(i)
if !ok { if !ok {
port = i port = i
app.Assigns[i] = client app.Assigns.Store(i, client)
app.Clients.Store(client, i)
break break
} }
} }
@@ -94,126 +146,75 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
} }
// 报告端口分配 // 报告端口分配
err = report.Assigned(client, port) if err = report.Assigned(client, port); err != nil {
if err != nil {
return fmt.Errorf("报告端口分配失败: %w", err) return fmt.Errorf("报告端口分配失败: %w", err)
} }
// 响应客户端 // 响应客户端
_, err = conn.Write([]byte{1}) if err = s.sendPong(conn); err != nil {
if err != nil {
return fmt.Errorf("响应客户端失败: %w", err) return fmt.Errorf("响应客户端失败: %w", err)
} }
// 启动转发服务 // 启动转发服务
slog.Info("监听转发端口", "port", port, "client", client)
proxy, err := dispatcher.New(port)
if err != nil {
return err
}
defer proxy.Close()
s.fwdLesWg.Add(1) s.fwdLesWg.Add(1)
go func() { go func() {
defer s.fwdLesWg.Done() defer s.fwdLesWg.Done()
err := proxy.Run() slog.Info("监听转发端口", "port", port, "client", client)
err = s.listenUser(port, conn)
if err != nil { if err != nil {
slog.Error("代理服务运行失败", "err", err) slog.Error("监听转发端口失败", "port", port, "client", client, "err", err)
} }
}() }()
// 监听控制通道连接 return nil
errCh := make(chan error, 1)
go func() {
defer close(errCh)
_, err := reader.ReadByte()
errCh <- err
}()
// 批量同步写入
go func() {
for {
select {
case <-s.ctx.Done():
return
case cmd := <-ctrlCmdChan:
_, err := cmd.conn.Write(cmd.buf)
if err != nil {
slog.Error("批量写入失败", "err", err)
utils.Close(cmd.conn)
}
}
}
}()
// 处理连接
for {
select {
case <-s.ctx.Done():
return nil
case err := <-errCh:
switch {
case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."):
slog.Debug("客户端主动断开连接")
return nil
case err == nil:
return errors.New("客户端握手失败")
default:
return fmt.Errorf("客户端意外断开连接: %w", err)
}
case user := <-proxy.Conn:
metrics.TimerAuth.Store(user.Conn, time.Now())
s.userConnWg.Add(1)
go func() {
defer s.userConnWg.Done()
err := s.processUserConn(user, conn)
if err != nil {
slog.Error("处理用户连接失败", "err", err)
utils.Close(user)
}
}()
}
}
} }
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { func (s *Service) onClose(conn net.Conn) (err error) {
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
// 组织写入信息 if err != nil {
dst := user.DestAddr().String() return err
dstLen := len(dst)
tag := user.Tag
tagLen := len(tag)
writeBuf := make([]byte, 2+dstLen+tagLen)
writeBuf[0] = byte(dstLen)
copy(writeBuf[1:], dst)
writeBuf[1+dstLen] = byte(tagLen)
copy(writeBuf[2+dstLen:], tag)
// 异步写入命令
ctrlCmdChan <- CtrlCmd{
conn: ctrl,
buf: writeBuf,
} }
// 记录用户连接 port, err := strconv.ParseUint(portStr, 10, 16)
s.userConnMap.Store(user.Tag, user) if err != nil {
return err
}
// 如果限定时间内没有建立数据通道,则关闭连接 id, _ := app.Assigns.LoadAndDelete(uint16(port))
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) app.Clients.Delete(id)
defer cancel() app.Assigns.Delete(uint16(port))
app.Permits.Delete(uint16(port))
select { err = s.sendPong(conn)
case <-s.ctx.Done(): if err != nil {
// 服务会在退出时统一关闭未消费的连接 return err
case <-timeout.Done():
storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag)
if ok {
slog.Debug("建立数据通道超时", "tag", user.Tag)
utils.Close(storedUser)
}
} }
return nil return nil
} }
func (s *Service) sendPong(conn net.Conn) (err error) {
_, err = conn.Write([]byte{byte(CtrlCmdPong)})
if err != nil {
return fmt.Errorf("响应客户端失败: %w", err)
}
return nil
}
func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) {
if len(addr) > 65535 {
return fmt.Errorf("代理地址过长: %s", addr)
}
buf := make([]byte, 1+16+2+len(addr))
buf[0] = byte(CtrlCmdProxy)
copy(buf[1:], tag[:])
binary.BigEndian.PutUint16(buf[17:], uint16(len(addr)))
copy(buf[19:], addr)
_, err = conn.Write(buf)
if err != nil {
return fmt.Errorf("发送代理命令失败: %w", err)
}
return nil
}

View File

@@ -1,7 +1,9 @@
package fwd package fwd
import ( import (
"bufio"
"fmt" "fmt"
"github.com/google/uuid"
"io" "io"
"log/slog" "log/slog"
"net" "net"
@@ -16,7 +18,7 @@ import (
"errors" "errors"
) )
func (s *Service) startDataTun() error { func (s *Service) listenData() error {
dataPort := env.AppDataPort dataPort := env.AppDataPort
slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort)))
@@ -57,38 +59,34 @@ func (s *Service) startDataTun() error {
} }
func (s *Service) processDataConn(client net.Conn) error { func (s *Service) processDataConn(client net.Conn) error {
var reader = bufio.NewReader(client)
// 接收 status // 接收连接结果
status, err := utils.ReadByte(client) var buf = make([]byte, 17)
_, err := io.ReadFull(reader, buf)
if err != nil { if err != nil {
return fmt.Errorf("从客户端获取 status 失败: %w", err) return fmt.Errorf("从客户端获取连接结果失败: %w", err)
} }
// 接收 tag tag := buf[0:16]
tagLen, err := utils.ReadByte(client) status := buf[16]
if err != nil {
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
}
tagBuf, err := utils.ReadBuffer(client, int(tagLen))
if err != nil {
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
}
tag := string(tagBuf)
// 找到用户连接 // 加载用户连接
user, ok := s.userConnMap.LoadAndDelete(tag) var tagStr = uuid.UUID(tag).String()
user, ok := s.userConnMap.LoadAndDelete(tagStr)
if !ok { if !ok {
return errors.New("用户连接已关闭tag" + tag) return fmt.Errorf("用户连接已关闭tag%s", tagStr)
} }
defer utils.Close(user) defer utils.Close(user)
data := time.Now()
// 检查状态 // 检查状态
if status != 1 { if status != 1 {
return errors.New("目标地址建立连接失败") return errors.New("目标地址建立连接失败")
} }
// 数据转发 // 转发数据
data := time.Now()
userPipeReader, userPipeWriter := io.Pipe() userPipeReader, userPipeWriter := io.Pipe()
defer utils.Close(userPipeWriter) defer utils.Close(userPipeWriter)
teeUser := io.TeeReader(user, userPipeWriter) teeUser := io.TeeReader(user, userPipeWriter)
@@ -110,7 +108,7 @@ func (s *Service) processDataConn(client net.Conn) error {
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
_, err := io.Copy(user, client) _, err := io.Copy(user, reader)
if err != nil { if err != nil {
slog.Error("数据转发失败 client->user", "err", err) slog.Error("数据转发失败 client->user", "err", err)
} }
@@ -118,7 +116,7 @@ func (s *Service) processDataConn(client net.Conn) error {
select { select {
case <-s.ctx.Done(): case <-s.ctx.Done():
case <-utils.ChanWgWait(s.ctx, &wg): case <-utils.WgWait(&wg):
} }
proxy := time.Now() proxy := time.Now()

View File

@@ -6,7 +6,7 @@ import (
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/core"
"proxy-server/server/fwd/http" "proxy-server/server/fwd/http"
"proxy-server/server/fwd/metrics" "proxy-server/server/fwd/metrics"
"proxy-server/server/fwd/socks" "proxy-server/server/fwd/socks"
@@ -19,13 +19,14 @@ import (
) )
type Server struct { type Server struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
Port uint16 readTimeout time.Duration
Conn chan *core.Conn Port uint16
Conn chan *core.Conn
} }
func New(port uint16) (*Server, error) { func New(port uint16, readTimeout time.Duration) (*Server, error) {
if port == 0 { if port == 0 {
return nil, errors.New("port is required") return nil, errors.New("port is required")
@@ -35,6 +36,7 @@ func New(port uint16) (*Server, error) {
return &Server{ return &Server{
ctx, ctx,
cancel, cancel,
readTimeout,
port, port,
make(chan *core.Conn), make(chan *core.Conn),
}, nil }, nil
@@ -54,7 +56,7 @@ func (s *Server) Run() error {
defer utils.Close(ls) defer utils.Close(ls)
m := cmux.New(ls) m := cmux.New(ls)
m.SetReadTimeout(5 * time.Second) m.SetReadTimeout(s.readTimeout)
defer m.Close() defer m.Close()
socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05}))) socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05})))

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"log/slog" "log/slog"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/core"
"sync" "sync"
) )
@@ -12,7 +12,7 @@ type Service struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
userConnMap core.ConnMap userConnMap core.SyncMap[string, *core.Conn]
fwdLesWg utils.CountWaitGroup fwdLesWg utils.CountWaitGroup
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
@@ -40,7 +40,7 @@ func (s *Service) Run() error {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := s.startCtrlTun() err := s.listenCtrl()
if err != nil { if err != nil {
slog.Error("fwd 控制通道监听发生错误", "err", err) slog.Error("fwd 控制通道监听发生错误", "err", err)
errQuit <- struct{}{} errQuit <- struct{}{}
@@ -52,7 +52,7 @@ func (s *Service) Run() error {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := s.startDataTun() err := s.listenData()
if err != nil { if err != nil {
slog.Error("fwd 数据通道监听发生错误", "err", err) slog.Error("fwd 数据通道监听发生错误", "err", err)
errQuit <- struct{}{} errQuit <- struct{}{}
@@ -75,13 +75,6 @@ func (s *Service) Run() error {
s.fwdLesWg.Wait() s.fwdLesWg.Wait()
s.userConnWg.Wait() s.userConnWg.Wait()
// 清理资源
s.userConnMap.Range(func(key string, value *core.Conn) bool {
utils.Close(value)
return true
})
s.userConnMap.Clear()
s.ctrlConnWg.Wait() s.ctrlConnWg.Wait()
slog.Debug("控制通道连接已关闭") slog.Debug("控制通道连接已关闭")
s.dataConnWg.Wait() s.dataConnWg.Wait()

View File

@@ -5,12 +5,13 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"fmt" "fmt"
"github.com/google/uuid"
"io" "io"
"net" "net"
"net/textproto" "net/textproto"
"net/url" "net/url"
"proxy-server/server/core"
"proxy-server/server/fwd/auth" "proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"strings" "strings"
"errors" "errors"
@@ -132,7 +133,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) {
return &core.Conn{ return &core.Conn{
Conn: req.conn, Conn: req.conn,
Reader: req.reader, Reader: req.reader,
Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(), Tag: uuid.New(),
Protocol: "http", Protocol: "http",
Dest: req.dest, Dest: req.dest,
Auth: req.auth, Auth: req.auth,
@@ -176,7 +177,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) {
return &core.Conn{ return &core.Conn{
Conn: req.conn, Conn: req.conn,
Reader: newReader, Reader: newReader,
Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(), Tag: uuid.New(),
Protocol: "http", Protocol: "http",
Dest: req.dest, Dest: req.dest,
Auth: req.auth, Auth: req.auth,

View File

@@ -1,22 +0,0 @@
package repo
import (
"time"
"gorm.io/gorm"
)
// Channel 连接认证模型
type Channel struct {
gorm.Model
UserId uint
NodeId uint
UserAddr string
NodePort int
AuthIp bool
AuthPass bool
Protocol string
Username string
Password string
Expiration time.Time
}

View File

@@ -1,15 +0,0 @@
package repo
import "gorm.io/gorm"
// Node 客户端模型
type Node struct {
gorm.Model
Name string
Version byte
FwdPort uint16
Provider string
Location string
Channels []Channel `gorm:"foreignKey:NodeId"`
}

View File

@@ -6,12 +6,13 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"github.com/google/uuid"
"io" "io"
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/core"
"proxy-server/server/fwd/auth" "proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"slices" "slices"
) )
@@ -83,7 +84,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Conn: conn, Conn: conn,
Reader: reader, Reader: reader,
Protocol: "socks5", Protocol: "socks5",
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), Tag: uuid.New(),
Dest: request.DestAddr, Dest: request.DestAddr,
Auth: authCtx, Auth: authCtx,
}, nil }, nil

82
server/fwd/user.go Normal file
View File

@@ -0,0 +1,82 @@
package fwd
import (
"context"
"encoding/hex"
"errors"
"log/slog"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/core"
"proxy-server/server/env"
"proxy-server/server/fwd/dispatcher"
"proxy-server/server/fwd/metrics"
"time"
)
func (s *Service) listenUser(port uint16, ctrl net.Conn) error {
dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second)
if err != nil {
return err
}
defer dspt.Close()
go func() {
err := dspt.Run()
if err != nil {
slog.Error("代理服务运行失败", "err", err)
}
}()
// 处理连接
for {
select {
case <-s.ctx.Done():
return nil
case user := <-dspt.Conn:
metrics.TimerAuth.Store(user.Conn, time.Now())
s.userConnWg.Add(1)
go func() {
defer s.userConnWg.Done()
err := s.processUserConn(user, ctrl)
if err != nil {
slog.Error("处理用户连接失败", "err", err)
utils.Close(user)
}
}()
}
}
}
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) (err error) {
// 发送代理命令
err = s.sendProxy(ctrl, user.Tag, user.Dest.String())
if err != nil {
return err
}
// 保存用户连接
s.userConnMap.Store(hex.EncodeToString(user.Tag[:]), user)
// 如果限定时间内没有建立数据通道,则关闭连接
var timeout, cancel = context.WithTimeout(context.Background(), time.Duration(env.AppDataTimeout)*time.Second)
defer cancel()
select {
case <-timeout.Done():
err = timeout.Err()
case <-s.ctx.Done():
err = s.ctx.Err()
}
_, ok := s.userConnMap.LoadAndDelete(hex.EncodeToString(user.Tag[:]))
if ok {
utils.Close(user)
if errors.Is(err, context.DeadlineExceeded) {
slog.Error("用户连接超时", "tag", hex.EncodeToString(user.Tag[:]), "addr", user.RemoteAddr().String())
}
}
return nil
}

View File

@@ -109,24 +109,26 @@ func (s *server) Run() (err error) {
} }
cancel() cancel()
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) // 主协程退出流程
defer cancel() wg.Add(1)
go func() {
defer wg.Done()
// 报告下线
slog.Debug("报告服务下线")
err = report.Offline(app.Name)
if err != nil {
slog.Error("服务下线失败", "err", err)
}
// 报告下线 // 关闭 redis
slog.Debug("报告服务下线") g.ExitRedis()
err = report.Offline(app.Name) }()
if err != nil {
slog.Error("服务下线失败", "err", err)
}
// 关闭 redis
g.ExitRedis()
// 等待其它服务关闭 // 等待其它服务关闭
select { select {
case <-utils.ChanWgWait(timeout, &wg): case <-utils.WgWait(&wg):
slog.Info("服务正常关闭") slog.Info("服务正常关闭")
case <-timeout.Done(): case <-time.After(time.Duration(env.AppExitTimeout) * time.Second):
slog.Warn("超时强制关闭") slog.Warn("超时强制关闭")
} }

View File

@@ -26,7 +26,7 @@ func Auth(ctx *fiber.Ctx) (err error) {
} }
// 保存授权配置 // 保存授权配置
app.Permits[req.Port] = req.Permit app.Permits.Store(req.Port, &req.Permit)
return nil return nil
} }