package fwd import ( "bufio" "context" "encoding/binary" "io" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/pkg/env" "proxy-server/server/pkg/orm" "proxy-server/server/pkg/socks5" "proxy-server/server/web/app/models" "strconv" "time" "github.com/pkg/errors" ) type Config struct { } type Service struct { Config *Config connMap map[string]socks5.ProxyData ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup } func New(config *Config) *Service { _config := config if _config == nil { _config = &Config{} } return &Service{ Config: _config, connMap: make(map[string]socks5.ProxyData), ctrlConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{}, } } func (s *Service) Run(ctx context.Context, errCh chan error) { defer func() { err := recover() if err != nil { slog.Error("服务由于意外的 panic 导致退出", err) } }() slog.Info("启动 fwd 服务") // 启动工作协程 subCtx, cancel := context.WithCancel(ctx) defer cancel() goNum := 2 subErrCh := make(chan error, goNum) defer close(subErrCh) go s.startCtrlTun(subCtx, subErrCh) go s.startDataTun(subCtx, subErrCh) // 等待结束 var firstSubErr error = nil for i := 0; i < goNum; i++ { err := <-subErrCh if err != nil { slog.Error("隧道错误关闭", "err", err) if firstSubErr == nil { firstSubErr = err cancel() } } else { slog.Info("隧道关闭") } } slog.Info("fwd 服务已结束") errCh <- firstSubErr } func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) { ctrlPort := env.AppCtrlPort slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) // 监听端口 ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) if err != nil { slog.Error("监听控制通道失败", "err", err) return } defer utils.Close(ls) // 等待连接 connCh := utils.ConnChan(ctx, ls) defer close(connCh) // 处理连接 loop: for { select { case <-ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") break loop case conn, ok := <-connCh: if !ok { slog.Debug("结束处理连接,由于获取连接失败") break loop } s.ctrlConnWg.Add(1) go s.processCtrlConn(conn) } } // 等待子协程结束 todo 可配置等待时间 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() procCh := utils.WaitChan(timeout, &s.ctrlConnWg) defer close(procCh) select { case <-timeout.Done(): slog.Warn("等待控制通道子协程结束超时") case <-procCh: slog.Info("控制通道子协程结束") } slog.Debug("关闭控制通道") errCh <- nil } func (s *Service) processCtrlConn(controller net.Conn) { defer func() { s.ctrlConnWg.Done() utils.Close(controller) }() slog.Info("收到客户端控制连接 " + controller.RemoteAddr().String()) reader := bufio.NewReader(controller) // 读取端口 portBuf := make([]byte, 2) _, err := io.ReadFull(reader, portBuf) if err != nil { slog.Error("读取转发端口失败", "err", err) return } port := binary.BigEndian.Uint16(portBuf) // 新建代理服务 slog.Info("新建代理服务", "port", port) proxy, err := socks5.New(&socks5.Config{ Name: strconv.Itoa(int(port)), Port: port, AuthMethods: []socks5.Authenticator{ &UserPassAuthenticator{}, &NoAuthAuthenticator{}, }, }) if err != nil { slog.Error("代理服务创建失败", "err", err) return } go func() { err := proxy.Run() if err != nil { slog.Error("代理服务建立失败", "err", err) return } }() slog.Info("代理服务已建立", "port", port) for { user := <-proxy.Conn tag := user.Tag() _, err := controller.Write([]byte{byte(len(tag))}) if err != nil { slog.Error("write error", "err", err) return } _, err = controller.Write([]byte(tag)) slog.Info("已通知客户端建立数据通道") if err != nil { slog.Error("write error", "err", err) return } s.connMap[tag] = user } } func (s *Service) startDataTun(ctx context.Context, errCh chan error) { dataPort := env.AppDataPort slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) // 监听端口 lData, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) if err != nil { slog.Error("listen error", "err", err) return } defer utils.Close(lData) // 等待连接 connCh := utils.ConnChan(ctx, lData) defer close(connCh) // 处理连接 loop: for { select { case <-ctx.Done(): slog.Debug("结束处理连接,由于上下文取消") break loop case conn, ok := <-connCh: if !ok { slog.Debug("结束处理连接,由于获取连接失败") break loop } s.dataConnWg.Add(1) go s.processDataConn(conn) } } // 等待子协程结束 todo 可配置等待时间 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() procCh := utils.WaitChan(timeout, &s.dataConnWg) defer close(procCh) select { case <-timeout.Done(): slog.Warn("等待数据通道子协程结束超时") case <-procCh: slog.Info("数据通道子协程结束") } slog.Debug("关闭数据通道") errCh <- nil } func (s *Service) processDataConn(client net.Conn) { defer func() { s.dataConnWg.Done() utils.Close(client) }() slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String()) // 读取 tag tagLen, err := utils.ReadByte(client) if err != nil { slog.Error("read error", "err", err) return } tagBuf, err := utils.ReadBuffer(client, int(tagLen)) if err != nil { slog.Error("read error", "err", err) return } tag := string(tagBuf) // 找到用户连接 data, ok := s.connMap[tag] if !ok { slog.Error("no such connection") return } // 响应用户 user := data.Conn defer utils.Close(user) socks5.SendSuccess(user, client) // 写入目标地址 _, err = client.Write([]byte{byte(len(data.Dest))}) if err != nil { slog.Error("写入目标地址失败", "err", err) return } _, err = client.Write([]byte(data.Dest)) if err != nil { slog.Error("写入目标地址失败", "err", err) return } // 数据转发 slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest) errCh := make(chan error) go func() { _, err := io.Copy(client, user) if err != nil { slog.Error("processDataConn error c2u", "err", err) } errCh <- err }() go func() { _, err := io.Copy(user, client) if err != nil { slog.Error("processDataConn error u2c", "err", err) } errCh <- err }() <-errCh slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) } type NoAuthAuthenticator struct { } func (a *NoAuthAuthenticator) Method() socks5.AuthMethod { return socks5.NoAuth } func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { // 获取用户地址 conn, ok := writer.(net.Conn) if !ok { return nil, errors.New("noAuth 认证失败,无法获取连接信息") } addr := conn.RemoteAddr().String() client, _, err := net.SplitHostPort(addr) if err != nil { return nil, errors.Wrap(err, "noAuth 认证失败") } slog.Debug("用户的地址为 " + client) // 获取服务 server, ok := ctx.Value("service").(*socks5.Server) if !ok { return nil, errors.New("noAuth 认证失败,无法获取服务信息") } node := server.Name slog.Debug("服务的名称为 " + server.Name) // 查询权限记录 slog.Info(" 客户端 " + client + " 请求连接到 " + node) var channels []models.Channel err = orm.DB. Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", node). Joins("INNER JOIN public.users u ON channels.user_id = u.id"). Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", client). Where(&models.Channel{ AuthIp: true, }). Find(&channels).Error if err != nil { return nil, errors.New("noAuth 查询用户权限失败") } // 记录应该只有一条 channel, err := orm.MaySingle(channels) if err != nil { return nil, errors.Wrap(err, "noAuth 没有权限") } // 检查是否需要密码认证 if channel.AuthPass { return nil, errors.New("noAuth 没有权限,需要密码认证") } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() slog.Info("用户剩余时间", "timeout", timeout) if timeout <= 0 { return nil, errors.New("noAuth 权限已过期") } slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) return &socks5.AuthContext{ Method: socks5.NoAuth, Timeout: uint(timeout), Payload: nil, }, nil } type UserPassAuthenticator struct { } func (a *UserPassAuthenticator) Method() socks5.AuthMethod { return socks5.UserPassAuth } func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { // 检查认证版本 slog.Debug("验证认证版本") v, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取版本号失败") } if v != socks5.AuthVersion { _, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure}) if err != nil { return nil, errors.Wrap(err, "响应认证失败") } return nil, errors.New("认证版本参数不正确") } // 读取账号 slog.Debug("验证用户账号") uLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取用户名长度失败") } usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) if err != nil { return nil, errors.Wrap(err, "读取用户名失败") } username := string(usernameBuf) // 读取密码 pLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取密码长度失败") } passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) if err != nil { return nil, errors.Wrap(err, "读取密码失败") } password := string(passwordBuf) // 查询通道配置 var channel models.Channel err = orm.DB. Where(&models.Channel{ Username: username, AuthPass: true, }). First(&channel).Error if err != nil { return nil, errors.Wrap(err, "查询用户失败") } // 检查密码 todo 哈希 if channel.Password != password { return nil, errors.New("密码错误") } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() slog.Info("用户剩余时间", "timeout", timeout) if timeout <= 0 { return nil, errors.New("权限已过期") } // 如果用户设置了双验证则检查 ip 是否在白名单中 if channel.AuthIp { slog.Debug("验证用户 ip") // 获取用户地址 conn, ok := writer.(net.Conn) if !ok { return nil, errors.New("无法获取连接信息") } addr := conn.RemoteAddr().String() client, _, err := net.SplitHostPort(addr) if err != nil { return nil, errors.Wrap(err, "无法获取连接信息") } // 查询通道配置 var ips []models.UserIp err = orm.DB. Where(&models.UserIp{ UserId: channel.UserId, IpAddress: client, }). Find(&ips).Error if err != nil { return nil, errors.Wrap(err, "查询用户 ip 失败") } // 检查是否在白名单中 if len(ips) == 0 { return nil, errors.New("没有权限") } } // 响应认证成功 _, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess}) if err != nil { slog.Error("响应认证失败", "err", err) return nil, err } return &socks5.AuthContext{ Method: socks5.UserPassAuth, Timeout: uint(timeout), Payload: nil, }, nil }