package auth import ( "context" "encoding/base64" "errors" "log/slog" client2 "platform/web/domains/client" q "platform/web/queries" "slices" "strings" "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" ) type ProtectBuilder struct { c *fiber.Ctx types []PayloadType scopes []string } func NewProtect(c *fiber.Ctx) *ProtectBuilder { return &ProtectBuilder{c, []PayloadType{}, []string{}} } func (p *ProtectBuilder) Payload(types ...PayloadType) *ProtectBuilder { p.types = types return p } func (p *ProtectBuilder) Scopes(scopes ...string) *ProtectBuilder { p.scopes = scopes return p } func (p *ProtectBuilder) Do() (*Context, error) { return Protect(p.c, p.types, p.scopes) } func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) { // 获取令牌 var header = c.Get("Authorization") var split = strings.Split(header, " ") if len(split) != 2 { slog.Debug("Authorization 头格式不正确") return nil, ErrUnauthorize } var token = strings.TrimSpace(split[1]) if token == "" { slog.Debug("提供的令牌为空") return nil, ErrUnauthorize } var auth *Context var err error switch split[0] { case "Bearer": auth, err = authBearer(c.Context(), token) if err != nil { slog.Debug("Bearer 认证失败", "err", err) return nil, ErrUnauthorize } case "Basic": if !slices.Contains(types, PayloadInternalServer) { slog.Debug("禁止使用 Basic 认证方式") return nil, ErrUnauthorize } auth, err = authBasic(c.Context(), token) if err != nil { slog.Debug("Basic 认证失败", "err", err) return nil, ErrUnauthorize } default: slog.Debug("无效的认证方式", "method", split[0]) return nil, ErrUnauthorize } // 检查权限 if !slices.Contains(types, auth.Payload.Type) { slog.Debug("无效的负载类型", "except", types, "actual", auth.Payload.Type) return nil, ErrForbidden } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { slog.Debug("无效的认证权限", "except", permissions, "actual", auth.Permissions) return nil, ErrForbidden } // 保存到上下文 Locals(c, auth) return auth, nil } func Locals(c *fiber.Ctx, auth *Context) { c.Locals("auth", auth) } func authBearer(ctx context.Context, token string) (*Context, error) { auth, err := FindSession(ctx, token) if err != nil { slog.Debug(err.Error()) return nil, err } return auth, nil } func authBasic(_ context.Context, token string) (*Context, error) { // 解析 Basic 认证信息 var base, err = base64.RawURLEncoding.DecodeString(token) if err != nil { base, err = base64.URLEncoding.DecodeString(token) if err != nil { return nil, errors.New("令牌格式错误,无法解析令牌") } } var split = strings.Split(string(base), ":") if len(split) != 2 { return nil, errors.New("令牌格式错误,必须是 : 格式") } var clientID = split[0] // 获取客户端信息 client, err := q.Client. Where( q.Client.ClientID.Eq(clientID), q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)), q.Client.GrantClient.Is(true), q.Client.Status.Eq(1)). Take() if err != nil { return nil, err } // 检查客户端密钥 var clientSecret = split[1] if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { return nil, errors.New("客户端密钥错误") } // todo 查询客户端关联权限 // 组织授权信息(一次性请求) return &Context{ Payload: Payload{ Id: client.ID, Type: PayloadTypeFromClientSpec(client2.Spec(client.Spec)), Name: client.Name, Avatar: client.Icon, }, Permissions: nil, Metadata: nil, }, nil } type AuthenticationErr string func (e AuthenticationErr) Error() string { return string(e) } var ( ErrUnauthorize = AuthenticationErr("令牌无效") ErrForbidden = AuthenticationErr("没有权限") )