package handlers import ( "encoding/base64" "errors" "platform/web/models" q "platform/web/queries" "platform/web/services" "strings" "time" "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) // region Token type TokenReq struct { ClientID string `json:"client_id" form:"client_id"` ClientSecret string `json:"client_secret" form:"client_secret"` GrantType TokenGrantType `json:"grant_type" form:"grant_type"` Code string `json:"code" form:"code"` RedirectURI string `json:"redirect_uri" form:"redirect_uri"` CodeVerifier string `json:"code_verifier" form:"code_verifier"` RefreshToken string `json:"refresh_token" form:"refresh_token"` Scope string `json:"scope" form:"scope"` } type TokenResp struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token,omitempty"` TokenType string `json:"token_type"` Scope string `json:"scope,omitempty"` ExpiresIn int `json:"expires_in"` } type TokenErrResp struct { Error string `json:"error"` Description string `json:"error_description,omitempty"` } type TokenGrantType string const ( AuthorizationCode = TokenGrantType("authorization_code") ClientCredentials = TokenGrantType("client_credentials") RefreshToken = TokenGrantType("refresh_token") ) // Token 处理 OAuth2.0 授权请求 func Token(c *fiber.Ctx) error { // 验证请求参数 req := new(TokenReq) if err := c.BodyParser(req); err != nil { return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数") } if req.GrantType == "" { return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:grant_type") } // 基于授权类型处理请求 switch req.GrantType { case AuthorizationCode: return authorizationCode(c, req) case ClientCredentials: return clientCredentials(c, req) case RefreshToken: return refreshToken(c, req) default: return sendError(c, services.ErrOauthUnsupportedGrantType) } } // 授权码 func authorizationCode(c *fiber.Ctx, req *TokenReq) error { if req.Code == "" { return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:code") } client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier) if err != nil { return sendError(c, err.(services.AuthServiceOauthError)) } return sendSuccess(c, token) } // 客户端凭证 func clientCredentials(c *fiber.Ctx, req *TokenReq) error { client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } scope := strings.Split(req.Scope, ",") token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...) if err != nil { return sendError(c, err.(services.AuthServiceOauthError)) } return sendSuccess(c, token) } // 刷新令牌 func refreshToken(c *fiber.Ctx, req *TokenReq) error { if req.RefreshToken == "" { return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") } client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) } scope := strings.Split(req.Scope, ",") token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope) if err != nil { return sendError(c, err.(services.AuthServiceOauthError)) } return sendSuccess(c, token) } // 检查客户端凭证 func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) { header := c.Get("Authorization") if header != "" { basic := strings.TrimPrefix(header, "Basic ") if basic != "" { base, err := base64.URLEncoding.DecodeString(basic) if err != nil { return nil, err } parts := strings.SplitN(string(base), ":", 2) if len(parts) == 2 { clientId = parts[0] clientSecret = parts[1] } } } // 查找客户端 if clientId == "" { return nil, services.ErrOauthInvalidRequest } client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take() if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, services.ErrOauthInvalidClient } return nil, err } // 验证客户端状态 if client.Status != 1 { return nil, services.ErrOauthUnauthorizedClient } // 验证授权类型 switch grant { case services.GrantTypeAuthorizationCode: if !client.GrantCode { return nil, services.ErrOauthUnauthorizedClient } case services.GrantTypeClientCredentials: if !client.GrantClient || client.Spec != 0 { return nil, services.ErrOauthUnauthorizedClient } case services.GrantTypeRefreshToken: if !client.GrantRefresh { return nil, services.ErrOauthUnauthorizedClient } } // 如果客户端是 confidential,验证 client_secret,失败返回错误 if client.Spec == 0 { if clientSecret == "" { return nil, services.ErrOauthInvalidRequest } if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { return nil, services.ErrOauthInvalidClient } } return client, nil } // 发送成功响应 func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error { return c.JSON(TokenResp{ AccessToken: details.AccessToken, TokenType: "Bearer", ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()), RefreshToken: details.RefreshToken, }) } // 发送错误响应 func sendError(c *fiber.Ctx, err error, description ...string) error { var sErr services.AuthServiceOauthError if errors.As(err, &sErr) { status := fiber.StatusBadRequest var desc string switch { case errors.Is(sErr, services.ErrOauthInvalidRequest): desc = "无效的请求" case errors.Is(sErr, services.ErrOauthInvalidClient): status = fiber.StatusUnauthorized desc = "无效的客户端凭证" case errors.Is(sErr, services.ErrOauthInvalidGrant): desc = "无效的授权凭证" case errors.Is(sErr, services.ErrOauthInvalidScope): desc = "无效的授权范围" case errors.Is(sErr, services.ErrOauthUnauthorizedClient): desc = "未授权的客户端" case errors.Is(sErr, services.ErrOauthUnsupportedGrantType): desc = "不支持的授权类型" } if len(description) > 0 { desc = description[0] } return c.Status(status).JSON(TokenErrResp{ Error: string(sErr), Description: desc, }) } return err } // endregion