完善登录逻辑,登录接口统一到 /token

This commit is contained in:
2025-04-23 19:01:08 +08:00
parent b181864a2f
commit 1374757eab
28 changed files with 404 additions and 266 deletions

View File

@@ -51,7 +51,7 @@ func ListBill(c *fiber.Ctx) error {
do = do.Where(q.Bill.BillNo.Eq(*req.BillNo))
}
bills, err := do.Debug().
bills, err := do.
Preload(q.Bill.Resource, q.Bill.Trade, q.Bill.Refund).
Preload(q.Bill.Resource.Pss).
Order(q.Bill.CreatedAt.Desc()).

View File

@@ -3,9 +3,10 @@ package handlers
import (
"encoding/base64"
"errors"
"platform/web/models"
"log/slog"
m "platform/web/models"
q "platform/web/queries"
"platform/web/services"
s "platform/web/services"
"strings"
"time"
@@ -17,22 +18,42 @@ import (
// region Token
type TokenReq struct {
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
GrantType TokenGrantType `json:"grant_type" form:"grant_type"`
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
RefreshToken string `json:"refresh_token" form:"refresh_token"`
Scope string `json:"scope" form:"scope"`
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"`
TokenReqCode
TokenReqClient
TokenReqRefresh
TokenReqPassword
}
type TokenReqCode struct {
Code string `json:"code" form:"code"`
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
}
type TokenReqClient struct {
}
type TokenReqRefresh struct {
RefreshToken string `json:"refresh_token" form:"refresh_token"`
}
type TokenReqPassword struct {
LoginType s.OauthGrantLoginType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type TokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ExpiresIn int `json:"expires_in"`
}
type TokenErrResp struct {
@@ -40,57 +61,57 @@ type TokenErrResp struct {
Description string `json:"error_description,omitempty"`
}
type TokenGrantType string
const (
AuthorizationCode = TokenGrantType("authorization_code")
ClientCredentials = TokenGrantType("client_credentials")
RefreshToken = TokenGrantType("refresh_token")
)
// Token 处理 OAuth2.0 授权请求
func Token(c *fiber.Ctx) error {
// 验证请求参数
req := new(TokenReq)
if err := c.BodyParser(req); err != nil {
return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数")
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
}
if req.GrantType == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数grant_type")
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数grant_type")
}
slog.Debug("oauth token", slog.String("grant_type",
string(req.GrantType)),
slog.String("client_id", req.ClientID),
)
// 基于授权类型处理请求
switch req.GrantType {
case AuthorizationCode:
case s.OauthGrantTypeAuthorizationCode:
return authorizationCode(c, req)
case ClientCredentials:
case s.OauthGrantTypeClientCredentials:
return clientCredentials(c, req)
case RefreshToken:
case s.OauthGrantTypeRefreshToken:
return refreshToken(c, req)
case s.OauthGrantTypePassword:
return password(c, req)
default:
return sendError(c, services.ErrOauthUnsupportedGrantType)
return sendError(c, s.ErrOauthUnsupportedGrantType)
}
}
// 授权码
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
if req.Code == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数code")
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
}
client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
client, err := protect(c, s.OauthGrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
@@ -98,15 +119,15 @@ func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
// 客户端凭证
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret)
client, err := protect(c, s.OauthGrantTypeClientCredentials, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...)
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
return sendError(c, err.(s.AuthServiceOauthError))
}
return sendSuccess(c, token)
@@ -115,19 +136,19 @@ func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
// 刷新令牌
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
if req.RefreshToken == "" {
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
}
client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret)
client, err := protect(c, s.OauthGrantTypeRefreshToken, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
if err != nil {
if errors.Is(err, services.ErrInvalidToken) {
return sendError(c, services.ErrOauthInvalidGrant)
if errors.Is(err, s.ErrInvalidToken) {
return sendError(c, s.ErrOauthInvalidGrant)
}
return sendError(c, err)
}
@@ -135,8 +156,108 @@ func refreshToken(c *fiber.Ctx, req *TokenReq) error {
return sendSuccess(c, token)
}
func password(c *fiber.Ctx, req *TokenReq) error {
if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
}
if req.Username == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数username")
}
if req.Password == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password")
}
// 验证客户端凭证
_, err := protect(c, s.OauthGrantTypePassword, req.ClientID, req.ClientSecret)
if err != nil {
return sendError(c, err)
}
// 验证验证码
err = s.Verifier.VerifySms(c.Context(), req.Username, req.Password)
if err != nil {
if errors.Is(err, s.ErrVerifierServiceInvalid) {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
return err
}
// 查找用户
var user *m.User
err = q.Q.Transaction(func(tx *q.Query) error {
switch req.LoginType {
case s.OauthGrantPasswordTypePhoneCode:
user, err = tx.User.Where(tx.User.Phone.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypeEmailCode:
user, err = tx.User.Where(tx.User.Email.Eq(req.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case s.OauthGrantPasswordTypePassword:
user, err = tx.User.
Where(tx.User.Or(
tx.User.Phone.Eq(req.Username),
tx.User.Email.Eq(req.Username),
tx.User.Username.Eq(req.Username),
)).
Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
default:
return sendError(c, s.ErrOauthInvalidRequest, "无效的登录类型")
}
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
if user == nil {
user = &m.User{
Phone: req.Username,
Username: req.Username,
}
}
// 更新用户的登录时间
user.LastLogin = time.Now()
user.LastLoginHost = c.IP()
user.LastLoginAgent = c.Get("User-Agent")
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
return err
}
return nil
})
if err != nil {
return err
}
// 保存到会话
auth := s.AuthContext{
Payload: s.Payload{
Id: user.ID,
Type: s.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
duration := s.DefaultSessionConfig
if !req.Remember {
duration.RefreshTokenDuration = 0
}
token, err := s.Session.Create(c.Context(), auth)
if err != nil {
return err
}
return sendSuccess(c, token)
}
// 检查客户端凭证
func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) {
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization")
if header != "" {
basic := strings.TrimPrefix(header, "Basic ")
@@ -155,44 +276,48 @@ func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret stri
// 查找客户端
if clientId == "" {
return nil, services.ErrOauthInvalidRequest
return nil, s.ErrOauthInvalidRequest
}
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, services.ErrOauthInvalidClient
return nil, s.ErrOauthInvalidClient
}
return nil, err
}
// 验证客户端状态
if client.Status != 1 {
return nil, services.ErrOauthUnauthorizedClient
return nil, s.ErrOauthUnauthorizedClient
}
// 验证授权类型
switch grant {
case services.GrantTypeAuthorizationCode:
case s.OauthGrantTypeAuthorizationCode:
if !client.GrantCode {
return nil, services.ErrOauthUnauthorizedClient
return nil, s.ErrOauthUnauthorizedClient
}
case services.GrantTypeClientCredentials:
case s.OauthGrantTypeClientCredentials:
if !client.GrantClient || client.Spec != 0 {
return nil, services.ErrOauthUnauthorizedClient
return nil, s.ErrOauthUnauthorizedClient
}
case services.GrantTypeRefreshToken:
case s.OauthGrantTypeRefreshToken:
if !client.GrantRefresh {
return nil, services.ErrOauthUnauthorizedClient
return nil, s.ErrOauthUnauthorizedClient
}
case s.OauthGrantTypePassword:
if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient
}
}
// 如果客户端是 confidential验证 client_secret失败返回错误
if client.Spec == 0 {
if clientSecret == "" {
return nil, services.ErrOauthInvalidRequest
return nil, s.ErrOauthInvalidRequest
}
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, services.ErrOauthInvalidClient
return nil, s.ErrOauthInvalidClient
}
}
@@ -200,7 +325,7 @@ func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret stri
}
// 发送成功响应
func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
return c.JSON(TokenResp{
AccessToken: details.AccessToken,
TokenType: "Bearer",
@@ -211,23 +336,23 @@ func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
// 发送错误响应
func sendError(c *fiber.Ctx, err error, description ...string) error {
var sErr services.AuthServiceOauthError
var sErr s.AuthServiceOauthError
if errors.As(err, &sErr) {
status := fiber.StatusBadRequest
var desc string
switch {
case errors.Is(sErr, services.ErrOauthInvalidRequest):
case errors.Is(sErr, s.ErrOauthInvalidRequest):
desc = "无效的请求"
case errors.Is(sErr, services.ErrOauthInvalidClient):
case errors.Is(sErr, s.ErrOauthInvalidClient):
status = fiber.StatusUnauthorized
desc = "无效的客户端凭证"
case errors.Is(sErr, services.ErrOauthInvalidGrant):
case errors.Is(sErr, s.ErrOauthInvalidGrant):
desc = "无效的授权凭证"
case errors.Is(sErr, services.ErrOauthInvalidScope):
case errors.Is(sErr, s.ErrOauthInvalidScope):
desc = "无效的授权范围"
case errors.Is(sErr, services.ErrOauthUnauthorizedClient):
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
desc = "未授权的客户端"
case errors.Is(sErr, services.ErrOauthUnsupportedGrantType):
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
desc = "不支持的授权类型"
}
if len(description) > 0 {

View File

@@ -2,6 +2,7 @@ package handlers
import (
"errors"
"platform/web/auth"
"platform/web/services"
"regexp"
"strconv"
@@ -16,6 +17,13 @@ type VerifierReq struct {
func SmsCode(c *fiber.Ctx) error {
_, err := auth.Protect(c, []services.PayloadType{
services.PayloadClientConfidential,
}, []string{})
if err != nil {
return err
}
// 解析请求参数
req := new(VerifierReq)
if err := c.BodyParser(req); err != nil {