完善登录逻辑,登录接口统一到 /token
This commit is contained in:
@@ -51,7 +51,7 @@ func ListBill(c *fiber.Ctx) error {
|
||||
do = do.Where(q.Bill.BillNo.Eq(*req.BillNo))
|
||||
}
|
||||
|
||||
bills, err := do.Debug().
|
||||
bills, err := do.
|
||||
Preload(q.Bill.Resource, q.Bill.Trade, q.Bill.Refund).
|
||||
Preload(q.Bill.Resource.Pss).
|
||||
Order(q.Bill.CreatedAt.Desc()).
|
||||
|
||||
@@ -3,9 +3,10 @@ package handlers
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"platform/web/models"
|
||||
"log/slog"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/services"
|
||||
s "platform/web/services"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -17,22 +18,42 @@ import (
|
||||
// region Token
|
||||
|
||||
type TokenReq struct {
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
GrantType TokenGrantType `json:"grant_type" form:"grant_type"`
|
||||
Code string `json:"code" form:"code"`
|
||||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
|
||||
ClientID string `json:"client_id" form:"client_id"`
|
||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||
Scope string `json:"scope" form:"scope"`
|
||||
TokenReqCode
|
||||
TokenReqClient
|
||||
TokenReqRefresh
|
||||
TokenReqPassword
|
||||
}
|
||||
|
||||
type TokenReqCode struct {
|
||||
Code string `json:"code" form:"code"`
|
||||
RedirectURI string `json:"redirect_uri" form:"redirect_uri"`
|
||||
CodeVerifier string `json:"code_verifier" form:"code_verifier"`
|
||||
}
|
||||
|
||||
type TokenReqClient struct {
|
||||
}
|
||||
|
||||
type TokenReqRefresh struct {
|
||||
RefreshToken string `json:"refresh_token" form:"refresh_token"`
|
||||
}
|
||||
|
||||
type TokenReqPassword struct {
|
||||
LoginType s.OauthGrantLoginType `json:"login_type" form:"login_type"`
|
||||
Username string `json:"username" form:"username"`
|
||||
Password string `json:"password" form:"password"`
|
||||
Remember bool `json:"remember" form:"remember"`
|
||||
}
|
||||
|
||||
type TokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
TokenType string `json:"token_type"`
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type TokenErrResp struct {
|
||||
@@ -40,57 +61,57 @@ type TokenErrResp struct {
|
||||
Description string `json:"error_description,omitempty"`
|
||||
}
|
||||
|
||||
type TokenGrantType string
|
||||
|
||||
const (
|
||||
AuthorizationCode = TokenGrantType("authorization_code")
|
||||
ClientCredentials = TokenGrantType("client_credentials")
|
||||
RefreshToken = TokenGrantType("refresh_token")
|
||||
)
|
||||
|
||||
// Token 处理 OAuth2.0 授权请求
|
||||
func Token(c *fiber.Ctx) error {
|
||||
|
||||
// 验证请求参数
|
||||
req := new(TokenReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return sendError(c, services.ErrOauthInvalidRequest, "无法解析请求参数")
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "无法解析请求参数")
|
||||
}
|
||||
if req.GrantType == "" {
|
||||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:grant_type")
|
||||
}
|
||||
|
||||
slog.Debug("oauth token", slog.String("grant_type",
|
||||
string(req.GrantType)),
|
||||
slog.String("client_id", req.ClientID),
|
||||
)
|
||||
|
||||
// 基于授权类型处理请求
|
||||
switch req.GrantType {
|
||||
|
||||
case AuthorizationCode:
|
||||
case s.OauthGrantTypeAuthorizationCode:
|
||||
return authorizationCode(c, req)
|
||||
|
||||
case ClientCredentials:
|
||||
case s.OauthGrantTypeClientCredentials:
|
||||
return clientCredentials(c, req)
|
||||
|
||||
case RefreshToken:
|
||||
case s.OauthGrantTypeRefreshToken:
|
||||
return refreshToken(c, req)
|
||||
|
||||
case s.OauthGrantTypePassword:
|
||||
return password(c, req)
|
||||
|
||||
default:
|
||||
return sendError(c, services.ErrOauthUnsupportedGrantType)
|
||||
return sendError(c, s.ErrOauthUnsupportedGrantType)
|
||||
}
|
||||
}
|
||||
|
||||
// 授权码
|
||||
func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
|
||||
if req.Code == "" {
|
||||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||||
}
|
||||
|
||||
client, err := protect(c, services.GrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
|
||||
client, err := protect(c, s.OauthGrantTypeAuthorizationCode, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
token, err := services.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
token, err := s.Auth.OauthAuthorizationCode(c.Context(), client, req.Code, req.RedirectURI, req.CodeVerifier)
|
||||
if err != nil {
|
||||
return sendError(c, err.(services.AuthServiceOauthError))
|
||||
return sendError(c, err.(s.AuthServiceOauthError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
@@ -98,15 +119,15 @@ func authorizationCode(c *fiber.Ctx, req *TokenReq) error {
|
||||
|
||||
// 客户端凭证
|
||||
func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
|
||||
client, err := protect(c, services.GrantTypeClientCredentials, req.ClientID, req.ClientSecret)
|
||||
client, err := protect(c, s.OauthGrantTypeClientCredentials, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
||||
token, err := s.Auth.OauthClientCredentials(c.Context(), client, scope...)
|
||||
if err != nil {
|
||||
return sendError(c, err.(services.AuthServiceOauthError))
|
||||
return sendError(c, err.(s.AuthServiceOauthError))
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
@@ -115,19 +136,19 @@ func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
|
||||
// 刷新令牌
|
||||
func refreshToken(c *fiber.Ctx, req *TokenReq) error {
|
||||
if req.RefreshToken == "" {
|
||||
return sendError(c, services.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||||
}
|
||||
|
||||
client, err := protect(c, services.GrantTypeRefreshToken, req.ClientID, req.ClientSecret)
|
||||
client, err := protect(c, s.OauthGrantTypeRefreshToken, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
scope := strings.Split(req.Scope, ",")
|
||||
token, err := services.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
token, err := s.Auth.OauthRefreshToken(c.Context(), client, req.RefreshToken, scope)
|
||||
if err != nil {
|
||||
if errors.Is(err, services.ErrInvalidToken) {
|
||||
return sendError(c, services.ErrOauthInvalidGrant)
|
||||
if errors.Is(err, s.ErrInvalidToken) {
|
||||
return sendError(c, s.ErrOauthInvalidGrant)
|
||||
}
|
||||
return sendError(c, err)
|
||||
}
|
||||
@@ -135,8 +156,108 @@ func refreshToken(c *fiber.Ctx, req *TokenReq) error {
|
||||
return sendSuccess(c, token)
|
||||
}
|
||||
|
||||
func password(c *fiber.Ctx, req *TokenReq) error {
|
||||
if req.LoginType == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
||||
}
|
||||
if req.Username == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:username")
|
||||
}
|
||||
if req.Password == "" {
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password")
|
||||
}
|
||||
|
||||
// 验证客户端凭证
|
||||
_, err := protect(c, s.OauthGrantTypePassword, req.ClientID, req.ClientSecret)
|
||||
if err != nil {
|
||||
return sendError(c, err)
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
err = s.Verifier.VerifySms(c.Context(), req.Username, req.Password)
|
||||
if err != nil {
|
||||
if errors.Is(err, s.ErrVerifierServiceInvalid) {
|
||||
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// 查找用户
|
||||
var user *m.User
|
||||
err = q.Q.Transaction(func(tx *q.Query) error {
|
||||
|
||||
switch req.LoginType {
|
||||
case s.OauthGrantPasswordTypePhoneCode:
|
||||
user, err = tx.User.Where(tx.User.Phone.Eq(req.Username)).Take()
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
case s.OauthGrantPasswordTypeEmailCode:
|
||||
user, err = tx.User.Where(tx.User.Email.Eq(req.Username)).Take()
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
case s.OauthGrantPasswordTypePassword:
|
||||
user, err = tx.User.
|
||||
Where(tx.User.Or(
|
||||
tx.User.Phone.Eq(req.Username),
|
||||
tx.User.Email.Eq(req.Username),
|
||||
tx.User.Username.Eq(req.Username),
|
||||
)).
|
||||
Take()
|
||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return sendError(c, s.ErrOauthInvalidRequest, "无效的登录类型")
|
||||
}
|
||||
|
||||
// 如果用户不存在,初始化用户 todo 初始化默认权限信息
|
||||
if user == nil {
|
||||
user = &m.User{
|
||||
Phone: req.Username,
|
||||
Username: req.Username,
|
||||
}
|
||||
}
|
||||
|
||||
// 更新用户的登录时间
|
||||
user.LastLogin = time.Now()
|
||||
user.LastLoginHost = c.IP()
|
||||
user.LastLoginAgent = c.Get("User-Agent")
|
||||
if err := tx.User.Omit(q.User.AdminID).Save(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存到会话
|
||||
auth := s.AuthContext{
|
||||
Payload: s.Payload{
|
||||
Id: user.ID,
|
||||
Type: s.PayloadUser,
|
||||
Name: user.Name,
|
||||
Avatar: user.Avatar,
|
||||
},
|
||||
}
|
||||
|
||||
duration := s.DefaultSessionConfig
|
||||
if !req.Remember {
|
||||
duration.RefreshTokenDuration = 0
|
||||
}
|
||||
token, err := s.Session.Create(c.Context(), auth)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return sendSuccess(c, token)
|
||||
}
|
||||
|
||||
// 检查客户端凭证
|
||||
func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret string) (*models.Client, error) {
|
||||
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
|
||||
header := c.Get("Authorization")
|
||||
if header != "" {
|
||||
basic := strings.TrimPrefix(header, "Basic ")
|
||||
@@ -155,44 +276,48 @@ func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret stri
|
||||
|
||||
// 查找客户端
|
||||
if clientId == "" {
|
||||
return nil, services.ErrOauthInvalidRequest
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
client, err := q.Client.Where(q.Client.ClientID.Eq(clientId)).Take()
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, services.ErrOauthInvalidClient
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证客户端状态
|
||||
if client.Status != 1 {
|
||||
return nil, services.ErrOauthUnauthorizedClient
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
|
||||
// 验证授权类型
|
||||
switch grant {
|
||||
case services.GrantTypeAuthorizationCode:
|
||||
case s.OauthGrantTypeAuthorizationCode:
|
||||
if !client.GrantCode {
|
||||
return nil, services.ErrOauthUnauthorizedClient
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case services.GrantTypeClientCredentials:
|
||||
case s.OauthGrantTypeClientCredentials:
|
||||
if !client.GrantClient || client.Spec != 0 {
|
||||
return nil, services.ErrOauthUnauthorizedClient
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case services.GrantTypeRefreshToken:
|
||||
case s.OauthGrantTypeRefreshToken:
|
||||
if !client.GrantRefresh {
|
||||
return nil, services.ErrOauthUnauthorizedClient
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
case s.OauthGrantTypePassword:
|
||||
if !client.GrantPassword {
|
||||
return nil, s.ErrOauthUnauthorizedClient
|
||||
}
|
||||
}
|
||||
|
||||
// 如果客户端是 confidential,验证 client_secret,失败返回错误
|
||||
if client.Spec == 0 {
|
||||
if clientSecret == "" {
|
||||
return nil, services.ErrOauthInvalidRequest
|
||||
return nil, s.ErrOauthInvalidRequest
|
||||
}
|
||||
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
|
||||
return nil, services.ErrOauthInvalidClient
|
||||
return nil, s.ErrOauthInvalidClient
|
||||
}
|
||||
}
|
||||
|
||||
@@ -200,7 +325,7 @@ func protect(c *fiber.Ctx, grant services.GrantType, clientId, clientSecret stri
|
||||
}
|
||||
|
||||
// 发送成功响应
|
||||
func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
|
||||
func sendSuccess(c *fiber.Ctx, details *s.TokenDetails) error {
|
||||
return c.JSON(TokenResp{
|
||||
AccessToken: details.AccessToken,
|
||||
TokenType: "Bearer",
|
||||
@@ -211,23 +336,23 @@ func sendSuccess(c *fiber.Ctx, details *services.TokenDetails) error {
|
||||
|
||||
// 发送错误响应
|
||||
func sendError(c *fiber.Ctx, err error, description ...string) error {
|
||||
var sErr services.AuthServiceOauthError
|
||||
var sErr s.AuthServiceOauthError
|
||||
if errors.As(err, &sErr) {
|
||||
status := fiber.StatusBadRequest
|
||||
var desc string
|
||||
switch {
|
||||
case errors.Is(sErr, services.ErrOauthInvalidRequest):
|
||||
case errors.Is(sErr, s.ErrOauthInvalidRequest):
|
||||
desc = "无效的请求"
|
||||
case errors.Is(sErr, services.ErrOauthInvalidClient):
|
||||
case errors.Is(sErr, s.ErrOauthInvalidClient):
|
||||
status = fiber.StatusUnauthorized
|
||||
desc = "无效的客户端凭证"
|
||||
case errors.Is(sErr, services.ErrOauthInvalidGrant):
|
||||
case errors.Is(sErr, s.ErrOauthInvalidGrant):
|
||||
desc = "无效的授权凭证"
|
||||
case errors.Is(sErr, services.ErrOauthInvalidScope):
|
||||
case errors.Is(sErr, s.ErrOauthInvalidScope):
|
||||
desc = "无效的授权范围"
|
||||
case errors.Is(sErr, services.ErrOauthUnauthorizedClient):
|
||||
case errors.Is(sErr, s.ErrOauthUnauthorizedClient):
|
||||
desc = "未授权的客户端"
|
||||
case errors.Is(sErr, services.ErrOauthUnsupportedGrantType):
|
||||
case errors.Is(sErr, s.ErrOauthUnsupportedGrantType):
|
||||
desc = "不支持的授权类型"
|
||||
}
|
||||
if len(description) > 0 {
|
||||
|
||||
@@ -2,6 +2,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"platform/web/auth"
|
||||
"platform/web/services"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -16,6 +17,13 @@ type VerifierReq struct {
|
||||
|
||||
func SmsCode(c *fiber.Ctx) error {
|
||||
|
||||
_, err := auth.Protect(c, []services.PayloadType{
|
||||
services.PayloadClientConfidential,
|
||||
}, []string{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 解析请求参数
|
||||
req := new(VerifierReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
|
||||
Reference in New Issue
Block a user