package fwd import ( "bufio" "context" "encoding/binary" "io" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/fwd/socks" "proxy-server/server/pkg/env" "strconv" "sync" "time" "github.com/pkg/errors" ) type Config struct { } type Service struct { Config *Config ctx context.Context cancel context.CancelFunc userConnMap map[string]socks.ProxyConn ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } func New(config *Config) *Service { if config == nil { config = &Config{} } ctx, cancel := context.WithCancel(context.Background()) return &Service{ Config: config, ctx: ctx, cancel: cancel, userConnMap: make(map[string]socks.ProxyConn), ctrlConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{}, } } func (s *Service) Close() { s.cancel() for _, conn := range s.userConnMap { utils.Close(conn) } clear(s.userConnMap) } func (s *Service) Run() { slog.Debug("启动 fwd 服务") errQuit := make(chan struct{}) defer close(errQuit) wg := sync.WaitGroup{} // 启动工作协程 wg.Add(1) go func() { defer wg.Done() err := s.startCtrlTun() if err != nil { slog.Error("控制通道发生错误", "err", err) errQuit <- struct{}{} return } }() wg.Add(1) go func() { defer wg.Done() err := s.startDataTun() if err != nil { slog.Error("数据通道发生错误", "err", err) errQuit <- struct{}{} return } }() // 等待结束 select { case <-s.ctx.Done(): slog.Debug("服务关闭") case <-errQuit: slog.Debug("服务异常退出") } // 退出 s.Close() timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wgCh := utils.ChanWgWait(timeout, &wg) defer close(wgCh) select { case <-timeout.Done(): slog.Warn("关闭超时,强制关闭") case <-wgCh: slog.Debug("服务已退出") } } func (s *Service) startCtrlTun() error { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) if err != nil { return errors.Wrap(err, "监听控制通道失败") } defer utils.Close(ls) // 等待连接 connCh := utils.ChanConnAccept(s.ctx, ls) defer close(connCh) // 处理连接 for loop := true; loop; { select { case <-s.ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") loop = false case conn, ok := <-connCh: if !ok { slog.Debug("结束处理连接,由于获取连接失败") loop = false } s.ctrlConnWg.Add(1) go func() { defer s.ctrlConnWg.Done() defer utils.Close(conn) err := s.processCtrlConn(conn) if err != nil { slog.Error("处理控制通道连接失败", err) } }() } } // 等待子协程结束 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg) defer close(procCh) select { case <-timeout.Done(): slog.Warn("等待控制通道子协程结束超时") case <-procCh: slog.Debug("控制通道子协程结束") } slog.Debug("关闭控制通道") return nil } func (s *Service) processCtrlConn(controller net.Conn) error { slog.Info("客户端连入", "addr", controller.RemoteAddr().String()) reader := bufio.NewReader(controller) // 获取转发端口 portBuf, err := utils.ReadBuffer(reader, 2) if err != nil { return errors.Wrap(err, "获取转发端口失败") } port := binary.BigEndian.Uint16(portBuf) // 开放转发端口 todo 混合转发 slog.Info("开放转发端口", "port", port) proxy, err := socks.New(&socks.Config{ Name: strconv.Itoa(int(port)), Port: port, AuthMethods: []socks.Authenticator{ &UserPassAuthenticator{}, &NoAuthAuthenticator{}, }, }) if err != nil { return errors.Wrap(err, "创建 socks 转发服务失败") } defer proxy.Close() go func() { err := proxy.Run() if err != nil { slog.Error("代理服务启动失败", "err", err) return } }() // 等待用户连接 wg := sync.WaitGroup{} for loop := true; loop; { select { case <-s.ctx.Done(): loop = false case user, ok := <-proxy.Conn: if !ok { loop = false err = errors.New("无法获取连接") } wg.Add(1) go func() { defer wg.Done() tag := user.Tag() tagLen := len(tag) tagBuf := make([]byte, 1+tagLen) tagBuf[0] = byte(tagLen) copy(tagBuf[1:], tag) _, err := controller.Write(tagBuf) if err != nil { utils.Close(user) slog.Error("向客户端发送 tag 失败", "err", err) return } s.userConnMap[tag] = user }() } } wg.Wait() return nil } func (s *Service) startDataTun() error { dataPort := env.AppDataPort slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) if err != nil { return errors.Wrap(err, "监听数据通道失败") } defer utils.Close(ls) // 等待连接 connCh := utils.ChanConnAccept(s.ctx, ls) defer close(connCh) // 处理连接 for loop := true; loop; { select { case <-s.ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") loop = false case conn, ok := <-connCh: if !ok { slog.Debug("结束处理连接,由于获取连接失败") loop = false } s.dataConnWg.Add(1) go func() { defer s.dataConnWg.Done() defer utils.Close(conn) err := s.processDataConn(conn) if err != nil { slog.Error("处理数据通道失败", err) } }() } } // 等待子协程结束 todo 可配置等待时间 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() procCh := utils.ChanWgWait(timeout, &s.dataConnWg) defer close(procCh) select { case <-timeout.Done(): slog.Warn("等待数据通道子协程结束超时") case <-procCh: slog.Debug("数据通道子协程结束") } slog.Debug("关闭数据通道") return nil } func (s *Service) processDataConn(client net.Conn) error { slog.Info("客户端准备接收数据 " + client.RemoteAddr().String()) // 读取 tag tagLen, err := utils.ReadByte(client) if err != nil { return errors.Wrap(err, "从客户端获取 tag 失败") } tagBuf, err := utils.ReadBuffer(client, int(tagLen)) if err != nil { return errors.Wrap(err, "从客户端获取 tag 失败") } tag := string(tagBuf) select { case <-s.ctx.Done(): return nil default: } // 找到用户连接 data, ok := s.userConnMap[tag] if !ok { return errors.New("查找用户连接失败") } defer func() { delete(s.userConnMap, tag) utils.Close(data) }() // 响应用户 user := data.Conn err = socks.SendSuccess(user, client) if err != nil { // todo 考虑是否需要处理服务关闭后导致用户连接被关闭的情况 return errors.Wrap(err, "向用户发送成功消息失败") } // 发送目标地址 dest := data.Dest destLen := len(dest) destBuf := make([]byte, 1+destLen) destBuf[0] = byte(destLen) copy(destBuf[1:], dest) _, err = client.Write(destBuf) if err != nil { return errors.Wrap(err, "向客户端发送目标地址失败") } // 数据转发 slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest) // userPipeReader, userPipeWriter := io.Pipe() // defer utils.Close(userPipeWriter) // teeUser := io.TeeReader(user, userPipeWriter) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() _, err := io.Copy(client, user) if err != nil { slog.Error("数据转发失败 user->client", "err", err) } }() wg.Add(1) go func() { defer wg.Done() _, err := io.Copy(user, client) if err != nil { slog.Error("数据转发失败 client->user", "err", err) } }() wg.Wait() slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) return nil }