package fwd import ( "bufio" "context" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/fwd/core" "proxy-server/server/fwd/dispatcher" "proxy-server/server/fwd/metrics" "proxy-server/server/fwd/repo" "proxy-server/server/pkg/env" "proxy-server/server/pkg/orm" "strconv" "strings" "time" "github.com/pkg/errors" ) type CtrlCmd struct { conn net.Conn buf []byte } var ctrlCmdChan = make(chan CtrlCmd, 1024) func (s *Service) startCtrlTun() error { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) if err != nil { return errors.Wrap(err, "监听控制通道失败") } defer utils.Close(ls) // 处理连接 connCh := utils.ChanConnAccept(s.ctx, ls) err = nil for loop := true; loop; { select { case <-s.ctx.Done(): loop = false case conn, ok := <-connCh: if !ok { err = errors.New("获取连接失败") loop = false } s.ctrlConnWg.Add(1) go func() { defer s.ctrlConnWg.Done() defer utils.Close(conn) err := s.processCtrlConn(conn) if err != nil { slog.Error("处理控制通道连接失败", "err", err) } }() } } return err } func (s *Service) processCtrlConn(conn net.Conn) error { reader := bufio.NewReader(conn) // version version, err := reader.ReadByte() if err != nil { _ = ctrlResp(conn, CtrlFail) return errors.Wrap(err, "获取版本号失败") } // name nameLen, err := reader.ReadByte() if err != nil { _ = ctrlResp(conn, CtrlFail) return errors.Wrap(err, "获取 name 失败") } nameBuf, err := utils.ReadBuffer(reader, int(nameLen)) if err != nil { _ = ctrlResp(conn, CtrlFail) return errors.Wrap(err, "获取 name 失败") } name := string(nameBuf) if name == "" { _ = ctrlResp(conn, CtrlFail) return errors.New("客户端名称不能为空") } // 检查客户端 var node repo.Node err = orm.DB.Take(&node, &repo.Node{ Name: name, }).Error if err != nil { _ = ctrlResp(conn, CtrlFail) return errors.Wrap(err, "查询客户端失败") } if version != node.Version { _ = ctrlResp(conn, CtrlFail) return errors.New("客户端版本不匹配") } err = ctrlResp(conn, CtrlDone) if err != nil { return errors.Wrap(err, "向客户端发送响应失败") } port := node.FwdPort slog.Info("监听转发端口", "port", port, "client", name) // 启动转发服务 proxy, err := dispatcher.New(port) if err != nil { return errors.Wrap(err, "创建 socks 转发服务失败") } defer proxy.Close() s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() err := proxy.Run() if err != nil { slog.Error("代理服务运行失败", "err", err) } }() // 监听控制通道连接 errCh := make(chan error, 1) go func() { defer close(errCh) _, err := reader.ReadByte() errCh <- err }() // 批量同步写入 go func() { for { select { case <-s.ctx.Done(): return case cmd := <-ctrlCmdChan: _, err := cmd.conn.Write(cmd.buf) if err != nil { slog.Error("批量写入失败", "err", err) utils.Close(cmd.conn) } } } }() // 处理连接 for { select { case <-s.ctx.Done(): return nil case err := <-errCh: switch { case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."): slog.Debug("客户端主动断开连接") return nil case err == nil: return errors.New("客户端握手失败") default: return errors.Wrap(err, "客户端意外断开连接") } case user := <-proxy.Conn: metrics.TimerAuth.Store(user.Conn, time.Now()) s.userConnWg.Add(1) go func() { defer s.userConnWg.Done() err := s.processUserConn(user, conn) if err != nil { slog.Error("处理用户连接失败", "err", err) utils.Close(user) } }() } } } func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { // 组织写入信息 dst := user.DestAddr().String() dstLen := len(dst) tag := user.Tag tagLen := len(tag) writeBuf := make([]byte, 2+dstLen+tagLen) writeBuf[0] = byte(dstLen) copy(writeBuf[1:], dst) writeBuf[1+dstLen] = byte(tagLen) copy(writeBuf[2+dstLen:], tag) // 异步写入命令 ctrlCmdChan <- CtrlCmd{ conn: ctrl, buf: writeBuf, } // 记录用户连接 s.userConnMap.Store(user.Tag, user) // 如果限定时间内没有建立数据通道,则关闭连接 timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() select { case <-s.ctx.Done(): // 服务会在退出时统一关闭未消费的连接 case <-timeout.Done(): storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag) if ok { slog.Debug("建立数据通道超时", "tag", user.Tag) utils.Close(storedUser) } } return nil } type CtrlResult byte const ( CtrlFail CtrlResult = iota CtrlDone ) func ctrlResp(conn net.Conn, result CtrlResult) error { _, err := conn.Write([]byte{byte(result)}) return err }