优化连接处理逻辑,增加超时设置;重构命令读取与错误处理;新增公共工具函数以简化错误处理

This commit is contained in:
2025-05-26 16:37:54 +08:00
parent 8c928a8321
commit c2dcae7af5
7 changed files with 205 additions and 154 deletions

View File

@@ -1,6 +1,6 @@
## TODO ## TODO
- 连接断开时尽量由 - 将协议内容抽离出公共包gateway 和 edge 节点共同调用
## 开发相关 ## 开发相关

View File

@@ -21,6 +21,8 @@ import (
) )
func Start() error { func Start() error {
var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer cancel()
// 初始化环境变量 // 初始化环境变量
slog.Debug("初始化环境变量...") slog.Debug("初始化环境变量...")
@@ -44,36 +46,17 @@ func Start() error {
} }
// 连接到网关 // 连接到网关
var ctx, cancel = signal.NotifyContext(context.Background(), os.Interrupt, os.Kill)
defer cancel()
var errCh = make(chan error) var errCh = make(chan error)
go func() { go func() {
for {
err = ctrl(ctx, id, host) err = ctrl(ctx, id, host)
if err == nil { if err == nil {
errCh <- nil errCh <- err
return
}
select {
case <-ctx.Done():
return
default:
slog.Error("建立控制通道失败", "err", err)
slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval))
}
select {
case <-ctx.Done():
return
case <-time.After(time.Duration(core.RetryInterval) * time.Second):
}
} }
}() }()
// 等待退出 // 等待退出
select { select {
case <-ctx.Done():
case err := <-errCh: case err := <-errCh:
if err != nil { if err != nil {
slog.Error("控制通道发生错误", "err", err) slog.Error("控制通道发生错误", "err", err)
@@ -102,7 +85,7 @@ func ctrl(ctx context.Context, id int32, host string) error {
defer utils.Close(conn) defer utils.Close(conn)
var reader = bufio.NewReader(conn) var reader = bufio.NewReader(conn)
// 发送节点连接命令 // 发送开启连接
slog.Debug("发送节点连接命令") slog.Debug("发送节点连接命令")
err = sendOpen(conn, id) err = sendOpen(conn, id)
if err != nil { if err != nil {
@@ -126,78 +109,55 @@ func ctrl(ctx context.Context, id int32, host string) error {
} }
}() }()
// 异步等待连接命令 // 异步读取节点命令
slog.Info("等待用户连接")
var cmdCh = make(chan ConnCmd)
var errCh = make(chan error) var errCh = make(chan error)
go func() { go func() {
for { for {
// 读取命令
cmd, err := reader.ReadByte() cmd, err := reader.ReadByte()
if err != nil { if err := utils.WarpConnErr(err); err != nil {
switch {
case errors.Is(err, net.ErrClosed):
err = fmt.Errorf("控制通道关闭: %w", err)
case errors.Is(err, io.EOF):
err = fmt.Errorf("网关关闭了控制通道: %w", err)
default:
err = fmt.Errorf("读取命令失败: %w", err)
}
errCh <- err errCh <- err
return return
} }
switch cmd { switch cmd {
// pong 命令,忽略
case 1: case 1:
// 忽略网关响应的 pong 命令
// 代理命令
case 5: case 5:
tag, addr, err := onConn(reader) err := onConn(reader, dataAddr)
if err != nil { if err != nil {
slog.Error("接收连接命令失败", "err", err) errCh <- fmt.Errorf("处理代理命令失败: %w", err)
return return
} }
cmdCh <- ConnCmd{
Tag: tag,
Addr: addr,
}
} }
} }
}() }()
// 等待建立数据通道 // 等待建立数据通道
for loop := true; loop; {
select { select {
case <-ctx.Done(): case <-ctx.Done():
loop = false
case err = <-errCh: case err = <-errCh:
slog.Error("读取控制命令失败", "err", err)
loop = false
case cmd := <-cmdCh:
slog.Debug("建立数据通道", "tag", cmd.Tag, "addr", cmd.Addr)
go func() {
err := data(dataAddr, cmd.Addr, cmd.Tag)
if err != nil {
slog.Error("建立数据通道失败", "err", err)
}
}()
}
} }
// 发送关闭连接(不 return err否则会重新连接 // 发送关闭连接
slog.Debug("发送关闭连接") slog.Debug("发送关闭连接")
err = sendClose(conn) err = sendClose(conn)
if err != nil { if err != nil {
slog.Error("发送关闭连接失败", "err", err) return fmt.Errorf("发送关闭连接失败: %w", err)
} }
return nil return nil
} }
func data(proxy string, dest string, tag [16]byte) error { func data(proxy string, destination string, tag [16]byte) error {
slog.Debug("建立数据通道", "tag", tag, "addr", destination)
// 向目标地址建立连接 // 向目标地址建立连接
var result = 1 var result = 1
var dstErr error var dstErr error
dst, err := net.Dial("tcp", dest) dst, err := net.Dial("tcp", destination)
if err != nil { if err != nil {
dstErr = fmt.Errorf("连接目标地址失败: %w", dstErr) dstErr = fmt.Errorf("连接目标地址失败: %w", dstErr)
result = 0 result = 0
@@ -295,24 +255,34 @@ func sendPing(writer io.Writer) error {
return nil return nil
} }
func onConn(reader io.Reader) (tag [16]byte, addr string, err error) { func onConn(reader io.Reader, dataAddr string) (err error) {
// 读取连接命令
var buf = make([]byte, 16+2) var buf = make([]byte, 16+2)
_, err = io.ReadFull(reader, buf) _, err = io.ReadFull(reader, buf)
if err != nil { if err != nil {
return [16]byte{}, "", err return err
} }
tag = [16]byte(buf[0:16]) var tag = [16]byte(buf[0:16])
var addrLen = binary.BigEndian.Uint16(buf[16:18]) var addrLen = binary.BigEndian.Uint16(buf[16:18])
var addrBuf = make([]byte, addrLen) var addrBuf = make([]byte, addrLen)
_, err = io.ReadFull(reader, addrBuf) _, err = io.ReadFull(reader, addrBuf)
if err != nil { if err != nil {
return [16]byte{}, "", err return err
} }
var addr = string(addrBuf)
addr = string(addrBuf) // 异步建立数据通道
return tag, addr, nil go func() {
err := data(dataAddr, addr, tag)
if err != nil {
slog.Error("建立数据通道失败", "err", err)
}
}()
return nil
} }
type ConnCmd struct { type ConnCmd struct {

38
gateway/env/env.go vendored
View File

@@ -20,7 +20,10 @@ var (
AppLogMode = "dev" AppLogMode = "dev"
AppExitTimeout = 5 // 等待服务停止的超时时间 AppExitTimeout = 5 // 等待服务停止的超时时间
AppDataTimeout = 10 // 等待数据通道连接的超时时间 AppDataTimeout = 10 // 等待数据通道连接的超时时间
AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程) AppUserRWTimeout = 10 // 等待用户连接读写超时时间
AppDataRWTimeout = 10 // 等待数据通道读写超时时间
AppCtrlRWTimeout = 10 // 等待控制通道读写超时时间
AppCtrlHBTimeout = 30 // 控制通道心跳超时时间(断开连接等待时间为:心跳等待时间 * 2 + 读写等待时间)
AuthWhitelist []net.IP // 全局白名单,可以将白名单 IP 视为一个可信任代理 AuthWhitelist []net.IP // 全局白名单,可以将白名单 IP 视为一个可信任代理
@@ -103,13 +106,40 @@ func Init() {
AppDataTimeout = appDataTimeout AppDataTimeout = appDataTimeout
} }
value = os.Getenv("APP_USER_TIMEOUT") value = os.Getenv("APP_USER_RW_TIMEOUT")
if value != "" { if value != "" {
appUserTimeout, err := strconv.Atoi(value) appUserTimeout, err := strconv.Atoi(value)
if err != nil { if err != nil {
panic(fmt.Sprintf("环境变量 APP_USER_TIMEOUT 格式错误: %v", err)) panic(fmt.Sprintf("环境变量 APP_USER_RW_TIMEOUT 格式错误: %v", err))
} }
AppUserTimeout = appUserTimeout AppUserRWTimeout = appUserTimeout
}
value = os.Getenv("APP_DATA_RW_TIMEOUT")
if value != "" {
appDataRWTimeout, err := strconv.Atoi(value)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_DATA_RW_TIMEOUT 格式错误: %v", err))
}
AppDataRWTimeout = appDataRWTimeout
}
value = os.Getenv("APP_CTRL_RW_TIMEOUT")
if value != "" {
appCtrlRWTimeout, err := strconv.Atoi(value)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_CTRL_RW_TIMEOUT 格式错误: %v", err))
}
AppCtrlRWTimeout = appCtrlRWTimeout
}
value = os.Getenv("APP_CTRL_HB_TIMEOUT")
if value != "" {
appCtrlHBTimeout, err := strconv.Atoi(value)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_CTRL_HB_TIMEOUT 格式错误: %v", err))
}
AppCtrlHBTimeout = appCtrlHBTimeout
} }
value = os.Getenv("AUTH_WHITELIST") value = os.Getenv("AUTH_WHITELIST")

View File

@@ -14,7 +14,7 @@ import (
"proxy-server/gateway/report" "proxy-server/gateway/report"
"proxy-server/utils" "proxy-server/utils"
"strconv" "strconv"
"syscall" "time"
) )
type CtrlCmdType int type CtrlCmdType int
@@ -37,8 +37,7 @@ func ListenCtrl(ctx context.Context) error {
} }
defer utils.Close(ls) defer utils.Close(ls)
// 处理连接 // 异步等待处理连接
// 异步等待连接
var connCh = make(chan net.Conn) var connCh = make(chan net.Conn)
go func() { go func() {
for { for {
@@ -80,42 +79,39 @@ func ListenCtrl(ctx context.Context) error {
} }
func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) { func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) {
// 通道上下文
ctx, cancel := context.WithCancel(_ctx)
// 结束后清理资源
var fwdPort uint16
defer func() {
slog.Debug("关闭控制通道", "port", fwdPort)
app.DelEdge(fwdPort)
}()
// 处理控制命令
defer cancel()
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// 上下文与通道信息
ctx, cancel := context.WithCancel(_ctx)
defer cancel()
var fwdPort uint16
// 处理连接命令
var errCh = make(chan error)
go func() {
var err error
for { for {
// 循环等待直到服务关闭 // 读取命令
select { var timeout = time.Duration(env.AppCtrlHBTimeout*2+env.AppCtrlRWTimeout) * time.Second
case <-ctx.Done(): err = conn.SetReadDeadline(time.Now().Add(timeout))
return nil if err != nil {
default: errCh <- fmt.Errorf("设置读取超时失败: %w", err)
return
} }
// 读取命令
cmd, err := reader.ReadByte() cmd, err := reader.ReadByte()
if errors.Is(err, syscall.ECONNRESET) || errors.Is(err, syscall.WSAECONNRESET) { if err := utils.WarpConnErr(err); err != nil {
slog.Debug("节点重置了控制通道连接(WSAECONNRESET)") errCh <- err
return nil return
}
if errors.Is(err, io.EOF) {
slog.Debug("节点关闭了控制通道")
return nil
}
if err != nil {
return fmt.Errorf("读取节点命令失败: %w", err)
} }
// 处理节点命令 // 处理节点命令
err = conn.SetReadDeadline(time.Now().Add(time.Duration(env.AppCtrlRWTimeout) * time.Second))
if err != nil {
errCh <- fmt.Errorf("设置读取超时失败: %w", err)
return
}
switch CtrlCmdType(cmd) { switch CtrlCmdType(cmd) {
// 连接建立命令 // 连接建立命令
@@ -123,34 +119,48 @@ func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) {
var recv = make([]byte, 4) var recv = make([]byte, 4)
_, err = io.ReadFull(reader, recv) _, err = io.ReadFull(reader, recv)
if err != nil { if err != nil {
return fmt.Errorf("读取节点 ID 失败: %w", err) errCh <- fmt.Errorf("读取节点 ID 失败: %w", err)
return
} }
var client = int32(binary.BigEndian.Uint32(recv)) var client = int32(binary.BigEndian.Uint32(recv))
fwdPort, err = onOpen(ctx, conn, client) fwdPort, err = onOpen(ctx, conn, client)
if err != nil { if err != nil {
return fmt.Errorf("处理连接建立命令失败: %w", err) errCh <- fmt.Errorf("处理连接建立命令失败: %w", err)
return
} }
// 心跳命令 // 心跳命令
case CtrlCmdPing: case CtrlCmdPing:
err = onPing(conn) err = onPing(conn)
if err != nil { if err != nil {
return fmt.Errorf("处理心跳命令失败: %w", err) errCh <- fmt.Errorf("处理心跳命令失败: %w", err)
return
} }
// 连接关闭命令 // 连接关闭命令
case CtrlCmdClose: case CtrlCmdClose:
err = onClose(conn) err = onClose(conn)
if err != nil { if err != nil {
return fmt.Errorf("处理关闭命令失败: %w", err) errCh <- fmt.Errorf("处理关闭命令失败: %w", err)
return
} }
return nil
// 忽略其他不应该由节点发起的命令 // 忽略其他不应该由节点发起的命令
default: default:
return fmt.Errorf("无法处理控制命令: %d", cmd) errCh <- fmt.Errorf("无法处理控制命令: %d", cmd)
return
} }
} }
}()
// 等待处理结束
select {
case <-ctx.Done():
case err = <-errCh:
}
app.DelEdge(fwdPort)
return
} }
func onOpen(ctx context.Context, writer io.Writer, edge int32) (port uint16, err error) { func onOpen(ctx context.Context, writer io.Writer, edge int32) (port uint16, err error) {

View File

@@ -73,8 +73,8 @@ func ListenData(ctx context.Context) error {
} }
} }
func processDataConn(ctx context.Context, client net.Conn) error { func processDataConn(ctx context.Context, edge net.Conn) error {
var reader = bufio.NewReader(client) var reader = bufio.NewReader(edge)
// 接收连接结果 // 接收连接结果
var buf = make([]byte, 17) var buf = make([]byte, 17)
@@ -133,7 +133,7 @@ func processDataConn(ctx context.Context, client net.Conn) error {
// 复制用户数据到节点 // 复制用户数据到节点
var waitUser = make(chan error) var waitUser = make(chan error)
go func() { go func() {
_, err := io.Copy(client, teeUser) _, err := io.Copy(edge, teeUser)
switch { switch {
case errors.Is(err, net.ErrClosed): case errors.Is(err, net.ErrClosed):
slog.Debug("用户连接意外关闭") slog.Debug("用户连接意外关闭")

View File

@@ -17,7 +17,7 @@ import (
) )
func ListenUser(ctx context.Context, port uint16, ctrl io.Writer) error { func ListenUser(ctx context.Context, port uint16, ctrl io.Writer) error {
dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second) dspt, err := dispatcher.New(port, time.Duration(env.AppUserRWTimeout)*time.Second)
if err != nil { if err != nil {
return err return err
} }

41
utils/conn.go Normal file
View File

@@ -0,0 +1,41 @@
package utils
import (
"errors"
"fmt"
"io"
"log/slog"
"net"
"syscall"
)
func WarpConnErr(err error) error {
if err == nil {
return nil
}
var opErr *net.OpError
switch {
case errors.Is(err, net.ErrClosed):
slog.Debug("连接已关闭")
return nil
case errors.Is(err, io.EOF):
slog.Debug("连接被对端关闭")
return nil
case errors.As(err, &opErr):
switch {
case
errors.Is(opErr.Err, syscall.WSAECONNRESET), errors.Is(opErr.Err, syscall.WSAECONNABORTED),
errors.Is(opErr.Err, syscall.ECONNRESET), errors.Is(opErr.Err, syscall.ECONNABORTED):
slog.Debug("连接被对端重置")
return nil
case opErr.Timeout():
slog.Debug("连接已超时")
return nil
}
}
return fmt.Errorf("连接发生未处理的错误: %w", err)
}