package socks import ( "bufio" "context" "fmt" "io" "log" "log/slog" "net" "os" "proxy-server/pkg/utils" "strconv" "time" "github.com/pkg/errors" ) const ( Version = byte(5) ) type Config struct { Name string Host string Port uint16 // 认证方法 AuthMethods []Authenticator // 域名解析 Resolver NameResolver // 自定义认证规则 Rules RuleSet // 地址重写 Rewriter AddressRewriter // 用于 bind 和 associate BindIP net.IP // Logger Logger *log.Logger // 自定义连接流程 Dial func(network, addr string) (net.Conn, error) } type Server struct { config *Config ctx context.Context cancel context.CancelFunc wg utils.CountWaitGroup Name string Port uint16 Conn chan ProxyConn } // New 创建服务器 func New(conf *Config) (*Server, error) { if conf == nil { conf = &Config{} } if len(conf.AuthMethods) == 0 { return nil, ConfigError("认证方法不能为空") } if conf.Resolver == nil { conf.Resolver = DNSResolver{} } if conf.Rules == nil { conf.Rules = PermitAll() } if conf.Logger == nil { conf.Logger = log.New(os.Stdout, "", log.LstdFlags) } if conf.Dial == nil { conf.Dial = func(network, addr string) (net.Conn, error) { return net.Dial(network, addr) } } ctx, cancel := context.WithCancel(context.Background()) return &Server{ config: conf, ctx: ctx, cancel: cancel, wg: utils.CountWaitGroup{}, Name: conf.Name, Port: conf.Port, Conn: make(chan ProxyConn), }, nil } // Run 监听端口 func (s *Server) Run() error { slog.Debug("启动 socks5 代理服务") // 监听端口 host := s.config.Host port := s.config.Port addr := net.JoinHostPort(host, strconv.Itoa(int(port))) ls, err := net.Listen("tcp", addr) if err != nil { return errors.Wrap(err, "监听端口失败") } defer utils.Close(ls) slog.Debug("正在监听端口", slog.Uint64("port", uint64(port))) // 处理连接 connCh := utils.ChanConnAccept(s.ctx, ls) defer close(connCh) err = nil for loop := true; loop; { select { case <-s.ctx.Done(): slog.Debug("socks 服务主动停止") loop = false case conn, ok := <-connCh: if !ok { err = errors.New("意外错误,无法获取连接") loop = false s.Close() break } s.wg.Add(1) go func() { defer s.wg.Done() // 连接要传出,不能在这里关闭连接 err := s.process(conn) if err != nil { slog.Error("处理连接失败", err) } }() } } // 关闭服务 timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() wgCh := utils.ChanWgWait(timeout, &s.wg) err = nil select { case <-timeout.Done(): err = errors.New("关闭超时(强制关闭)") case <-wgCh: } close(s.Conn) return err } // Close 关闭服务 func (s *Server) Close() { s.cancel() } // process 建立连接 func (s *Server) process(conn net.Conn) error { slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接") reader := bufio.NewReader(conn) // 认证 slog.Debug("开始认证流程") authContext, err := s.authenticate(reader, conn) if err != nil { utils.Close(conn) slog.Error("认证失败", err) return err } else { slog.Debug("认证完成") } // 处理连接请求 slog.Debug("处理连接请求") request, err := s.request(reader, conn) if err != nil { slog.Error("连接请求处理失败", err) return err } else { slog.Debug("连接请求处理完成") } request.Authentication = authContext user, ok := conn.RemoteAddr().(*net.TCPAddr) if !ok { return fmt.Errorf("获取用户地址失败") } request.RemoteAddr = &AddrSpec{ IP: user.IP, Port: user.Port, } // 处理请求 slog.Debug("开始代理流量") err = s.handle(request, conn) if err != nil { return err } return nil } // checkVersion 检查客户端版本 func checkVersion(reader io.Reader) error { version, err := utils.ReadByte(reader) if err != nil { return err } slog.Debug("客户端请求版本", "version", version) if version != Version { return errors.New("客户端版本不兼容") } return nil }