244 lines
6.5 KiB
Go
244 lines
6.5 KiB
Go
package handlers
|
||
|
||
import (
|
||
"encoding/base64"
|
||
"errors"
|
||
"platform/web/models"
|
||
q "platform/web/queries"
|
||
"platform/web/services"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gofiber/fiber/v2"
|
||
"golang.org/x/crypto/bcrypt"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// region Token
|
||
|
||
type TokenReq struct {
|
||
ClientID string `json:"client_id" form:"client_id"`
|
||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||
GrantType TokenGrantType `json:"grant_type" form:"grant_type"`
|
||
Code string `json:"code" form:"code"`
|
||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||
Scope string `json:"scope" form:"scope"`
|
||
}
|
||
|
||
type TokenResp struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token,omitempty"`
|
||
TokenType string `json:"token_type"`
|
||
Scope string `json:"scope,omitempty"`
|
||
ExpiresIn int `json:"expires_in"`
|
||
}
|
||
|
||
type TokenErrResp struct {
|
||
Error string `json:"error"`
|
||
Description string `json:"error_description,omitempty"`
|
||
}
|
||
|
||
type TokenGrantType string
|
||
|
||
const (
|
||
AuthorizationCode = TokenGrantType("authorization_code")
|
||
ClientCredentials = TokenGrantType("client_credentials")
|
||
RefreshToken = TokenGrantType("refresh_token")
|
||
)
|
||
|
||
// Token 处理 OAuth2.0 授权请求
|
||
func Token(c *fiber.Ctx) error {
|
||
|
||
// 验证请求参数
|
||
req := new(TokenReq)
|
||
if err := c.BodyParser(req); err != nil {
|
||
return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数")
|
||
}
|
||
if req.GrantType == "" {
|
||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
||
}
|
||
|
||
// 基于授权类型处理请求
|
||
switch req.GrantType {
|
||
|
||
case AuthorizationCode:
|
||
return authorizationCode(c, req)
|
||
|
||
case ClientCredentials:
|
||
return clientCredentials(c, req)
|
||
|
||
case RefreshToken:
|
||
return refreshToken(c, req)
|
||
|
||
default:
|
||
return sendError(c, services.ErrOauthUnsupportedGrantType)
|
||
}
|
||
}
|
||
|
||
// 授权码
|
||
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
|
||
if req.Code == "" {
|
||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||
}
|
||
|
||
client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
|
||
if err != nil {
|
||
return sendError(c, err)
|
||
}
|
||
|
||
token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
||
if err != nil {
|
||
return sendError(c, err.(services.AuthServiceOauthError))
|
||
}
|
||
|
||
return sendSuccess(c, token)
|
||
}
|
||
|
||
// 客户端凭证
|
||
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
|
||
client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret)
|
||
if err != nil {
|
||
return sendError(c, err)
|
||
}
|
||
|
||
scope := strings.Split(req.Scope, ",")
|
||
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
||
if err != nil {
|
||
return sendError(c, err.(services.AuthServiceOauthError))
|
||
}
|
||
|
||
return sendSuccess(c, token)
|
||
}
|
||
|
||
// 刷新令牌
|
||
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
|
||
if req.RefreshToken == "" {
|
||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||
}
|
||
|
||
client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret)
|
||
if err != nil {
|
||
return sendError(c, err)
|
||
}
|
||
|
||
scope := strings.Split(req.Scope, ",")
|
||
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||
if err != nil {
|
||
return sendError(c, err.(services.AuthServiceOauthError))
|
||
}
|
||
|
||
return sendSuccess(c, token)
|
||
}
|
||
|
||
// 检查客户端凭证
|
||
func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) {
|
||
header := c.Get("Authorization")
|
||
if header != "" {
|
||
basic := strings.TrimPrefix(header, "Basic ")
|
||
if basic != "" {
|
||
base, err := base64.URLEncoding.DecodeString(basic)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
parts := strings.SplitN(string(base), ":", 2)
|
||
if len(parts) == 2 {
|
||
clientId = parts[0]
|
||
clientSecret = parts[1]
|
||
}
|
||
}
|
||
}
|
||
|
||
// 查找客户端
|
||
if clientId == "" {
|
||
return nil, services.ErrOauthInvalidRequest
|
||
}
|
||
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
return nil, services.ErrOauthInvalidClient
|
||
}
|
||
return nil, err
|
||
}
|
||
|
||
// 验证客户端状态
|
||
if client.Status != 1 {
|
||
return nil, services.ErrOauthUnauthorizedClient
|
||
}
|
||
|
||
// 验证授权类型
|
||
switch grant {
|
||
case services.GrantTypeAuthorizationCode:
|
||
if !client.GrantCode {
|
||
return nil, services.ErrOauthUnauthorizedClient
|
||
}
|
||
case services.GrantTypeClientCredentials:
|
||
if !client.GrantClient || client.Spec != 0 {
|
||
return nil, services.ErrOauthUnauthorizedClient
|
||
}
|
||
case services.GrantTypeRefreshToken:
|
||
if !client.GrantRefresh {
|
||
return nil, services.ErrOauthUnauthorizedClient
|
||
}
|
||
}
|
||
|
||
// 如果客户端是 confidential,验证 client_secret,失败返回错误
|
||
if client.Spec == 0 {
|
||
if clientSecret == "" {
|
||
return nil, services.ErrOauthInvalidRequest
|
||
}
|
||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||
return nil, services.ErrOauthInvalidClient
|
||
}
|
||
}
|
||
|
||
return client, nil
|
||
}
|
||
|
||
// 发送成功响应
|
||
func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
|
||
return c.JSON(TokenResp{
|
||
AccessToken: details.AccessToken,
|
||
TokenType: "Bearer",
|
||
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
|
||
RefreshToken: details.RefreshToken,
|
||
})
|
||
}
|
||
|
||
// 发送错误响应
|
||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||
var sErr services.AuthServiceOauthError
|
||
if errors.As(err, &sErr) {
|
||
status := fiber.StatusBadRequest
|
||
var desc string
|
||
switch {
|
||
case errors.Is(sErr, services.ErrOauthInvalidRequest):
|
||
desc = "无效的请求"
|
||
case errors.Is(sErr, services.ErrOauthInvalidClient):
|
||
status = fiber.StatusUnauthorized
|
||
desc = "无效的客户端凭证"
|
||
case errors.Is(sErr, services.ErrOauthInvalidGrant):
|
||
desc = "无效的授权凭证"
|
||
case errors.Is(sErr, services.ErrOauthInvalidScope):
|
||
desc = "无效的授权范围"
|
||
case errors.Is(sErr, services.ErrOauthUnauthorizedClient):
|
||
desc = "未授权的客户端"
|
||
case errors.Is(sErr, services.ErrOauthUnsupportedGrantType):
|
||
desc = "不支持的授权类型"
|
||
}
|
||
if len(description) > 0 {
|
||
desc = description[0]
|
||
}
|
||
|
||
return c.Status(status).JSON(TokenErrResp{
|
||
Error: string(sErr),
|
||
Description: desc,
|
||
})
|
||
}
|
||
|
||
return err
|
||
}
|
||
|
||
// endregion
|