Files
platform/web/handlers/oauth.go

244 lines
6.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handlers
import (
"encoding/base64"
"errors"
"platform/web/models"
q "platform/web/queries"
"platform/web/services"
"strings"
"time"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
// region Token
type TokenReq struct {
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
GrantType TokenGrantType `json:"grant_type" form:"grant_type"`
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
RefreshToken string `json:"refresh_token" form:"refresh_token"`
Scope string `json:"scope" form:"scope"`
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ExpiresIn int `json:"expires_in"`
}
type TokenErrResp struct {
Error string `json:"error"`
Description string `json:"error_description,omitempty"`
}
type TokenGrantType string
const (
AuthorizationCode = TokenGrantType("authorization_code")
ClientCredentials = TokenGrantType("client_credentials")
RefreshToken = TokenGrantType("refresh_token")
)
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
// 基于授权类型处理请求
switch req.GrantType {
case AuthorizationCode:
return authorizationCode(c, req)
case ClientCredentials:
return clientCredentials(c, req)
case RefreshToken:
return refreshToken(c, req)
default:
return sendError(c, services.ErrOauthUnsupportedGrantType)
}
}
// 授权码
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
if req.Code == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 客户端凭证
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 刷新令牌
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
if req.RefreshToken == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
}
return sendSuccess(c, token)
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
if basic != "" {
base, err := base64.URLEncoding.DecodeString(basic)
if err != nil {
return nil, err
}
parts := strings.SplitN(string(base), ":", 2)
if len(parts) == 2 {
clientId = parts[0]
clientSecret = parts[1]
}
}
}
// 查找客户端
if clientId == "" {
return nil, services.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, services.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, services.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case services.GrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, services.ErrOauthUnauthorizedClient
}
case services.GrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, services.ErrOauthUnauthorizedClient
}
case services.GrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, services.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, services.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, services.ErrOauthInvalidClient
}
}
return client, nil
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
RefreshToken: details.RefreshToken,
})
}
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr services.AuthServiceOauthError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, services.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, services.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, services.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, services.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, services.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, services.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {
desc = description[0]
}
return c.Status(status).JSON(TokenErrResp{
Error: string(sErr),
Description: desc,
})
}
return err
}
// endregion