重新规划网关与节点的交互协议,实现统一命令位的识别和处理
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -11,25 +12,21 @@ import (
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/fwd/dispatcher"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"proxy-server/server/report"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type CtrlCmd struct {
|
||||
conn net.Conn
|
||||
buf []byte
|
||||
}
|
||||
type CtrlCmdType int
|
||||
|
||||
var ctrlCmdChan = make(chan CtrlCmd, 1024)
|
||||
const (
|
||||
CtrlCmdPong CtrlCmdType = iota + 1
|
||||
CtrlCmdPing
|
||||
CtrlCmdOpen
|
||||
CtrlCmdClose
|
||||
CtrlCmdProxy
|
||||
)
|
||||
|
||||
func (s *Service) startCtrlTun() error {
|
||||
func (s *Service) listenCtrl() error {
|
||||
ctrlPort := env.AppCtrlPort
|
||||
slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort)))
|
||||
|
||||
@@ -56,7 +53,7 @@ func (s *Service) startCtrlTun() error {
|
||||
go func() {
|
||||
defer s.ctrlConnWg.Done()
|
||||
defer utils.Close(conn)
|
||||
err := s.processCtrlConn(conn)
|
||||
err := s.processCtrlConn(s.ctx, conn)
|
||||
if err != nil {
|
||||
slog.Error("处理控制通道连接失败", "err", err)
|
||||
}
|
||||
@@ -67,25 +64,80 @@ func (s *Service) startCtrlTun() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// 循环等待直到服务关闭
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
var recv = make([]byte, 4)
|
||||
_, err := io.ReadFull(reader, recv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取客户端 ID 失败: %w", err)
|
||||
// 读取命令
|
||||
cmdByte, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取节点命令失败: %w", err)
|
||||
}
|
||||
var cmd = CtrlCmdType(cmdByte)
|
||||
switch cmd {
|
||||
|
||||
// 连接建立命令
|
||||
case CtrlCmdOpen:
|
||||
var recv = make([]byte, 4)
|
||||
_, err = io.ReadFull(reader, recv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取节点 ID 失败: %w", err)
|
||||
}
|
||||
var client = int32(binary.BigEndian.Uint32(recv))
|
||||
err = s.onOpen(conn, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理连接建立命令失败: %w", err)
|
||||
}
|
||||
|
||||
// 心跳命令
|
||||
case CtrlCmdPing:
|
||||
err = s.onPing(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理心跳命令失败: %w", err)
|
||||
}
|
||||
|
||||
// 连接关闭命令
|
||||
case CtrlCmdClose:
|
||||
err = s.onClose(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理关闭命令失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
// 忽略其他不应该由节点发起的命令
|
||||
default:
|
||||
return fmt.Errorf("无法处理控制命令: %d", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) onPing(conn net.Conn) (err error) {
|
||||
return s.sendPong(conn)
|
||||
}
|
||||
|
||||
func (s *Service) onOpen(conn net.Conn, client int32) (err error) {
|
||||
// open 命令全局只执行一次
|
||||
_, ok := app.Clients.Load(client)
|
||||
if ok {
|
||||
return fmt.Errorf("节点 ID %d 已经连接", client)
|
||||
}
|
||||
var client = int32(binary.BigEndian.Uint32(recv))
|
||||
|
||||
// 分配端口
|
||||
var minim uint16 = 20000
|
||||
var maxim uint16 = 60000
|
||||
var port uint16
|
||||
for i := minim; i < maxim; i++ {
|
||||
var _, ok = app.Assigns[i]
|
||||
var _, ok = app.Assigns.Load(i)
|
||||
if !ok {
|
||||
port = i
|
||||
app.Assigns[i] = client
|
||||
app.Assigns.Store(i, client)
|
||||
app.Clients.Store(client, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -94,126 +146,75 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 报告端口分配
|
||||
err = report.Assigned(client, port)
|
||||
if err != nil {
|
||||
if err = report.Assigned(client, port); err != nil {
|
||||
return fmt.Errorf("报告端口分配失败: %w", err)
|
||||
}
|
||||
|
||||
// 响应客户端
|
||||
_, err = conn.Write([]byte{1})
|
||||
if err != nil {
|
||||
if err = s.sendPong(conn); err != nil {
|
||||
return fmt.Errorf("响应客户端失败: %w", err)
|
||||
}
|
||||
|
||||
// 启动转发服务
|
||||
slog.Info("监听转发端口", "port", port, "client", client)
|
||||
proxy, err := dispatcher.New(port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer proxy.Close()
|
||||
|
||||
s.fwdLesWg.Add(1)
|
||||
go func() {
|
||||
defer s.fwdLesWg.Done()
|
||||
err := proxy.Run()
|
||||
slog.Info("监听转发端口", "port", port, "client", client)
|
||||
err = s.listenUser(port, conn)
|
||||
if err != nil {
|
||||
slog.Error("代理服务运行失败", "err", err)
|
||||
slog.Error("监听转发端口失败", "port", port, "client", client, "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 监听控制通道连接
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
_, err := reader.ReadByte()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
// 批量同步写入
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case cmd := <-ctrlCmdChan:
|
||||
_, err := cmd.conn.Write(cmd.buf)
|
||||
if err != nil {
|
||||
slog.Error("批量写入失败", "err", err)
|
||||
utils.Close(cmd.conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 处理连接
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."):
|
||||
slog.Debug("客户端主动断开连接")
|
||||
return nil
|
||||
case err == nil:
|
||||
return errors.New("客户端握手失败")
|
||||
default:
|
||||
return fmt.Errorf("客户端意外断开连接: %w", err)
|
||||
}
|
||||
case user := <-proxy.Conn:
|
||||
metrics.TimerAuth.Store(user.Conn, time.Now())
|
||||
s.userConnWg.Add(1)
|
||||
go func() {
|
||||
defer s.userConnWg.Done()
|
||||
err := s.processUserConn(user, conn)
|
||||
if err != nil {
|
||||
slog.Error("处理用户连接失败", "err", err)
|
||||
utils.Close(user)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error {
|
||||
|
||||
// 组织写入信息
|
||||
dst := user.DestAddr().String()
|
||||
dstLen := len(dst)
|
||||
|
||||
tag := user.Tag
|
||||
tagLen := len(tag)
|
||||
|
||||
writeBuf := make([]byte, 2+dstLen+tagLen)
|
||||
writeBuf[0] = byte(dstLen)
|
||||
copy(writeBuf[1:], dst)
|
||||
writeBuf[1+dstLen] = byte(tagLen)
|
||||
copy(writeBuf[2+dstLen:], tag)
|
||||
|
||||
// 异步写入命令
|
||||
ctrlCmdChan <- CtrlCmd{
|
||||
conn: ctrl,
|
||||
buf: writeBuf,
|
||||
func (s *Service) onClose(conn net.Conn) (err error) {
|
||||
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录用户连接
|
||||
s.userConnMap.Store(user.Tag, user)
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果限定时间内没有建立数据通道,则关闭连接
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
id, _ := app.Assigns.LoadAndDelete(uint16(port))
|
||||
app.Clients.Delete(id)
|
||||
app.Assigns.Delete(uint16(port))
|
||||
app.Permits.Delete(uint16(port))
|
||||
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
// 服务会在退出时统一关闭未消费的连接
|
||||
case <-timeout.Done():
|
||||
storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag)
|
||||
if ok {
|
||||
slog.Debug("建立数据通道超时", "tag", user.Tag)
|
||||
utils.Close(storedUser)
|
||||
}
|
||||
err = s.sendPong(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) sendPong(conn net.Conn) (err error) {
|
||||
_, err = conn.Write([]byte{byte(CtrlCmdPong)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("响应客户端失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) {
|
||||
if len(addr) > 65535 {
|
||||
return fmt.Errorf("代理地址过长: %s", addr)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1+16+2+len(addr))
|
||||
buf[0] = byte(CtrlCmdProxy)
|
||||
copy(buf[1:], tag[:])
|
||||
binary.BigEndian.PutUint16(buf[17:], uint16(len(addr)))
|
||||
copy(buf[19:], addr)
|
||||
|
||||
_, err = conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送代理命令失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user