package fwd import ( "bufio" "context" "encoding/binary" "errors" "fmt" "io" "log/slog" "net" "proxy-server/gateway/app" "proxy-server/gateway/env" "proxy-server/utils" "strconv" "time" ) type CtrlCmdType int const ( CtrlCmdPong CtrlCmdType = iota + 1 CtrlCmdPing CtrlCmdOpen CtrlCmdClose CtrlCmdProxy ) func ListenCtrl(ctx context.Context) error { ctrlPort := env.AppCtrlPort // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) if err != nil { return fmt.Errorf("监听控制通道失败: %w", err) } defer utils.Close(ls) // 异步等待处理连接 var connCh = make(chan net.Conn) go func() { for { conn, err := ls.Accept() if errors.Is(err, net.ErrClosed) { slog.Debug("控制通道监听关闭") return } if err != nil { slog.Error("接受控制通道连接失败", "err", err) return } select { case connCh <- conn: case <-ctx.Done(): utils.Close(conn) return } } }() err = nil for { select { case <-ctx.Done(): return nil case conn := <-connCh: app.CtrlConnWg.Add(1) go func() { defer app.CtrlConnWg.Done() defer utils.Close(conn) err := processCtrlConn(ctx, conn) if err != nil { slog.Error("处理控制通道连接失败", "err", err) } }() } } } func processCtrlConn(_ctx context.Context, conn net.Conn) (err error) { reader := bufio.NewReader(conn) // 上下文与通道信息 ctx, cancel := context.WithCancel(_ctx) defer cancel() // 结束时清理 var edgeId int32 defer func() { var edge, ok = app.Edges.Load(edgeId) if ok { *edge.Status = 0 } }() // 处理连接命令 var errCh = make(chan error) go func() { var err error for { // 读取命令 var timeout = time.Duration(env.AppCtrlHBTimeout*2+env.AppCtrlRWTimeout) * time.Second err = conn.SetReadDeadline(time.Now().Add(timeout)) if err != nil { errCh <- fmt.Errorf("设置读取超时失败: %w", err) return } cmd, err := reader.ReadByte() if ok, err := utils.WarpConnErr(err); !ok { errCh <- err return } // 处理节点命令 err = conn.SetReadDeadline(time.Now().Add(time.Duration(env.AppCtrlRWTimeout) * time.Second)) if err != nil { errCh <- fmt.Errorf("设置读取超时失败: %w", err) return } switch CtrlCmdType(cmd) { // 连接建立命令 case CtrlCmdOpen: var recv = make([]byte, 4) _, err = io.ReadFull(reader, recv) if err != nil { errCh <- fmt.Errorf("读取节点 ID 失败: %w", err) return } edgeId = int32(binary.BigEndian.Uint32(recv)) err = onOpen(ctx, conn, edgeId, conn.RemoteAddr()) if err != nil { errCh <- fmt.Errorf("处理连接建立命令失败: %w", err) return } // 心跳命令 case CtrlCmdPing: err = onPing(conn) if err != nil { errCh <- fmt.Errorf("处理心跳命令失败: %w", err) return } // 连接关闭命令 case CtrlCmdClose: err = onClose(conn) if err != nil { errCh <- fmt.Errorf("处理关闭命令失败: %w", err) return } // 忽略其他不应该由节点发起的命令 default: errCh <- fmt.Errorf("无法处理控制命令: %d", cmd) return } } }() // 等待处理结束 select { case <-ctx.Done(): case err = <-errCh: } return } func onOpen(ctx context.Context, writer io.Writer, edgeId int32, addr net.Addr) (err error) { tcpAddr, ok := addr.(*net.TCPAddr) if !ok { return fmt.Errorf("无效的地址类型: %T", addr) } var port uint16 edge, ok := app.Edges.Load(edgeId) if !ok || edge.Port == nil { // 分配端口 app.LockPortAssign.Lock() var minim uint16 = 20000 var maxim uint16 = 60000 for i := minim; i < maxim; i++ { var _, ok = app.Assigns.Load(i) if !ok { port = i break } } if port == 0 { return errors.New("没有可用的端口") } err := app.NewEdge(edgeId, port, tcpAddr) if err != nil { return fmt.Errorf("新增边缘节点失败:%w", err) } app.LockPortAssign.Unlock() } else { // 更新边缘节点地址 port = *edge.Port err := app.TryUpdateEdge(edgeId, tcpAddr) if err != nil { return fmt.Errorf("尝试更新边缘节点失败: %w", err) } } // 启动转发服务 app.FwdLesWg.Add(1) go func() { defer app.FwdLesWg.Done() slog.Info("监听转发端口", "port", port, "edge", edgeId) err = ListenUser(ctx, port, writer) if err != nil { slog.Error("监听转发端口失败", "port", port, "edge", edgeId, "err", err) } }() // 响应节点 if err = sendPong(writer); err != nil { return fmt.Errorf("响应节点失败: %w", err) } return nil } func onPing(writer io.Writer) (err error) { return sendPong(writer) } func onClose(writer io.Writer) (err error) { return sendPong(writer) } func sendPong(writer io.Writer) (err error) { _, err = writer.Write([]byte{byte(CtrlCmdPong)}) if err != nil { return fmt.Errorf("响应节点失败: %w", err) } return nil } func sendProxy(writer io.Writer, tag [16]byte, addr string) (err error) { if len(addr) > 65535 { return fmt.Errorf("代理地址过长: %s", addr) } buf := make([]byte, 1+16+2+len(addr)) buf[0] = byte(CtrlCmdProxy) copy(buf[1:], tag[:]) binary.BigEndian.PutUint16(buf[17:], uint16(len(addr))) copy(buf[19:], addr) _, err = writer.Write(buf) if err != nil { return fmt.Errorf("发送代理命令失败: %w", err) } return nil }