From 3140d35a95e89cc5ca4544eeb6fb9e4c8a62f0f6 Mon Sep 17 00:00:00 2001 From: luorijun Date: Sat, 10 May 2025 13:38:47 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84=E9=94=99=E8=AF=AF=E5=A4=84?= =?UTF-8?q?=E7=90=86=E9=80=BB=E8=BE=91=EF=BC=8C=E4=BD=BF=E7=94=A8=20fiber.?= =?UTF-8?q?Error=20=E7=BB=9F=E4=B8=80=E8=BF=94=E5=9B=9E=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E7=8A=B6=E6=80=81=E7=A0=81=EF=BC=9B=E7=BB=9F=E4=B8=80=E6=8E=88?= =?UTF-8?q?=E6=9D=83=E6=9E=9A=E4=B8=BE=E5=80=BC=E5=AE=9A=E4=B9=89=E5=88=B0?= =?UTF-8?q?=20auth=20=E5=8C=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 14 ++++++----- web/auth/authorize.go | 18 ++++++++++++++ web/core/errors.go | 20 +++++++-------- web/handlers/auth.go | 38 ++++++++++++++-------------- web/handlers/user.go | 2 +- web/services/auth.go | 49 ++++++++++++------------------------- web/services/auth_test.go | 41 ++++++++++++++++--------------- web/services/channel.go | 2 +- web/services/transaction.go | 13 +++++++--- 9 files changed, 103 insertions(+), 94 deletions(-) create mode 100644 web/auth/authorize.go diff --git a/README.md b/README.md index 57be8de..75f154b 100644 --- a/README.md +++ b/README.md @@ -5,19 +5,21 @@ - 页面 账户总览 - 页面 提取记录 - 页面 使用记录 -- 代理数据表的 secret 字段 aes 加密存储 - -- 将 LocalDateTime 迁移到 orm -- globals 合并到 services 或者反之 -- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error - 公众号的到期提示 - 支付回调处理 - 保存 session 到数据库 +### 重构 + +- 将 LocalDateTime 迁移到 orm +- globals 合并到 services 或者反之 +- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error +- 增加 domain 层,缓解同包字段过长的问题 + ### 下阶段 -- 增加 domain 层,缓解同包字段过长的问题 +- 代理数据表的 secret 字段 aes 加密存储 - 扩展 device 权限验证方式,提供一种方法区分内部和外部服务 - 废弃 password 授权模式,迁移到 authorization code 授权模式 - oauth token 验证授权范围 diff --git a/web/auth/authorize.go b/web/auth/authorize.go new file mode 100644 index 0000000..cdcd3db --- /dev/null +++ b/web/auth/authorize.go @@ -0,0 +1,18 @@ +package auth + +type GrantType string + +const ( + GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式 + GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式 + GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式 + GrantPassword = GrantType("password") // 密码模式(私有扩展) +) + +type PasswordGrantType string + +const ( + GrantPasswordSecret = PasswordGrantType("password") // 密码模式 + GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式 + GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式 +) diff --git a/web/core/errors.go b/web/core/errors.go index f59decb..e94a713 100644 --- a/web/core/errors.go +++ b/web/core/errors.go @@ -1,19 +1,17 @@ package core -type UnAuthorizedErr string +import "github.com/gofiber/fiber/v2" -func (e UnAuthorizedErr) Error() string { - return string(e) +// ErrInvalid 返回 400 状态码的错误 +func ErrInvalid(message ...string) error { + return fiber.NewError(fiber.StatusBadRequest, message...) } -type ForbiddenErr string - -func (e ForbiddenErr) Error() string { - return string(e) +func ErrUnauthorized(message ...string) error { + return fiber.NewError(fiber.StatusUnauthorized, message...) } -type DataErr string - -func (e DataErr) Error() string { - return string(e) +// ErrForbidden 返回 403 状态码的错误 +func ErrForbidden(message ...string) error { + return fiber.NewError(fiber.StatusForbidden, message...) } diff --git a/web/handlers/auth.go b/web/handlers/auth.go index ea3c316..64b3b49 100644 --- a/web/handlers/auth.go +++ b/web/handlers/auth.go @@ -4,7 +4,7 @@ import ( "encoding/base64" "errors" "log/slog" - "platform/web/auth" + auth2 "platform/web/auth" client2 "platform/web/domains/client" m "platform/web/models" q "platform/web/queries" @@ -20,10 +20,10 @@ import ( // region /token type TokenReq struct { - GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"` - ClientID string `json:"client_id" form:"client_id"` - ClientSecret string `json:"client_secret" form:"client_secret"` - Scope string `json:"scope" form:"scope"` + GrantType auth2.GrantType `json:"grant_type" form:"grant_type"` + ClientID string `json:"client_id" form:"client_id"` + ClientSecret string `json:"client_secret" form:"client_secret"` + Scope string `json:"scope" form:"scope"` s.GrantCodeData s.GrantClientData s.GrantRefreshData @@ -64,7 +64,7 @@ func Token(c *fiber.Ctx) error { switch req.GrantType { // 授权码模式 - case s.OauthGrantTypeAuthorizationCode: + case auth2.GrantAuthorizationCode: if req.Code == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code") } @@ -82,7 +82,7 @@ func Token(c *fiber.Ctx) error { return sendSuccess(c, token) // 客户端凭证模式 - case s.OauthGrantTypeClientCredentials: + case auth2.GrantClientCredentials: client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret) if err != nil { return sendError(c, err) @@ -97,7 +97,7 @@ func Token(c *fiber.Ctx) error { return sendSuccess(c, token) // 刷新令牌模式 - case s.OauthGrantTypeRefreshToken: + case auth2.GrantRefreshToken: if req.RefreshToken == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token") } @@ -119,7 +119,7 @@ func Token(c *fiber.Ctx) error { return sendSuccess(c, token) // 密码模式 - case s.OauthGrantTypePassword: + case auth2.GrantPassword: if req.LoginType == "" { return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type") } @@ -148,7 +148,7 @@ func Token(c *fiber.Ctx) error { } // 检查客户端凭证 -func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) { +func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) { header := c.Get("Authorization") if header != "" { basic := strings.TrimPrefix(header, "Basic ") @@ -184,19 +184,19 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string // 验证授权类型 switch grant { - case s.OauthGrantTypeAuthorizationCode: + case auth2.GrantAuthorizationCode: if !client.GrantCode { return nil, s.ErrOauthUnauthorizedClient } - case s.OauthGrantTypeClientCredentials: + case auth2.GrantClientCredentials: if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) { return nil, s.ErrOauthUnauthorizedClient } - case s.OauthGrantTypeRefreshToken: + case auth2.GrantRefreshToken: if !client.GrantRefresh { return nil, s.ErrOauthUnauthorizedClient } - case s.OauthGrantTypePassword: + case auth2.GrantPassword: if !client.GrantPassword { return nil, s.ErrOauthUnauthorizedClient } @@ -213,10 +213,10 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string } // 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑) - auth.Locals(c, &auth.Context{ - Payload: auth.Payload{ + auth2.Locals(c, &auth2.Context{ + Payload: auth2.Payload{ Id: client.ID, - Type: auth.PayloadSecuredServer, + Type: auth2.PayloadSecuredServer, Name: client.Name, Avatar: client.Icon, }, @@ -279,7 +279,7 @@ type RevokeReq struct { } func Revoke(c *fiber.Ctx) error { - _, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) + _, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{}) if err != nil { // 用户未登录 return nil @@ -310,7 +310,7 @@ type IntrospectResp struct { func Introspect(c *fiber.Ctx) error { // 验证权限 - authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{}) + authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{}) if err != nil { return err } diff --git a/web/handlers/user.go b/web/handlers/user.go index d37bb39..b1e6b5f 100644 --- a/web/handlers/user.go +++ b/web/handlers/user.go @@ -115,7 +115,7 @@ func UpdatePassword(c *fiber.Ctx) error { // 验证手机令牌 if req.Phone == "" || req.Code == "" { - return core.NewErr("user", "手机号码和验证码不能为空") + return core.ErrInvalid("手机号码和验证码不能为空") } err = s.Verifier.VerifySms(c.Context(), req.Phone, req.Code) if err != nil { diff --git a/web/services/auth.go b/web/services/auth.go index 328b93e..96ed243 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -3,7 +3,7 @@ package services import ( "context" "errors" - "platform/web/auth" + auth2 "platform/web/auth" "platform/web/core" client2 "platform/web/domains/client" m "platform/web/models" @@ -26,12 +26,12 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie // OauthClientCredentials 验证客户端凭证 func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) { - var clientType auth.PayloadType + var clientType auth2.PayloadType switch client2.Spec(client.Spec) { case client2.SpecNative, client2.SpecBrowser: - clientType = auth.PayloadPublicServer + clientType = auth2.PayloadPublicServer case client2.SpecWeb, client2.SpecTrusted: - clientType = auth.PayloadSecuredServer + clientType = auth2.PayloadSecuredServer } var permissions = make(map[string]struct{}, len(scope)) @@ -40,9 +40,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie } // 保存会话并返回令牌 - authCtx := auth.Context{ + authCtx := auth2.Context{ Permissions: permissions, - Payload: auth.Payload{ + Payload: auth2.Payload{ Id: client.ID, Type: clientType, Name: client.Name, @@ -75,7 +75,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran err := q.Q.Transaction(func(tx *q.Query) error { switch data.LoginType { - case OauthGrantPasswordTypePhoneCode: + case auth2.GrantPasswordPhone: // 验证验证码 err := Verifier.VerifySms(ctx, data.Username, data.Password) if err != nil { @@ -91,13 +91,13 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } - case OauthGrantPasswordTypeEmailCode: + case auth2.GrantPasswordEmail: var err error user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take() if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return err } - case OauthGrantPasswordTypePassword: + case auth2.GrantPasswordSecret: var err error user, err = tx.User. Where(tx.User.Or( @@ -136,10 +136,10 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran } // 保存到会话 - authCtx := auth.Context{ - Payload: auth.Payload{ + authCtx := auth2.Context{ + Payload: auth2.Payload{ Id: user.ID, - Type: auth.PayloadUser, + Type: auth2.PayloadUser, Name: user.Name, Avatar: user.Avatar, }, @@ -167,29 +167,12 @@ type GrantRefreshData struct { } type GrantPasswordData struct { - LoginType OauthGrantLoginType `json:"login_type" form:"login_type"` - Username string `json:"username" form:"username"` - Password string `json:"password" form:"password"` - Remember bool `json:"remember" form:"remember"` + LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"` + Username string `json:"username" form:"username"` + Password string `json:"password" form:"password"` + Remember bool `json:"remember" form:"remember"` } -type OauthGrantType string - -const ( - OauthGrantTypeAuthorizationCode = OauthGrantType("authorization_code") - OauthGrantTypeClientCredentials = OauthGrantType("client_credentials") - OauthGrantTypeRefreshToken = OauthGrantType("refresh_token") - OauthGrantTypePassword = OauthGrantType("password") -) - -type OauthGrantLoginType string - -const ( - OauthGrantPasswordTypePassword = OauthGrantLoginType("password") - OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code") - OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code") -) - type AuthServiceError string func (e AuthServiceError) Error() string { diff --git a/web/services/auth_test.go b/web/services/auth_test.go index 458d094..fb0e1bb 100644 --- a/web/services/auth_test.go +++ b/web/services/auth_test.go @@ -2,6 +2,7 @@ package services import ( "context" + "platform/web/auth" "platform/web/models" "reflect" "testing" @@ -10,10 +11,10 @@ import ( // mockSessionService 用于模拟Session服务的行为 type mockSessionService struct { - createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error) + createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) } -func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) { +func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) { panic("implement me") } func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) { @@ -22,8 +23,8 @@ func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) ( func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { panic("implement me") } -func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { - return m.createFunc(ctx, auth) +func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) { + return m.createFunc(ctx, authCtx) } func Test_authService_OauthClientCredentials(t *testing.T) { @@ -52,7 +53,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) { mockCreateErr error want *TokenDetails wantErr bool - wantPayload Payload + wantPayload auth.Payload }{ { name: "成功 - 机密客户端 (Spec=0)", @@ -64,8 +65,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) { mockCreateErr: nil, want: expectedToken, wantErr: false, - wantPayload: Payload{ - Type: PayloadClientConfidential, + wantPayload: auth.Payload{ + Type: auth.PayloadSecuredServer, Id: 1, }, }, @@ -79,8 +80,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) { mockCreateErr: nil, want: expectedToken, wantErr: false, - wantPayload: Payload{ - Type: PayloadClientPublic, + wantPayload: auth.Payload{ + Type: auth.PayloadPublicServer, Id: 1, }, }, @@ -94,8 +95,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) { mockCreateErr: nil, want: expectedToken, wantErr: false, - wantPayload: Payload{ - Type: PayloadClientPublic, + wantPayload: auth.Payload{ + Type: auth.PayloadPublicServer, Id: 1, }, }, @@ -106,23 +107,23 @@ func Test_authService_OauthClientCredentials(t *testing.T) { // 为每个测试用例设置模拟的Session服务 mockSession := &mockSessionService{ - createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) { + createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) { // 验证权限映射 - if len(auth.Permissions) != len(tt.args.scope) { - t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope)) - for key := range auth.Permissions { - if _, ok := auth.Permissions[key]; !ok { + if len(authCtx.Permissions) != len(tt.args.scope) { + t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope)) + for key := range authCtx.Permissions { + if _, ok := authCtx.Permissions[key]; !ok { t.Errorf("Permissions[%s] not found", key) } } } // 验证Payload - if auth.Payload.Type != tt.wantPayload.Type { - t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type) + if authCtx.Payload.Type != tt.wantPayload.Type { + t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type) } - if auth.Payload.Id != tt.wantPayload.Id { - t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id) + if authCtx.Payload.Id != tt.wantPayload.Id { + t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id) } return expectedToken, tt.mockCreateErr diff --git a/web/services/channel.go b/web/services/channel.go index 2f011cc..d0600f6 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -67,7 +67,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte // 检查权限,如果为用户操作的话,则只能删除自己的通道 for _, channel := range channels { if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID { - return core.ForbiddenErr("无权限访问") + return core.ErrForbidden() } } diff --git a/web/services/transaction.go b/web/services/transaction.go index 432d961..658a5d8 100644 --- a/web/services/transaction.go +++ b/web/services/transaction.go @@ -107,7 +107,8 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query, // 调用支付宝支付接口 case trade2.MethodAlipay: resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{ - QRPayMode: "4", + QRPayMode: "4", + QRCodeWidth: "196", // 二维码宽度需要-4,支付宝页面布局有问题 Trade: alipay.Trade{ ProductCode: "FAST_INSTANT_TRADE_PAY", OutTradeNo: tradeNo, @@ -380,7 +381,13 @@ type TransactionCompleteResult struct { Trade *m.Trade } +type TransactionErr string + +func (e TransactionErr) Error() string { + return string(e) +} + var ( - ErrTransactionNotPaid = core.NewErr("transaction", "交易未完成") - ErrTransactionNotSupported = core.NewErr("transaction", "不支持的支付方式") + ErrTransactionNotPaid = TransactionErr("交易未支付") + ErrTransactionNotSupported = TransactionErr("不支持的支付方式") )