优化全局数据存储方式,以节点 id 为 key 保存相关数据;修复节点下线监听未关闭问题

This commit is contained in:
2025-05-17 11:02:18 +08:00
parent 84e01d3b50
commit c1664aa898
10 changed files with 77 additions and 51 deletions

View File

@@ -130,19 +130,20 @@ func ctrl(ctx context.Context, id int32, host string) error {
// 异步等待连接命令 // 异步等待连接命令
slog.Info("等待用户连接") slog.Info("等待用户连接")
var cmdCh = make(chan ConnCmd) var cmdCh = make(chan ConnCmd)
var errCh = make(chan error)
go func() { go func() {
for { for {
cmd, err := reader.ReadByte() cmd, err := reader.ReadByte()
if errors.Is(err, net.ErrClosed) {
slog.Debug("控制通道关闭")
return
}
if errors.Is(err, io.EOF) {
slog.Debug("网关关闭了控制通道")
return
}
if err != nil { if err != nil {
slog.Error("读取命令失败", "err", err) switch {
case errors.Is(err, net.ErrClosed):
err = fmt.Errorf("控制通道关闭: %w", err)
case errors.Is(err, io.EOF):
err = fmt.Errorf("网关关闭了控制通道: %w", err)
default:
err = fmt.Errorf("读取命令失败: %w", err)
}
errCh <- err
return return
} }
@@ -168,6 +169,9 @@ func ctrl(ctx context.Context, id int32, host string) error {
select { select {
case <-ctx.Done(): case <-ctx.Done():
loop = false loop = false
case err = <-errCh:
slog.Error("读取控制命令失败", "err", err)
loop = false
case cmd := <-cmdCh: case cmd := <-cmdCh:
slog.Debug("建立数据通道", "tag", cmd.Tag, "addr", cmd.Addr) slog.Debug("建立数据通道", "tag", cmd.Tag, "addr", cmd.Addr)
go func() { go func() {

View File

@@ -4,14 +4,18 @@ import (
"proxy-server/gateway/core" "proxy-server/gateway/core"
) )
type Stoppable interface {
Stop()
}
var ( var (
Id int32 Id int32
Name string Name string
PlatformSecret string // 平台密钥,验证接收的请求是否属于平台 PlatformSecret string // 平台密钥,验证接收的请求是否属于平台
Edges = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口 Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID
Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID Edges = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口
Permits = core.SyncMap[uint16, *core.Permit]{} // 转发端口 -> 权限配置 Permits = core.SyncMap[int32, *core.Permit]{} // 转发端口 -> 权限配置
) )
func AddEdge(id int32, port uint16) { func AddEdge(id int32, port uint16) {
@@ -22,9 +26,19 @@ func AddEdge(id int32, port uint16) {
func DelEdge(port uint16) { func DelEdge(port uint16) {
id, _ := Assigns.LoadAndDelete(port) id, _ := Assigns.LoadAndDelete(port)
Edges.Delete(id) Edges.Delete(id)
Permits.Delete(port) Permits.Delete(id)
} }
func PermitEdge(port uint16, permit *core.Permit) { func LoadPermit(port uint16) *core.Permit {
Permits.Store(port, permit) id, ok := Assigns.Load(port)
if !ok {
return nil
}
permit, ok := Permits.Load(id)
if !ok {
return nil
}
return permit
} }

View File

@@ -36,8 +36,8 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
} }
// 查找权限配置 // 查找权限配置
var permit, ok = app.Permits.Load(uint16(localPort)) var permit = app.LoadPermit(uint16(localPort))
if !ok { if permit == nil {
return nil, errors.New("没有权限") return nil, errors.New("没有权限")
} }

View File

@@ -29,7 +29,6 @@ const (
func (s *Service) listenCtrl() error { func (s *Service) listenCtrl() error {
ctrlPort := env.AppCtrlPort ctrlPort := env.AppCtrlPort
slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort)))
// 监听端口 // 监听端口
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort)))
@@ -80,22 +79,19 @@ func (s *Service) listenCtrl() error {
} }
} }
func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) { func (s *Service) processCtrlConn(_ctx context.Context, conn net.Conn) (err error) {
// 通道上下文
ctx, cancel := context.WithCancel(_ctx)
// 结束后清理资源
var fwdPort uint16
defer func() { defer func() {
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) slog.Debug("关闭控制通道", "port", fwdPort)
if err != nil { app.DelEdge(fwdPort)
slog.Error("获取控制通道端口失败", "err", err)
return
}
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
slog.Error("解析控制通道端口失败", "err", err)
return
}
app.DelEdge(uint16(port))
}() }()
// 处理控制命令
defer cancel()
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
for { for {
// 循环等待直到服务关闭 // 循环等待直到服务关闭
@@ -130,7 +126,7 @@ func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error
return fmt.Errorf("读取节点 ID 失败: %w", err) return fmt.Errorf("读取节点 ID 失败: %w", err)
} }
var client = int32(binary.BigEndian.Uint32(recv)) var client = int32(binary.BigEndian.Uint32(recv))
err = s.onOpen(conn, client) fwdPort, err = s.onOpen(ctx, conn, client)
if err != nil { if err != nil {
return fmt.Errorf("处理连接建立命令失败: %w", err) return fmt.Errorf("处理连接建立命令失败: %w", err)
} }
@@ -157,17 +153,16 @@ func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error
} }
} }
func (s *Service) onOpen(writer io.Writer, edge int32) (err error) { func (s *Service) onOpen(ctx context.Context, writer io.Writer, edge int32) (port uint16, err error) {
// open 命令全局只执行一次 // open 命令全局只执行一次
_, ok := app.Edges.Load(edge) _, ok := app.Edges.Load(edge)
if ok { if ok {
return fmt.Errorf("节点 ID %d 已经连接", edge) return 0, fmt.Errorf("节点 ID %d 已经连接", edge)
} }
// 分配端口 // 分配端口
var minim uint16 = 20000 var minim uint16 = 20000
var maxim uint16 = 60000 var maxim uint16 = 60000
var port uint16
for i := minim; i < maxim; i++ { for i := minim; i < maxim; i++ {
var _, ok = app.Assigns.Load(i) var _, ok = app.Assigns.Load(i)
if !ok { if !ok {
@@ -177,17 +172,17 @@ func (s *Service) onOpen(writer io.Writer, edge int32) (err error) {
} }
} }
if port == 0 { if port == 0 {
return errors.New("没有可用的端口") return 0, errors.New("没有可用的端口")
} }
// 报告端口分配 // 报告端口分配
if err = report.Assigned(edge, port); err != nil { if err = report.Assigned(edge, port); err != nil {
return fmt.Errorf("报告端口分配失败: %w", err) return 0, fmt.Errorf("报告端口分配失败: %w", err)
} }
// 响应节点 // 响应节点
if err = s.sendPong(writer); err != nil { if err = s.sendPong(writer); err != nil {
return fmt.Errorf("响应节点失败: %w", err) return 0, fmt.Errorf("响应节点失败: %w", err)
} }
// 启动转发服务 // 启动转发服务
@@ -195,13 +190,13 @@ func (s *Service) onOpen(writer io.Writer, edge int32) (err error) {
go func() { go func() {
defer s.fwdLesWg.Done() defer s.fwdLesWg.Done()
slog.Info("监听转发端口", "port", port, "edge", edge) slog.Info("监听转发端口", "port", port, "edge", edge)
err = s.listenUser(port, writer) err = s.listenUser(ctx, port, writer)
if err != nil { if err != nil {
slog.Error("监听转发端口失败", "port", port, "edge", edge, "err", err) slog.Error("监听转发端口失败", "port", port, "edge", edge, "err", err)
} }
}() }()
return nil return port, nil
} }
func (s *Service) onPing(writer io.Writer) (err error) { func (s *Service) onPing(writer io.Writer) (err error) {

View File

@@ -19,7 +19,6 @@ import (
func (s *Service) listenData() error { func (s *Service) listenData() error {
dataPort := env.AppDataPort dataPort := env.AppDataPort
slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort)))
// 监听端口 // 监听端口
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort)))

View File

@@ -42,7 +42,7 @@ func New(port uint16, readTimeout time.Duration) (*Server, error) {
}, nil }, nil
} }
func (s *Server) Close() { func (s *Server) Stop() {
s.cancel() s.cancel()
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"log/slog" "log/slog"
"proxy-server/gateway/core" "proxy-server/gateway/core"
"proxy-server/gateway/env"
"proxy-server/utils" "proxy-server/utils"
"sync" "sync"
) )
@@ -29,7 +30,7 @@ func New() *Service {
} }
func (s *Service) Run() error { func (s *Service) Run() error {
slog.Info("启动转发服务") slog.Info("启动转发服务", "控制通道", env.AppCtrlPort, "数据通道", env.AppDataPort)
errQuit := make(chan struct{}, 2) errQuit := make(chan struct{}, 2)
defer close(errQuit) defer close(errQuit)

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt"
"io" "io"
"log/slog" "log/slog"
"proxy-server/gateway/core" "proxy-server/gateway/core"
@@ -14,25 +15,33 @@ import (
"time" "time"
) )
func (s *Service) listenUser(port uint16, ctrl io.Writer) error { func (s *Service) listenUser(ctx context.Context, port uint16, ctrl io.Writer) error {
dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second) dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second)
if err != nil { if err != nil {
return err return err
} }
defer dspt.Close() defer dspt.Stop()
var errCh = make(chan error)
go func() { go func() {
err := dspt.Run() err := dspt.Run()
if err != nil { if err != nil {
slog.Error("代理服务运行失败", "err", err) // slog.Error("代理服务运行失败", "err", err)
err = fmt.Errorf("协议嗅探服务运行失败: %w", err)
} }
errCh <- err
}() }()
// 处理连接 // 处理连接
for { for {
select { select {
case <-s.ctx.Done(): case <-ctx.Done():
return nil return nil
case err := <-errCh:
if err != nil {
err = fmt.Errorf("监听转发端口失败: %w", err)
}
return err
case user := <-dspt.Conn: case user := <-dspt.Conn:
metrics.TimerAuth.Store(user.Conn, time.Now()) metrics.TimerAuth.Store(user.Conn, time.Now())
s.userConnWg.Add(1) s.userConnWg.Add(1)

View File

@@ -7,7 +7,7 @@ import (
) )
type AuthReq struct { type AuthReq struct {
Port uint16 `json:"port"` Id int32 `json:"id"`
core.Permit core.Permit
} }
@@ -26,7 +26,7 @@ func Auth(ctx *fiber.Ctx) (err error) {
} }
// 保存授权配置 // 保存授权配置
app.PermitEdge(req.Port, &req.Permit) app.Permits.Store(req.Id, &req.Permit)
return nil return nil
} }

View File

@@ -1,6 +1,7 @@
package web package web
import ( import (
"log/slog"
"proxy-server/gateway/env" "proxy-server/gateway/env"
"strconv" "strconv"
@@ -16,7 +17,10 @@ func New() *Server {
} }
func (s *Server) Run() error { func (s *Server) Run() error {
s.web = fiber.New() slog.Info("启动接口服务", "服务端口", env.AppWebPort)
s.web = fiber.New(fiber.Config{
DisableStartupMessage: true,
})
// 配置中间件和路由 // 配置中间件和路由
Router(s.web) Router(s.web)