重构错误处理逻辑,使用 fiber.Error 统一返回错误状态码;统一授权枚举值定义到 auth 包

This commit is contained in:
2025-05-10 13:38:47 +08:00
parent a06655ad29
commit 3140d35a95
9 changed files with 103 additions and 94 deletions

View File

@@ -5,19 +5,21 @@
- 页面 账户总览 - 页面 账户总览
- 页面 提取记录 - 页面 提取记录
- 页面 使用记录 - 页面 使用记录
- 代理数据表的 secret 字段 aes 加密存储
- 将 LocalDateTime 迁移到 orm
- globals 合并到 services 或者反之
- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error
- 公众号的到期提示 - 公众号的到期提示
- 支付回调处理 - 支付回调处理
- 保存 session 到数据库 - 保存 session 到数据库
### 重构
- 将 LocalDateTime 迁移到 orm
- globals 合并到 services 或者反之
- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error
- 增加 domain 层,缓解同包字段过长的问题
### 下阶段 ### 下阶段
- 增加 domain 层,缓解同包字段过长的问题 - 代理数据表的 secret 字段 aes 加密存储
- 扩展 device 权限验证方式,提供一种方法区分内部和外部服务 - 扩展 device 权限验证方式,提供一种方法区分内部和外部服务
- 废弃 password 授权模式,迁移到 authorization code 授权模式 - 废弃 password 授权模式,迁移到 authorization code 授权模式
- oauth token 验证授权范围 - oauth token 验证授权范围

18
web/auth/authorize.go Normal file
View File

@@ -0,0 +1,18 @@
package auth
type GrantType string
const (
GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式
GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式
GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式
GrantPassword = GrantType("password") // 密码模式(私有扩展)
)
type PasswordGrantType string
const (
GrantPasswordSecret = PasswordGrantType("password") // 密码模式
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式
)

View File

@@ -1,19 +1,17 @@
package core package core
type UnAuthorizedErr string import "github.com/gofiber/fiber/v2"
func (e UnAuthorizedErr) Error() string { // ErrInvalid 返回 400 状态码的错误
return string(e) func ErrInvalid(message ...string) error {
return fiber.NewError(fiber.StatusBadRequest, message...)
} }
type ForbiddenErr string func ErrUnauthorized(message ...string) error {
return fiber.NewError(fiber.StatusUnauthorized, message...)
func (e ForbiddenErr) Error() string {
return string(e)
} }
type DataErr string // ErrForbidden 返回 403 状态码的错误
func ErrForbidden(message ...string) error {
func (e DataErr) Error() string { return fiber.NewError(fiber.StatusForbidden, message...)
return string(e)
} }

View File

