Files
proxy/gateway/fwd/ctrl.go

245 lines
5.0 KiB
Go

package fwd
import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
"net"
"proxy-server/gateway/app"
"proxy-server/gateway/env"
"proxy-server/gateway/report"
"proxy-server/utils"
"strconv"
"time"
)
type CtrlCmdType int
const (
CtrlCmdPong CtrlCmdType = iota + 1
CtrlCmdPing
CtrlCmdOpen
CtrlCmdClose
CtrlCmdProxy
)
func ListenCtrl(ctx context.Context) error {
ctrlPort := env.AppCtrlPort
// 监听端口
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort)))
if err != nil {
return fmt.Errorf("监听控制通道失败: %w", err)
}
defer utils.Close(ls)
// 异步等待处理连接
var connCh = make(chan net.Conn)
go func() {
for {
conn, err := ls.Accept()
if errors.Is(err, net.ErrClosed) {
slog.Debug("控制通道监听关闭")
return
}
if err != nil {
slog.Error("接受控制通道连接失败", "err", err)
return
}
select {
case connCh <- conn:
case <-ctx.Done():
utils.Close(conn)
return
}
}
}()
err = nil
for {
select {
case <-ctx.Done():
return nil
case conn := <-connCh:
app.CtrlConnWg.Add(1)
go func() {
defer app.CtrlConnWg.Done()
defer utils.Close(conn)
err := processCtrlConn(ctx, conn)
if err != nil {
slog.Error("处理控制通道连接失败", "err", err)
}
}()
}
}
}
func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) {
reader := bufio.NewReader(conn)
// 上下文与通道信息
ctx, cancel := context.WithCancel(_ctx)
defer cancel()
var fwdPort uint16
// 处理连接命令
var errCh = make(chan error)
go func() {
var err error
for {
// 读取命令
var timeout = time.Duration(env.AppCtrlHBTimeout*2+env.AppCtrlRWTimeout) * time.Second
err = conn.SetReadDeadline(time.Now().Add(timeout))
if err != nil {
errCh <- fmt.Errorf("设置读取超时失败: %w", err)
return
}
cmd, err := reader.ReadByte()
if err := utils.WarpConnErr(err); err != nil {
errCh <- err
return
}
// 处理节点命令
err = conn.SetReadDeadline(time.Now().Add(time.Duration(env.AppCtrlRWTimeout) * time.Second))
if err != nil {
errCh <- fmt.Errorf("设置读取超时失败: %w", err)
return
}
switch CtrlCmdType(cmd) {
// 连接建立命令
case CtrlCmdOpen:
var recv = make([]byte, 4)
_, err = io.ReadFull(reader, recv)
if err != nil {
errCh <- fmt.Errorf("读取节点 ID 失败: %w", err)
return
}
var client = int32(binary.BigEndian.Uint32(recv))
fwdPort, err = onOpen(ctx, conn, client)
if err != nil {
errCh <- fmt.Errorf("处理连接建立命令失败: %w", err)
return
}
// 心跳命令
case CtrlCmdPing:
err = onPing(conn)
if err != nil {
errCh <- fmt.Errorf("处理心跳命令失败: %w", err)
return
}
// 连接关闭命令
case CtrlCmdClose:
err = onClose(conn)
if err != nil {
errCh <- fmt.Errorf("处理关闭命令失败: %w", err)
return
}
// 忽略其他不应该由节点发起的命令
default:
errCh <- fmt.Errorf("无法处理控制命令: %d", cmd)
return
}
}
}()
// 等待处理结束
select {
case <-ctx.Done():
case err = <-errCh:
}
app.DelEdge(fwdPort)
return
}
func onOpen(ctx context.Context, writer io.Writer, edge int32) (port uint16, err error) {
// open 命令全局只执行一次
_, ok := app.Edges.Load(edge)
if ok {
return 0, fmt.Errorf("节点 ID %d 已经连接", edge)
}
// 分配端口
var minim uint16 = 20000
var maxim uint16 = 60000
for i := minim; i < maxim; i++ {
var _, ok = app.Assigns.Load(i)
if !ok {
port = i
app.AddEdge(edge, port)
break
}
}
if port == 0 {
return 0, errors.New("没有可用的端口")
}
// 报告端口分配
if err = report.Assigned(edge, port); err != nil {
return 0, fmt.Errorf("报告端口分配失败: %w", err)
}
// 响应节点
if err = sendPong(writer); err != nil {
return 0, fmt.Errorf("响应节点失败: %w", err)
}
// 启动转发服务
app.FwdLesWg.Add(1)
go func() {
defer app.FwdLesWg.Done()
slog.Info("监听转发端口", "port", port, "edge", edge)
err = ListenUser(ctx, port, writer)
if err != nil {
slog.Error("监听转发端口失败", "port", port, "edge", edge, "err", err)
}
}()
return port, nil
}
func onPing(writer io.Writer) (err error) {
return sendPong(writer)
}
func onClose(writer io.Writer) (err error) {
return sendPong(writer)
}
func sendPong(writer io.Writer) (err error) {
_, err = writer.Write([]byte{byte(CtrlCmdPong)})
if err != nil {
return fmt.Errorf("响应节点失败: %w", err)
}
return nil
}
func sendProxy(writer io.Writer, tag [16]byte, addr string) (err error) {
if len(addr) > 65535 {
return fmt.Errorf("代理地址过长: %s", addr)
}
buf := make([]byte, 1+16+2+len(addr))
buf[0] = byte(CtrlCmdProxy)
copy(buf[1:], tag[:])
binary.BigEndian.PutUint16(buf[17:], uint16(len(addr)))
copy(buf[19:], addr)
_, err = writer.Write(buf)
if err != nil {
return fmt.Errorf("发送代理命令失败: %w", err)
}
return nil
}