package fwd import ( "bufio" "context" "encoding/binary" "errors" "fmt" "io" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/app" "proxy-server/server/env" "proxy-server/server/report" "strconv" ) type CtrlCmdType int const ( CtrlCmdPong CtrlCmdType = iota + 1 CtrlCmdPing CtrlCmdOpen CtrlCmdClose CtrlCmdProxy ) func (s *Service) listenCtrl() error { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) if err != nil { return fmt.Errorf("监听控制通道失败: %w", err) } defer utils.Close(ls) // 处理连接 // 异步等待连接 var connCh = make(chan net.Conn) go func() { for { conn, err := ls.Accept() if errors.Is(err, net.ErrClosed) { slog.Debug("控制通道监听关闭") return } if err != nil { slog.Error("接受控制通道连接失败", "err", err) return } select { case connCh <- conn: case <-s.ctx.Done(): utils.Close(conn) return } } }() err = nil for { select { case <-s.ctx.Done(): return nil case conn := <-connCh: s.ctrlConnWg.Add(1) go func() { defer s.ctrlConnWg.Done() defer utils.Close(conn) err := s.processCtrlConn(s.ctx, conn) if err != nil { slog.Error("处理控制通道连接失败", "err", err) } }() } } } func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) { reader := bufio.NewReader(conn) for { // 循环等待直到服务关闭 select { case <-ctx.Done(): return nil default: } // 读取命令 cmdByte, err := reader.ReadByte() if err != nil { return fmt.Errorf("读取节点命令失败: %w", err) } var cmd = CtrlCmdType(cmdByte) switch cmd { // 连接建立命令 case CtrlCmdOpen: var recv = make([]byte, 4) _, err = io.ReadFull(reader, recv) if err != nil { return fmt.Errorf("读取节点 ID 失败: %w", err) } var client = int32(binary.BigEndian.Uint32(recv)) err = s.onOpen(conn, client) if err != nil { return fmt.Errorf("处理连接建立命令失败: %w", err) } // 心跳命令 case CtrlCmdPing: err = s.onPing(conn) if err != nil { return fmt.Errorf("处理心跳命令失败: %w", err) } // 连接关闭命令 case CtrlCmdClose: err = s.onClose(conn) if err != nil { return fmt.Errorf("处理关闭命令失败: %w", err) } return nil // 忽略其他不应该由节点发起的命令 default: return fmt.Errorf("无法处理控制命令: %d", cmd) } } } func (s *Service) onPing(conn net.Conn) (err error) { return s.sendPong(conn) } func (s *Service) onOpen(conn net.Conn, client int32) (err error) { // open 命令全局只执行一次 _, ok := app.Clients.Load(client) if ok { return fmt.Errorf("节点 ID %d 已经连接", client) } // 分配端口 var minim uint16 = 20000 var maxim uint16 = 60000 var port uint16 for i := minim; i < maxim; i++ { var _, ok = app.Assigns.Load(i) if !ok { port = i app.Assigns.Store(i, client) app.Clients.Store(client, i) break } } if port == 0 { return errors.New("没有可用的端口") } // 报告端口分配 if err = report.Assigned(client, port); err != nil { return fmt.Errorf("报告端口分配失败: %w", err) } // 响应客户端 if err = s.sendPong(conn); err != nil { return fmt.Errorf("响应客户端失败: %w", err) } // 启动转发服务 s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() slog.Info("监听转发端口", "port", port, "client", client) err = s.listenUser(port, conn) if err != nil { slog.Error("监听转发端口失败", "port", port, "client", client, "err", err) } }() return nil } func (s *Service) onClose(conn net.Conn) (err error) { _, portStr, err := net.SplitHostPort(conn.LocalAddr().String()) if err != nil { return err } port, err := strconv.ParseUint(portStr, 10, 16) if err != nil { return err } id, _ := app.Assigns.LoadAndDelete(uint16(port)) app.Clients.Delete(id) app.Assigns.Delete(uint16(port)) app.Permits.Delete(uint16(port)) err = s.sendPong(conn) if err != nil { return err } return nil } func (s *Service) sendPong(conn net.Conn) (err error) { _, err = conn.Write([]byte{byte(CtrlCmdPong)}) if err != nil { return fmt.Errorf("响应客户端失败: %w", err) } return nil } func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) { if len(addr) > 65535 { return fmt.Errorf("代理地址过长: %s", addr) } buf := make([]byte, 1+16+2+len(addr)) buf[0] = byte(CtrlCmdProxy) copy(buf[1:], tag[:]) binary.BigEndian.PutUint16(buf[17:], uint16(len(addr))) copy(buf[19:], addr) _, err = conn.Write(buf) if err != nil { return fmt.Errorf("发送代理命令失败: %w", err) } return nil }