2025-04-24 10:52:13 +08:00
|
|
|
|
package handlers
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"encoding/base64"
|
|
|
|
|
|
"errors"
|
|
|
|
|
|
"log/slog"
|
2025-05-10 13:38:47 +08:00
|
|
|
|
auth2 "platform/web/auth"
|
2025-05-09 18:56:17 +08:00
|
|
|
|
client2 "platform/web/domains/client"
|
2025-04-24 10:52:13 +08:00
|
|
|
|
m "platform/web/models"
|
|
|
|
|
|
q "platform/web/queries"
|
|
|
|
|
|
s "platform/web/services"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// region /token
|
|
|
|
|
|
|
|
|
|
|
|
type TokenReq struct {
|
2025-05-10 13:38:47 +08:00
|
|
|
|
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
|
|
|
|
|
|
ClientID string `json:"client_id" form:"client_id"`
|
|
|
|
|
|
ClientSecret string `json:"client_secret" form:"client_secret"`
|
|
|
|
|
|
Scope string `json:"scope" form:"scope"`
|
2025-04-24 10:52:13 +08:00
|
|
|
|
s.GrantCodeData
|
|
|
|
|
|
s.GrantClientData
|
|
|
|
|
|
s.GrantRefreshData
|
|
|
|
|
|
s.GrantPasswordData
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type TokenResp struct {
|
|
|
|
|
|
AccessToken string `json:"access_token"`
|
|
|
|
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
|
|
|
|
ExpiresIn int `json:"expires_in"`
|
|
|
|
|
|
TokenType string `json:"token_type"`
|
|
|
|
|
|
Scope string `json:"scope,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type TokenErrResp struct {
|
|
|
|
|
|
Error string `json:"error"`
|
|
|
|
|
|
Description string `json:"error_description,omitempty"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// Token 处理 OAuth2.0 授权请求
|
|
|
|
|
|
func Token(c *fiber.Ctx) error {
|
|
|
|
|
|
|
|
|
|
|
|
// 验证请求参数
|
|
|
|
|
|
req := new(TokenReq)
|
|
|
|
|
|
if err := c.BodyParser(req); err != nil {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
|
|
|
|
|
|
}
|
|
|
|
|
|
if req.GrantType == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
slog.Debug("oauth token", slog.String("grant_type",
|
|
|
|
|
|
string(req.GrantType)),
|
|
|
|
|
|
slog.String("client_id", req.ClientID),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 基于授权类型处理请求
|
|
|
|
|
|
switch req.GrantType {
|
|
|
|
|
|
|
|
|
|
|
|
// 授权码模式
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantAuthorizationCode:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if req.Code == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err.(s.AuthServiceError))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return sendSuccess(c, token)
|
|
|
|
|
|
|
|
|
|
|
|
// 客户端凭证模式
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantClientCredentials:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
scope := strings.Split(req.Scope, ",")
|
|
|
|
|
|
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err.(s.AuthServiceError))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return sendSuccess(c, token)
|
|
|
|
|
|
|
|
|
|
|
|
// 刷新令牌模式
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantRefreshToken:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if req.RefreshToken == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
scope := strings.Split(req.Scope, ",")
|
|
|
|
|
|
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if errors.Is(err, s.ErrInvalidToken) {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidGrant)
|
|
|
|
|
|
}
|
|
|
|
|
|
return sendError(c, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return sendSuccess(c, token)
|
|
|
|
|
|
|
|
|
|
|
|
// 密码模式
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantPassword:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if req.LoginType == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
|
|
|
|
|
}
|
|
|
|
|
|
if req.Username == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username")
|
|
|
|
|
|
}
|
|
|
|
|
|
if req.Password == "" {
|
|
|
|
|
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password")
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return sendError(c, err)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
token, err := s.Auth.OauthPassword(c.Context(), client, &req.GrantPasswordData, c.IP(), c.Get("User-Agent"))
|
|
|
|
|
|
if err != nil {
|
2025-05-08 13:18:54 +08:00
|
|
|
|
return sendError(c, err)
|
2025-04-24 10:52:13 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return sendSuccess(c, token)
|
|
|
|
|
|
|
|
|
|
|
|
default:
|
|
|
|
|
|
return sendError(c, s.ErrOauthUnsupportedGrantType)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 检查客户端凭证
|
2025-05-10 13:38:47 +08:00
|
|
|
|
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
|
2025-04-24 10:52:13 +08:00
|
|
|
|
header := c.Get("Authorization")
|
|
|
|
|
|
if header != "" {
|
|
|
|
|
|
basic := strings.TrimPrefix(header, "Basic ")
|
|
|
|
|
|
if basic != "" {
|
2025-04-28 11:44:54 +08:00
|
|
|
|
base, err := base64.RawURLEncoding.DecodeString(basic)
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
parts := strings.SplitN(string(base), ":", 2)
|
|
|
|
|
|
if len(parts) == 2 {
|
|
|
|
|
|
clientId = parts[0]
|
|
|
|
|
|
clientSecret = parts[1]
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 查找客户端
|
|
|
|
|
|
if clientId == "" {
|
|
|
|
|
|
return nil, s.ErrOauthInvalidRequest
|
|
|
|
|
|
}
|
|
|
|
|
|
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
|
return nil, s.ErrOauthInvalidClient
|
|
|
|
|
|
}
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 验证客户端状态
|
|
|
|
|
|
if client.Status != 1 {
|
|
|
|
|
|
return nil, s.ErrOauthUnauthorizedClient
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 验证授权类型
|
|
|
|
|
|
switch grant {
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantAuthorizationCode:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if !client.GrantCode {
|
|
|
|
|
|
return nil, s.ErrOauthUnauthorizedClient
|
|
|
|
|
|
}
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantClientCredentials:
|
2025-05-09 18:56:17 +08:00
|
|
|
|
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
|
2025-04-24 10:52:13 +08:00
|
|
|
|
return nil, s.ErrOauthUnauthorizedClient
|
|
|
|
|
|
}
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantRefreshToken:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if !client.GrantRefresh {
|
|
|
|
|
|
return nil, s.ErrOauthUnauthorizedClient
|
|
|
|
|
|
}
|
2025-05-10 13:38:47 +08:00
|
|
|
|
case auth2.GrantPassword:
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if !client.GrantPassword {
|
|
|
|
|
|
return nil, s.ErrOauthUnauthorizedClient
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 如果客户端是 confidential,验证 client_secret,失败返回错误
|
2025-05-09 18:56:17 +08:00
|
|
|
|
if client.Spec == int32(client2.SpecWeb) || client.Spec == int32(client2.SpecTrusted) {
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if clientSecret == "" {
|
|
|
|
|
|
return nil, s.ErrOauthInvalidRequest
|
|
|
|
|
|
}
|
|
|
|
|
|
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
|
|
|
|
|
return nil, s.ErrOauthInvalidClient
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-08 13:18:54 +08:00
|
|
|
|
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
|
2025-05-10 13:38:47 +08:00
|
|
|
|
auth2.Locals(c, &auth2.Context{
|
|
|
|
|
|
Payload: auth2.Payload{
|
2025-05-08 13:18:54 +08:00
|
|
|
|
Id: client.ID,
|
2025-05-10 13:38:47 +08:00
|
|
|
|
Type: auth2.PayloadSecuredServer,
|
2025-05-08 13:18:54 +08:00
|
|
|
|
Name: client.Name,
|
|
|
|
|
|
Avatar: client.Icon,
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
|
2025-04-24 10:52:13 +08:00
|
|
|
|
return client, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 发送成功响应
|
|
|
|
|
|
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
|
|
|
|
|
|
return c.JSON(TokenResp{
|
|
|
|
|
|
AccessToken: details.AccessToken,
|
|
|
|
|
|
TokenType: "Bearer",
|
|
|
|
|
|
ExpiresIn: int(time.Until(details.AccessTokenExpires).Seconds()),
|
|
|
|
|
|
RefreshToken: details.RefreshToken,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 发送错误响应
|
|
|
|
|
|
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
|
|
|
|
|
var sErr s.AuthServiceError
|
|
|
|
|
|
if errors.As(err, &sErr) {
|
|
|
|
|
|
status := fiber.StatusBadRequest
|
|
|
|
|
|
var desc string
|
|
|
|
|
|
switch {
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthInvalidRequest):
|
|
|
|
|
|
desc = "无效的请求"
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthInvalidClient):
|
|
|
|
|
|
status = fiber.StatusUnauthorized
|
|
|
|
|
|
desc = "无效的客户端凭证"
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthInvalidGrant):
|
|
|
|
|
|
desc = "无效的授权凭证"
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthInvalidScope):
|
|
|
|
|
|
desc = "无效的授权范围"
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
|
|
|
|
|
|
desc = "未授权的客户端"
|
|
|
|
|
|
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
|
|
|
|
|
|
desc = "不支持的授权类型"
|
|
|
|
|
|
}
|
|
|
|
|
|
if len(description) > 0 {
|
|
|
|
|
|
desc = description[0]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return c.Status(status).JSON(TokenErrResp{
|
|
|
|
|
|
Error: string(sErr),
|
|
|
|
|
|
Description: desc,
|
|
|
|
|
|
})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// endregion
|
|
|
|
|
|
|
|
|
|
|
|
// region /revoke
|
|
|
|
|
|
|
|
|
|
|
|
type RevokeReq struct {
|
|
|
|
|
|
AccessToken string `json:"access_token"`
|
|
|
|
|
|
RefreshToken string `json:"refresh_token"`
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func Revoke(c *fiber.Ctx) error {
|
2025-05-10 13:38:47 +08:00
|
|
|
|
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
2025-04-24 10:52:13 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
// 用户未登录
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 解析请求参数
|
|
|
|
|
|
req := new(RevokeReq)
|
|
|
|
|
|
if err := c.BodyParser(req); err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 删除会话
|
|
|
|
|
|
err = s.Session.Remove(c.Context(), req.AccessToken, req.RefreshToken)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// endregion
|
2025-04-26 17:59:34 +08:00
|
|
|
|
|
|
|
|
|
|
// region /profile
|
|
|
|
|
|
|
|
|
|
|
|
type IntrospectResp struct {
|
|
|
|
|
|
m.User
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func Introspect(c *fiber.Ctx) error {
|
|
|
|
|
|
// 验证权限
|
2025-05-10 13:38:47 +08:00
|
|
|
|
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
2025-04-26 17:59:34 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 获取用户信息
|
|
|
|
|
|
profile, err := q.User.
|
|
|
|
|
|
Where(q.User.ID.Eq(authCtx.Payload.Id)).
|
|
|
|
|
|
Omit(q.User.Password, q.User.DeletedAt).
|
|
|
|
|
|
Take()
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-04-29 18:48:14 +08:00
|
|
|
|
// 掩码敏感信息
|
|
|
|
|
|
if profile.Phone != "" {
|
|
|
|
|
|
profile.Phone = maskPhone(profile.Phone)
|
|
|
|
|
|
}
|
|
|
|
|
|
if profile.IDNo != "" {
|
|
|
|
|
|
profile.IDNo = maskIdNo(profile.IDNo)
|
|
|
|
|
|
}
|
2025-04-26 17:59:34 +08:00
|
|
|
|
return c.JSON(IntrospectResp{*profile})
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-04-29 18:48:14 +08:00
|
|
|
|
func maskPhone(phone string) string {
|
|
|
|
|
|
if len(phone) < 11 {
|
|
|
|
|
|
return phone
|
|
|
|
|
|
}
|
|
|
|
|
|
return phone[:3] + "****" + phone[7:]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func maskIdNo(idNo string) string {
|
|
|
|
|
|
if len(idNo) < 18 {
|
|
|
|
|
|
return idNo
|
|
|
|
|
|
}
|
|
|
|
|
|
return idNo[:3] + "*********" + idNo[14:]
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-04-26 17:59:34 +08:00
|
|
|
|
// endregion
|