package socks import ( "context" "io" "log/slog" "proxy-server/pkg/utils" "slices" "github.com/pkg/errors" ) type AuthMethod byte const ( AuthVersion = byte(1) AuthSuccess = byte(0) AuthFailure = byte(1) NoAuth = AuthMethod(0) UserPassAuth = AuthMethod(2) NoAcceptable = byte(0xFF) ) type Authenticator interface { Method() AuthMethod Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*Authentication, error) } // authenticate 执行认证流程 func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*Authentication, error) { // 版本检查 err := checkVersion(reader) if err != nil { return nil, err } // 获取客户端认证方式 nAuth, err := utils.ReadByte(reader) if err != nil { return nil, err } methods, err := utils.ReadBuffer(reader, int(nAuth)) if err != nil { return nil, err } // 认证客户端连接 for _, authenticator := range s.config.AuthMethods { method := authenticator.Method() if slices.Contains(methods, byte(method)) { slog.Debug("使用的认证方式", method) _, err := writer.Write([]byte{Version, byte(method)}) if err != nil { slog.Error("响应认证方式失败", err) return nil, err } ctx := context.WithValue(context.Background(), "service", s) authContext, err := authenticator.Authenticate(ctx, reader, writer) if err != nil { return nil, err } return authContext, nil } } // 无适用的认证方式 _, err = writer.Write([]byte{Version, NoAcceptable}) if err != nil { return nil, err } return nil, errors.New("没有适用的认证方式") } type Authentication struct { Method AuthMethod Timeout uint Payload Payload Data map[string]any } type Payload struct { ID uint }