@@ -4,7 +4,7 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"log/slog" "log/slog"
"platform/web/auth" auth2 "platform/web/auth"
client2 "platform/web/domains/client" client2 "platform/web/domains/client"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
@@ -20,10 +20,10 @@ import (
// region /token // region /token
type TokenReq struct { type TokenReq struct {
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"` GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
ClientID string `json:"client_id" form:"client_id"` ClientID string `json:"client_id" form:"client_id"`
ClientSecret string `json:"client_secret" form:"client_secret"` ClientSecret string `json:"client_secret" form:"client_secret"`
Scope string `json:"scope" form:"scope"` Scope string `json:"scope" form:"scope"`
s.GrantCodeData s.GrantCodeData
s.GrantClientData s.GrantClientData
s.GrantRefreshData s.GrantRefreshData
@@ -64,7 +64,7 @@ func Token(c *fiber.Ctx) error {
switch req.GrantType { switch req.GrantType {
// 授权码模式 // 授权码模式
case s.OauthGrantTypeAuthorizationCode: case auth2.GrantAuthorizationCode:
if req.Code == "" { if req.Code == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code") return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数code")
} }
@@ -82,7 +82,7 @@ func Token(c *fiber.Ctx) error {
return sendSuccess(c, token) return sendSuccess(c, token)
// 客户端凭证模式 // 客户端凭证模式
case s.OauthGrantTypeClientCredentials: case auth2.GrantClientCredentials:
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
if err != nil { if err != nil {
return sendError(c, err) return sendError(c, err)
@@ -97,7 +97,7 @@ func Token(c *fiber.Ctx) error {
return sendSuccess(c, token) return sendSuccess(c, token)
// 刷新令牌模式 // 刷新令牌模式
case s.OauthGrantTypeRefreshToken: case auth2.GrantRefreshToken:
if req.RefreshToken == "" { if req.RefreshToken == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token") return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数refresh_token")
} }
@@ -119,7 +119,7 @@ func Token(c *fiber.Ctx) error {
return sendSuccess(c, token) return sendSuccess(c, token)
// 密码模式 // 密码模式
case s.OauthGrantTypePassword: case auth2.GrantPassword:
if req.LoginType == "" { if req.LoginType == "" {
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type") return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数password_type")
} }
@@ -148,7 +148,7 @@ func Token(c *fiber.Ctx) error {
} }
// 检查客户端凭证 // 检查客户端凭证
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) { func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
header := c.Get("Authorization") header := c.Get("Authorization")
if header != "" { if header != "" {
basic := strings.TrimPrefix(header, "Basic ") basic := strings.TrimPrefix(header, "Basic ")
@@ -184,19 +184,19 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
// 验证授权类型 // 验证授权类型
switch grant { switch grant {
case s.OauthGrantTypeAuthorizationCode: case auth2.GrantAuthorizationCode:
if !client.GrantCode { if !client.GrantCode {
return nil, s.ErrOauthUnauthorizedClient return nil, s.ErrOauthUnauthorizedClient
} }
case s.OauthGrantTypeClientCredentials: case auth2.GrantClientCredentials:
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) { if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
return nil, s.ErrOauthUnauthorizedClient return nil, s.ErrOauthUnauthorizedClient
} }
case s.OauthGrantTypeRefreshToken: case auth2.GrantRefreshToken:
if !client.GrantRefresh { if !client.GrantRefresh {
return nil, s.ErrOauthUnauthorizedClient return nil, s.ErrOauthUnauthorizedClient
} }
case s.OauthGrantTypePassword: case auth2.GrantPassword:
if !client.GrantPassword { if !client.GrantPassword {
return nil, s.ErrOauthUnauthorizedClient return nil, s.ErrOauthUnauthorizedClient
} }
@@ -213,10 +213,10 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
} }
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑) // 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
auth.Locals(c, &auth.Context{ auth2.Locals(c, &auth2.Context{
Payload: auth.Payload{ Payload: auth2.Payload{
Id: client.ID, Id: client.ID,
Type: auth.PayloadSecuredServer, Type: auth2.PayloadSecuredServer,
Name: client.Name, Name: client.Name,
Avatar: client.Icon, Avatar: client.Icon,
}, },
@@ -279,7 +279,7 @@ type RevokeReq struct {
} }
func Revoke(c *fiber.Ctx) error { func Revoke(c *fiber.Ctx) error {
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) _, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
if err != nil { if err != nil {
// 用户未登录 // 用户未登录
return nil return nil
@@ -310,7 +310,7 @@ type IntrospectResp struct {
func Introspect(c *fiber.Ctx) error { func Introspect(c *fiber.Ctx) error {
// 验证权限 // 验证权限
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -115,7 +115,7 @@ func UpdatePassword(c *fiber.Ctx) error {
// 验证手机令牌 // 验证手机令牌
if req.Phone == "" || req.Code == "" { if req.Phone == "" || req.Code == "" {
return core.NewErr("user", "手机号码和验证码不能为空") return core.ErrInvalid("手机号码和验证码不能为空")
} }
err = s.Verifier.VerifySms(c.Context(), req.Phone, req.Code) err = s.Verifier.VerifySms(c.Context(), req.Phone, req.Code)
if err != nil { if err != nil {

View File

@@ -3,7 +3,7 @@ package services
import ( import (
"context" "context"
"errors" "errors"
"platform/web/auth" auth2 "platform/web/auth"
"platform/web/core" "platform/web/core"
client2 "platform/web/domains/client" client2 "platform/web/domains/client"
m "platform/web/models" m "platform/web/models"
@@ -26,12 +26,12 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie
// OauthClientCredentials 验证客户端凭证 // OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
var clientType auth.PayloadType var clientType auth2.PayloadType
switch client2.Spec(client.Spec) { switch client2.Spec(client.Spec) {
case client2.SpecNative, client2.SpecBrowser: case client2.SpecNative, client2.SpecBrowser:
clientType = auth.PayloadPublicServer clientType = auth2.PayloadPublicServer
case client2.SpecWeb, client2.SpecTrusted: case client2.SpecWeb, client2.SpecTrusted:
clientType = auth.PayloadSecuredServer clientType = auth2.PayloadSecuredServer
} }
var permissions = make(map[string]struct{}, len(scope)) var permissions = make(map[string]struct{}, len(scope))
@@ -40,9 +40,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
} }
// 保存会话并返回令牌 // 保存会话并返回令牌
authCtx := auth.Context{ authCtx := auth2.Context{
Permissions: permissions, Permissions: permissions,
Payload: auth.Payload{ Payload: auth2.Payload{
Id: client.ID, Id: client.ID,
Type: clientType, Type: clientType,
Name: client.Name, Name: client.Name,
@@ -75,7 +75,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
err := q.Q.Transaction(func(tx *q.Query) error { err := q.Q.Transaction(func(tx *q.Query) error {
switch data.LoginType { switch data.LoginType {
case OauthGrantPasswordTypePhoneCode: case auth2.GrantPasswordPhone:
// 验证验证码 // 验证验证码
err := Verifier.VerifySms(ctx, data.Username, data.Password) err := Verifier.VerifySms(ctx, data.Username, data.Password)
if err != nil { if err != nil {
@@ -91,13 +91,13 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err return err
} }
case OauthGrantPasswordTypeEmailCode: case auth2.GrantPasswordEmail:
var err error var err error
user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take() user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err return err
} }
case OauthGrantPasswordTypePassword: case auth2.GrantPasswordSecret:
var err error var err error
user, err = tx.User. user, err = tx.User.
Where(tx.User.Or( Where(tx.User.Or(
@@ -136,10 +136,10 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
} }
// 保存到会话 // 保存到会话
authCtx := auth.Context{ authCtx := auth2.Context{
Payload: auth.Payload{ Payload: auth2.Payload{
Id: user.ID, Id: user.ID,
Type: auth.PayloadUser, Type: auth2.PayloadUser,
Name: user.Name, Name: user.Name,
Avatar: user.Avatar, Avatar: user.Avatar,
}, },
@@ -167,29 +167,12 @@ type GrantRefreshData struct {
} }
type GrantPasswordData struct { type GrantPasswordData struct {
LoginType OauthGrantLoginType `json:"login_type" form:"login_type"` LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"` Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"` Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"` Remember bool `json:"remember" form:"remember"`
} }
type OauthGrantType string
const (
OauthGrantTypeAuthorizationCode = OauthGrantType("authorization_code")
OauthGrantTypeClientCredentials = OauthGrantType("client_credentials")
OauthGrantTypeRefreshToken = OauthGrantType("refresh_token")
OauthGrantTypePassword = OauthGrantType("password")
)
type OauthGrantLoginType string
const (
OauthGrantPasswordTypePassword = OauthGrantLoginType("password")
OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code")
OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code")
)
type AuthServiceError string type AuthServiceError string
func (e AuthServiceError) Error() string { func (e AuthServiceError) Error() string {

View File

@@ -2,6 +2,7 @@ package services
import ( import (
"context" "context"
"platform/web/auth"
"platform/web/models" "platform/web/models"
"reflect" "reflect"
"testing" "testing"
@@ -10,10 +11,10 @@ import (
// mockSessionService 用于模拟Session服务的行为 // mockSessionService 用于模拟Session服务的行为
type mockSessionService struct { type mockSessionService struct {
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error) createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error)
} }
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) { func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
panic("implement me") panic("implement me")
} }
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) { func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
@@ -22,8 +23,8 @@ func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
panic("implement me") panic("implement me")
} }
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
return m.createFunc(ctx, auth) return m.createFunc(ctx, authCtx)
} }
func Test_authService_OauthClientCredentials(t *testing.T) { func Test_authService_OauthClientCredentials(t *testing.T) {
@@ -52,7 +53,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr error mockCreateErr error
want *TokenDetails want *TokenDetails
wantErr bool wantErr bool
wantPayload Payload wantPayload auth.Payload
}{ }{
{ {
name: "成功 - 机密客户端 (Spec=0)", name: "成功 - 机密客户端 (Spec=0)",
@@ -64,8 +65,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil, mockCreateErr: nil,
want: expectedToken, want: expectedToken,
wantErr: false, wantErr: false,
wantPayload: Payload{ wantPayload: auth.Payload{
Type: PayloadClientConfidential, Type: auth.PayloadSecuredServer,
Id: 1, Id: 1,
}, },
}, },
@@ -79,8 +80,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil, mockCreateErr: nil,
want: expectedToken, want: expectedToken,
wantErr: false, wantErr: false,
wantPayload: Payload{ wantPayload: auth.Payload{
Type: PayloadClientPublic, Type: auth.PayloadPublicServer,
Id: 1, Id: 1,
}, },
}, },
@@ -94,8 +95,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil, mockCreateErr: nil,
want: expectedToken, want: expectedToken,
wantErr: false, wantErr: false,
wantPayload: Payload{ wantPayload: auth.Payload{
Type: PayloadClientPublic, Type: auth.PayloadPublicServer,
Id: 1, Id: 1,
}, },
}, },
@@ -106,23 +107,23 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
// 为每个测试用例设置模拟的Session服务 // 为每个测试用例设置模拟的Session服务
mockSession := &mockSessionService{ mockSession := &mockSessionService{
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) { createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) {
// 验证权限映射 // 验证权限映射
if len(auth.Permissions) != len(tt.args.scope) { if len(authCtx.Permissions) != len(tt.args.scope) {
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope)) t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope))
for key := range auth.Permissions { for key := range authCtx.Permissions {
if _, ok := auth.Permissions[key]; !ok { if _, ok := authCtx.Permissions[key]; !ok {
t.Errorf("Permissions[%s] not found", key) t.Errorf("Permissions[%s] not found", key)
} }
} }
} }
// 验证Payload // 验证Payload
if auth.Payload.Type != tt.wantPayload.Type { if authCtx.Payload.Type != tt.wantPayload.Type {
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type) t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type)
} }
if auth.Payload.Id != tt.wantPayload.Id { if authCtx.Payload.Id != tt.wantPayload.Id {
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id) t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id)
} }
return expectedToken, tt.mockCreateErr return expectedToken, tt.mockCreateErr

View File

@@ -67,7 +67,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
// 检查权限,如果为用户操作的话,则只能删除自己的通道 // 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels { for _, channel := range channels {
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID { if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
return core.ForbiddenErr("无权限访问") return core.ErrForbidden()
} }
} }

View File

@@ -107,7 +107,8 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query,
// 调用支付宝支付接口 // 调用支付宝支付接口
case trade2.MethodAlipay: case trade2.MethodAlipay:
resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{ resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{
QRPayMode: "4", QRPayMode: "4",
QRCodeWidth: "196", // 二维码宽度需要-4支付宝页面布局有问题
Trade: alipay.Trade{ Trade: alipay.Trade{
ProductCode: "FAST_INSTANT_TRADE_PAY", ProductCode: "FAST_INSTANT_TRADE_PAY",
OutTradeNo: tradeNo, OutTradeNo: tradeNo,
@@ -380,7 +381,13 @@ type TransactionCompleteResult struct {
Trade *m.Trade Trade *m.Trade
} }
type TransactionErr string
func (e TransactionErr) Error() string {
return string(e)
}
var ( var (
ErrTransactionNotPaid = core.NewErr("transaction", "交易未完成") ErrTransactionNotPaid = TransactionErr("交易未支付")
ErrTransactionNotSupported = core.NewErr("transaction", "不支持的支付方式") ErrTransactionNotSupported = TransactionErr("不支持的支付方式")
) )