重构错误处理逻辑,使用 fiber.Error 统一返回错误状态码;统一授权枚举值定义到 auth 包

This commit is contained in:
2025-05-10 13:38:47 +08:00
parent a06655ad29
commit 3140d35a95
9 changed files with 103 additions and 94 deletions

View File

@@ -3,7 +3,7 @@ package services
import (
"context"
"errors"
"platform/web/auth"
auth2 "platform/web/auth"
"platform/web/core"
client2 "platform/web/domains/client"
m "platform/web/models"
@@ -26,12 +26,12 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie
// OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
var clientType auth.PayloadType
var clientType auth2.PayloadType
switch client2.Spec(client.Spec) {
case client2.SpecNative, client2.SpecBrowser:
clientType = auth.PayloadPublicServer
clientType = auth2.PayloadPublicServer
case client2.SpecWeb, client2.SpecTrusted:
clientType = auth.PayloadSecuredServer
clientType = auth2.PayloadSecuredServer
}
var permissions = make(map[string]struct{}, len(scope))
@@ -40,9 +40,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
}
// 保存会话并返回令牌
authCtx := auth.Context{
authCtx := auth2.Context{
Permissions: permissions,
Payload: auth.Payload{
Payload: auth2.Payload{
Id: client.ID,
Type: clientType,
Name: client.Name,
@@ -75,7 +75,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
err := q.Q.Transaction(func(tx *q.Query) error {
switch data.LoginType {
case OauthGrantPasswordTypePhoneCode:
case auth2.GrantPasswordPhone:
// 验证验证码
err := Verifier.VerifySms(ctx, data.Username, data.Password)
if err != nil {
@@ -91,13 +91,13 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case OauthGrantPasswordTypeEmailCode:
case auth2.GrantPasswordEmail:
var err error
user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take()
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
case OauthGrantPasswordTypePassword:
case auth2.GrantPasswordSecret:
var err error
user, err = tx.User.
Where(tx.User.Or(
@@ -136,10 +136,10 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
}
// 保存到会话
authCtx := auth.Context{
Payload: auth.Payload{
authCtx := auth2.Context{
Payload: auth2.Payload{
Id: user.ID,
Type: auth.PayloadUser,
Type: auth2.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
@@ -167,29 +167,12 @@ type GrantRefreshData struct {
}
type GrantPasswordData struct {
LoginType OauthGrantLoginType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"`
Username string `json:"username" form:"username"`
Password string `json:"password" form:"password"`
Remember bool `json:"remember" form:"remember"`
}
type OauthGrantType string
const (
OauthGrantTypeAuthorizationCode = OauthGrantType("authorization_code")
OauthGrantTypeClientCredentials = OauthGrantType("client_credentials")
OauthGrantTypeRefreshToken = OauthGrantType("refresh_token")
OauthGrantTypePassword = OauthGrantType("password")
)
type OauthGrantLoginType string
const (
OauthGrantPasswordTypePassword = OauthGrantLoginType("password")
OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code")
OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code")
)
type AuthServiceError string
func (e AuthServiceError) Error() string {

View File

@@ -2,6 +2,7 @@ package services
import (
"context"
"platform/web/auth"
"platform/web/models"
"reflect"
"testing"
@@ -10,10 +11,10 @@ import (
// mockSessionService 用于模拟Session服务的行为
type mockSessionService struct {
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error)
createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error)
}
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
panic("implement me")
}
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
@@ -22,8 +23,8 @@ func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
panic("implement me")
}
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
return m.createFunc(ctx, auth)
func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
return m.createFunc(ctx, authCtx)
}
func Test_authService_OauthClientCredentials(t *testing.T) {
@@ -52,7 +53,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr error
want *TokenDetails
wantErr bool
wantPayload Payload
wantPayload auth.Payload
}{
{
name: "成功 - 机密客户端 (Spec=0)",
@@ -64,8 +65,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientConfidential,
wantPayload: auth.Payload{
Type: auth.PayloadSecuredServer,
Id: 1,
},
},
@@ -79,8 +80,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
wantPayload: auth.Payload{
Type: auth.PayloadPublicServer,
Id: 1,
},
},
@@ -94,8 +95,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
wantPayload: auth.Payload{
Type: auth.PayloadPublicServer,
Id: 1,
},
},
@@ -106,23 +107,23 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
// 为每个测试用例设置模拟的Session服务
mockSession := &mockSessionService{
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) {
// 验证权限映射
if len(auth.Permissions) != len(tt.args.scope) {
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope))
for key := range auth.Permissions {
if _, ok := auth.Permissions[key]; !ok {
if len(authCtx.Permissions) != len(tt.args.scope) {
t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope))
for key := range authCtx.Permissions {
if _, ok := authCtx.Permissions[key]; !ok {
t.Errorf("Permissions[%s] not found", key)
}
}
}
// 验证Payload
if auth.Payload.Type != tt.wantPayload.Type {
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type)
if authCtx.Payload.Type != tt.wantPayload.Type {
t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type)
}
if auth.Payload.Id != tt.wantPayload.Id {
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id)
if authCtx.Payload.Id != tt.wantPayload.Id {
t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id)
}
return expectedToken, tt.mockCreateErr

View File

@@ -67,7 +67,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
// 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels {
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
return core.ForbiddenErr("无权限访问")
return core.ErrForbidden()
}
}

View File

@@ -107,7 +107,8 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query,
// 调用支付宝支付接口
case trade2.MethodAlipay:
resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{
QRPayMode: "4",
QRPayMode: "4",
QRCodeWidth: "196", // 二维码宽度需要-4支付宝页面布局有问题
Trade: alipay.Trade{
ProductCode: "FAST_INSTANT_TRADE_PAY",
OutTradeNo: tradeNo,
@@ -380,7 +381,13 @@ type TransactionCompleteResult struct {
Trade *m.Trade
}
type TransactionErr string
func (e TransactionErr) Error() string {
return string(e)
}
var (
ErrTransactionNotPaid = core.NewErr("transaction", "交易未完成")
ErrTransactionNotSupported = core.NewErr("transaction", "不支持的支付方式")
ErrTransactionNotPaid = TransactionErr("交易未支付")
ErrTransactionNotSupported = TransactionErr("不支持的支付方式")
)