package handlers import ( "encoding/base64" "errors" "log/slog" "platform/web/auth" m "platform/web/models" q "platform/web/queries" s "platform/web/services" "strings" "time" "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) // region /token type TokenReq struct { GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"` ClientID string `json:"client_id" form:"client_id"` ClientSecret string `json:"client_secret" form:"client_secret"` Scope string `json:"scope" form:"scope"` s.GrantCodeData s.GrantClientData s.GrantRefreshData s.GrantPasswordData } type TokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` ExpiresIn int `json:"expires_in"` TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` } type TokenErrResp struct { Error string `json:"error"` Description string `json:"error_description,omitempty"` } // Token 处理 OAuth2.0 授权请求 func Token(c *fiber.Ctx) error { // 验证请求参数 req := new(TokenReq) if err := c.BodyParser(req); err != nil { return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数") } if req.GrantType == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type") } slog.Debug("oauth token", slog.String("grant_type", string(req.GrantType)), slog.String("client_id", req.ClientID), ) // 基于授权类型处理请求 switch req.GrantType { // 授权码模式 case s.OauthGrantTypeAuthorizationCode: if req.Code == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code") } client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier) if err != nil { return sendError(c, err.(s.AuthServiceError)) } return sendSuccess(c, token) // 客户端凭证模式 case s.OauthGrantTypeClientCredentials: client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } scope := strings.Split(req.Scope, ",") token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...) if err != nil { return sendError(c, err.(s.AuthServiceError)) } return sendSuccess(c, token) // 刷新令牌模式 case s.OauthGrantTypeRefreshToken: if req.RefreshToken == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") } client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } scope := strings.Split(req.Scope, ",") token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) if err != nil { if errors.Is(err, s.ErrInvalidToken) { return sendError(c, s.ErrOauthInvalidGrant) } return sendError(c, err) } return sendSuccess(c, token) // 密码模式 case s.OauthGrantTypePassword: if req.LoginType == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type") } if req.Username == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username") } if req.Password == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password") } client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent")) if err != nil { return sendError(c, err) } return sendSuccess(c, token) default: return sendError(c, s.ErrOauthUnsupportedGrantType) } } // 检查客户端凭证 func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) { header := c.Get("Authorization") if header != "" { basic := strings.TrimPrefix(header, "Basic ") if basic != "" { base, err := base64.RawURLEncoding.DecodeString(basic) if err != nil { return nil, err } parts := strings.SplitN(string(base), ":", 2) if len(parts) == 2 { clientId = parts[0] clientSecret = parts[1] } } } // 查找客户端 if clientId == "" { return nil, s.ErrOauthInvalidRequest } client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, s.ErrOauthInvalidClient } return nil, err } // 验证客户端状态 if client.Status != 1 { return nil, s.ErrOauthUnauthorizedClient } // 验证授权类型 switch grant { case s.OauthGrantTypeAuthorizationCode: if !client.GrantCode { return nil, s.ErrOauthUnauthorizedClient } case s.OauthGrantTypeClientCredentials: if !client.GrantClient || client.Spec != 3 { return nil, s.ErrOauthUnauthorizedClient } case s.OauthGrantTypeRefreshToken: if !client.GrantRefresh { return nil, s.ErrOauthUnauthorizedClient } case s.OauthGrantTypePassword: if !client.GrantPassword { return nil, s.ErrOauthUnauthorizedClient } } // 如果客户端是 confidential,验证 client_secret,失败返回错误 if client.Spec == 3 { if clientSecret == "" { return nil, s.ErrOauthInvalidRequest } if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { return nil, s.ErrOauthInvalidClient } } // 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑) auth.Locals(c, &auth.Context{ Payload: auth.Payload{ Id: client.ID, Type: auth.PayloadClientConfidential, Name: client.Name, Avatar: client.Icon, }, }) return client, nil } // 发送成功响应 func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error { return c.JSON(TokenResp{ AccessToken: details.AccessToken, TokenType: "Bearer", ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()), RefreshToken: details.RefreshToken, }) } // 发送错误响应 func sendError(c *fiber.Ctx, err error, description ...string) error { var sErr s.AuthServiceError if errors.As(err, &sErr) { status := fiber.StatusBadRequest var desc string switch { case errors.Is(sErr, s.ErrOauthInvalidRequest): desc = "无效的请求" case errors.Is(sErr, s.ErrOauthInvalidClient): status = fiber.StatusUnauthorized desc = "无效的客户端凭证" case errors.Is(sErr, s.ErrOauthInvalidGrant): desc = "无效的授权凭证" case errors.Is(sErr, s.ErrOauthInvalidScope): desc = "无效的授权范围" case errors.Is(sErr, s.ErrOauthUnauthorizedClient): desc = "未授权的客户端" case errors.Is(sErr, s.ErrOauthUnsupportedGrantType): desc = "不支持的授权类型" } if len(description) > 0 { desc = description[0] } return c.Status(status).JSON(TokenErrResp{ Error: string(sErr), Description: desc, }) } return err } // endregion // region /revoke type RevokeReq struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` } func Revoke(c *fiber.Ctx) error { _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { // 用户未登录 return nil } // 解析请求参数 req := new(RevokeReq) if err := c.BodyParser(req); err != nil { return err } // 删除会话 err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken) if err != nil { return err } return nil } // endregion // region /profile type IntrospectResp struct { m.User } func Introspect(c *fiber.Ctx) error { // 验证权限 authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) if err != nil { return err } // 获取用户信息 profile, err := q.User. Where(q.User.ID.Eq(authCtx.Payload.Id)). Omit(q.User.Password, q.User.DeletedAt). Take() if err != nil { return err } // 掩码敏感信息 if profile.Phone != "" { profile.Phone = maskPhone(profile.Phone) } if profile.IDNo != "" { profile.IDNo = maskIdNo(profile.IDNo) } return c.JSON(IntrospectResp{*profile}) } func maskPhone(phone string) string { if len(phone) < 11 { return phone } return phone[:3] + "****" + phone[7:] } func maskIdNo(idNo string) string { if len(idNo) < 18 { return idNo } return idNo[:3] + "*********" + idNo[14:] } // endregion