重构错误处理逻辑,使用 fiber.Error 统一返回错误状态码;统一授权枚举值定义到 auth 包
This commit is contained in:
14
README.md
14
README.md
@@ -5,19 +5,21 @@
|
|||||||
- 页面 账户总览
|
- 页面 账户总览
|
||||||
- 页面 提取记录
|
- 页面 提取记录
|
||||||
- 页面 使用记录
|
- 页面 使用记录
|
||||||
- 代理数据表的 secret 字段 aes 加密存储
|
|
||||||
|
|
||||||
- 将 LocalDateTime 迁移到 orm
|
|
||||||
- globals 合并到 services 或者反之
|
|
||||||
- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error
|
|
||||||
|
|
||||||
- 公众号的到期提示
|
- 公众号的到期提示
|
||||||
- 支付回调处理
|
- 支付回调处理
|
||||||
- 保存 session 到数据库
|
- 保存 session 到数据库
|
||||||
|
|
||||||
|
### 重构
|
||||||
|
|
||||||
|
- 将 LocalDateTime 迁移到 orm
|
||||||
|
- globals 合并到 services 或者反之
|
||||||
|
- 自定义的服务错误没有必要,可以统一在 handler 层使用包装的 fiber.Error
|
||||||
|
- 增加 domain 层,缓解同包字段过长的问题
|
||||||
|
|
||||||
### 下阶段
|
### 下阶段
|
||||||
|
|
||||||
- 增加 domain 层,缓解同包字段过长的问题
|
- 代理数据表的 secret 字段 aes 加密存储
|
||||||
- 扩展 device 权限验证方式,提供一种方法区分内部和外部服务
|
- 扩展 device 权限验证方式,提供一种方法区分内部和外部服务
|
||||||
- 废弃 password 授权模式,迁移到 authorization code 授权模式
|
- 废弃 password 授权模式,迁移到 authorization code 授权模式
|
||||||
- oauth token 验证授权范围
|
- oauth token 验证授权范围
|
||||||
|
|||||||
18
web/auth/authorize.go
Normal file
18
web/auth/authorize.go
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
type GrantType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
GrantAuthorizationCode = GrantType("authorization_code") // 授权码模式
|
||||||
|
GrantClientCredentials = GrantType("client_credentials") // 客户端凭证模式
|
||||||
|
GrantRefreshToken = GrantType("refresh_token") // 刷新令牌模式
|
||||||
|
GrantPassword = GrantType("password") // 密码模式(私有扩展)
|
||||||
|
)
|
||||||
|
|
||||||
|
type PasswordGrantType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
GrantPasswordSecret = PasswordGrantType("password") // 密码模式
|
||||||
|
GrantPasswordPhone = PasswordGrantType("phone_code") // 手机号模式
|
||||||
|
GrantPasswordEmail = PasswordGrantType("email_code") // 邮箱模式
|
||||||
|
)
|
||||||
@@ -1,19 +1,17 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
type UnAuthorizedErr string
|
import "github.com/gofiber/fiber/v2"
|
||||||
|
|
||||||
func (e UnAuthorizedErr) Error() string {
|
// ErrInvalid 返回 400 状态码的错误
|
||||||
return string(e)
|
func ErrInvalid(message ...string) error {
|
||||||
|
return fiber.NewError(fiber.StatusBadRequest, message...)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ForbiddenErr string
|
func ErrUnauthorized(message ...string) error {
|
||||||
|
return fiber.NewError(fiber.StatusUnauthorized, message...)
|
||||||
func (e ForbiddenErr) Error() string {
|
|
||||||
return string(e)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DataErr string
|
// ErrForbidden 返回 403 状态码的错误
|
||||||
|
func ErrForbidden(message ...string) error {
|
||||||
func (e DataErr) Error() string {
|
return fiber.NewError(fiber.StatusForbidden, message...)
|
||||||
return string(e)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"platform/web/auth"
|
auth2 "platform/web/auth"
|
||||||
client2 "platform/web/domains/client"
|
client2 "platform/web/domains/client"
|
||||||
m "platform/web/models"
|
m "platform/web/models"
|
||||||
q "platform/web/queries"
|
q "platform/web/queries"
|
||||||
@@ -20,10 +20,10 @@ import (
|
|||||||
// region /token
|
// region /token
|
||||||
|
|
||||||
type TokenReq struct {
|
type TokenReq struct {
|
||||||
GrantType s.OauthGrantType `json:"grant_type" form:"grant_type"`
|
GrantType auth2.GrantType `json:"grant_type" form:"grant_type"`
|
||||||
ClientID string `json:"client_id" form:"client_id"`
|
ClientID string `json:"client_id" form:"client_id"`
|
||||||
ClientSecret string `json:"client_secret" form:"client_secret"`
|
ClientSecret string `json:"client_secret" form:"client_secret"`
|
||||||
Scope string `json:"scope" form:"scope"`
|
Scope string `json:"scope" form:"scope"`
|
||||||
s.GrantCodeData
|
s.GrantCodeData
|
||||||
s.GrantClientData
|
s.GrantClientData
|
||||||
s.GrantRefreshData
|
s.GrantRefreshData
|
||||||
@@ -64,7 +64,7 @@ func Token(c *fiber.Ctx) error {
|
|||||||
switch req.GrantType {
|
switch req.GrantType {
|
||||||
|
|
||||||
// 授权码模式
|
// 授权码模式
|
||||||
case s.OauthGrantTypeAuthorizationCode:
|
case auth2.GrantAuthorizationCode:
|
||||||
if req.Code == "" {
|
if req.Code == "" {
|
||||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:code")
|
||||||
}
|
}
|
||||||
@@ -82,7 +82,7 @@ func Token(c *fiber.Ctx) error {
|
|||||||
return sendSuccess(c, token)
|
return sendSuccess(c, token)
|
||||||
|
|
||||||
// 客户端凭证模式
|
// 客户端凭证模式
|
||||||
case s.OauthGrantTypeClientCredentials:
|
case auth2.GrantClientCredentials:
|
||||||
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
client, err := protect(c, req.GrantType, req.ClientID, req.ClientSecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return sendError(c, err)
|
return sendError(c, err)
|
||||||
@@ -97,7 +97,7 @@ func Token(c *fiber.Ctx) error {
|
|||||||
return sendSuccess(c, token)
|
return sendSuccess(c, token)
|
||||||
|
|
||||||
// 刷新令牌模式
|
// 刷新令牌模式
|
||||||
case s.OauthGrantTypeRefreshToken:
|
case auth2.GrantRefreshToken:
|
||||||
if req.RefreshToken == "" {
|
if req.RefreshToken == "" {
|
||||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:refresh_token")
|
||||||
}
|
}
|
||||||
@@ -119,7 +119,7 @@ func Token(c *fiber.Ctx) error {
|
|||||||
return sendSuccess(c, token)
|
return sendSuccess(c, token)
|
||||||
|
|
||||||
// 密码模式
|
// 密码模式
|
||||||
case s.OauthGrantTypePassword:
|
case auth2.GrantPassword:
|
||||||
if req.LoginType == "" {
|
if req.LoginType == "" {
|
||||||
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
return sendError(c, s.ErrOauthInvalidRequest, "缺少必要参数:password_type")
|
||||||
}
|
}
|
||||||
@@ -148,7 +148,7 @@ func Token(c *fiber.Ctx) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 检查客户端凭证
|
// 检查客户端凭证
|
||||||
func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string) (*m.Client, error) {
|
func protect(c *fiber.Ctx, grant auth2.GrantType, clientId, clientSecret string) (*m.Client, error) {
|
||||||
header := c.Get("Authorization")
|
header := c.Get("Authorization")
|
||||||
if header != "" {
|
if header != "" {
|
||||||
basic := strings.TrimPrefix(header, "Basic ")
|
basic := strings.TrimPrefix(header, "Basic ")
|
||||||
@@ -184,19 +184,19 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
|
|||||||
|
|
||||||
// 验证授权类型
|
// 验证授权类型
|
||||||
switch grant {
|
switch grant {
|
||||||
case s.OauthGrantTypeAuthorizationCode:
|
case auth2.GrantAuthorizationCode:
|
||||||
if !client.GrantCode {
|
if !client.GrantCode {
|
||||||
return nil, s.ErrOauthUnauthorizedClient
|
return nil, s.ErrOauthUnauthorizedClient
|
||||||
}
|
}
|
||||||
case s.OauthGrantTypeClientCredentials:
|
case auth2.GrantClientCredentials:
|
||||||
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
|
if !client.GrantClient || client.Spec != int32(client2.SpecWeb) || client.Spec != int32(client2.SpecTrusted) {
|
||||||
return nil, s.ErrOauthUnauthorizedClient
|
return nil, s.ErrOauthUnauthorizedClient
|
||||||
}
|
}
|
||||||
case s.OauthGrantTypeRefreshToken:
|
case auth2.GrantRefreshToken:
|
||||||
if !client.GrantRefresh {
|
if !client.GrantRefresh {
|
||||||
return nil, s.ErrOauthUnauthorizedClient
|
return nil, s.ErrOauthUnauthorizedClient
|
||||||
}
|
}
|
||||||
case s.OauthGrantTypePassword:
|
case auth2.GrantPassword:
|
||||||
if !client.GrantPassword {
|
if !client.GrantPassword {
|
||||||
return nil, s.ErrOauthUnauthorizedClient
|
return nil, s.ErrOauthUnauthorizedClient
|
||||||
}
|
}
|
||||||
@@ -213,10 +213,10 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
|
// 保存 auth 信息到上下文(以兼容通用 auth 处理逻辑)
|
||||||
auth.Locals(c, &auth.Context{
|
auth2.Locals(c, &auth2.Context{
|
||||||
Payload: auth.Payload{
|
Payload: auth2.Payload{
|
||||||
Id: client.ID,
|
Id: client.ID,
|
||||||
Type: auth.PayloadSecuredServer,
|
Type: auth2.PayloadSecuredServer,
|
||||||
Name: client.Name,
|
Name: client.Name,
|
||||||
Avatar: client.Icon,
|
Avatar: client.Icon,
|
||||||
},
|
},
|
||||||
@@ -279,7 +279,7 @@ type RevokeReq struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Revoke(c *fiber.Ctx) error {
|
func Revoke(c *fiber.Ctx) error {
|
||||||
_, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
_, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 用户未登录
|
// 用户未登录
|
||||||
return nil
|
return nil
|
||||||
@@ -310,7 +310,7 @@ type IntrospectResp struct {
|
|||||||
|
|
||||||
func Introspect(c *fiber.Ctx) error {
|
func Introspect(c *fiber.Ctx) error {
|
||||||
// 验证权限
|
// 验证权限
|
||||||
authCtx, err := auth.Protect(c, []auth.PayloadType{auth.PayloadUser}, []string{})
|
authCtx, err := auth2.Protect(c, []auth2.PayloadType{auth2.PayloadUser}, []string{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -115,7 +115,7 @@ func UpdatePassword(c *fiber.Ctx) error {
|
|||||||
|
|
||||||
// 验证手机令牌
|
// 验证手机令牌
|
||||||
if req.Phone == "" || req.Code == "" {
|
if req.Phone == "" || req.Code == "" {
|
||||||
return core.NewErr("user", "手机号码和验证码不能为空")
|
return core.ErrInvalid("手机号码和验证码不能为空")
|
||||||
}
|
}
|
||||||
err = s.Verifier.VerifySms(c.Context(), req.Phone, req.Code)
|
err = s.Verifier.VerifySms(c.Context(), req.Phone, req.Code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package services
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"platform/web/auth"
|
auth2 "platform/web/auth"
|
||||||
"platform/web/core"
|
"platform/web/core"
|
||||||
client2 "platform/web/domains/client"
|
client2 "platform/web/domains/client"
|
||||||
m "platform/web/models"
|
m "platform/web/models"
|
||||||
@@ -26,12 +26,12 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *m.Clie
|
|||||||
// OauthClientCredentials 验证客户端凭证
|
// OauthClientCredentials 验证客户端凭证
|
||||||
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
|
func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Client, scope ...string) (*TokenDetails, error) {
|
||||||
|
|
||||||
var clientType auth.PayloadType
|
var clientType auth2.PayloadType
|
||||||
switch client2.Spec(client.Spec) {
|
switch client2.Spec(client.Spec) {
|
||||||
case client2.SpecNative, client2.SpecBrowser:
|
case client2.SpecNative, client2.SpecBrowser:
|
||||||
clientType = auth.PayloadPublicServer
|
clientType = auth2.PayloadPublicServer
|
||||||
case client2.SpecWeb, client2.SpecTrusted:
|
case client2.SpecWeb, client2.SpecTrusted:
|
||||||
clientType = auth.PayloadSecuredServer
|
clientType = auth2.PayloadSecuredServer
|
||||||
}
|
}
|
||||||
|
|
||||||
var permissions = make(map[string]struct{}, len(scope))
|
var permissions = make(map[string]struct{}, len(scope))
|
||||||
@@ -40,9 +40,9 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 保存会话并返回令牌
|
// 保存会话并返回令牌
|
||||||
authCtx := auth.Context{
|
authCtx := auth2.Context{
|
||||||
Permissions: permissions,
|
Permissions: permissions,
|
||||||
Payload: auth.Payload{
|
Payload: auth2.Payload{
|
||||||
Id: client.ID,
|
Id: client.ID,
|
||||||
Type: clientType,
|
Type: clientType,
|
||||||
Name: client.Name,
|
Name: client.Name,
|
||||||
@@ -75,7 +75,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
|||||||
err := q.Q.Transaction(func(tx *q.Query) error {
|
err := q.Q.Transaction(func(tx *q.Query) error {
|
||||||
|
|
||||||
switch data.LoginType {
|
switch data.LoginType {
|
||||||
case OauthGrantPasswordTypePhoneCode:
|
case auth2.GrantPasswordPhone:
|
||||||
// 验证验证码
|
// 验证验证码
|
||||||
err := Verifier.VerifySms(ctx, data.Username, data.Password)
|
err := Verifier.VerifySms(ctx, data.Username, data.Password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -91,13 +91,13 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
|||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case OauthGrantPasswordTypeEmailCode:
|
case auth2.GrantPasswordEmail:
|
||||||
var err error
|
var err error
|
||||||
user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take()
|
user, err = tx.User.Where(tx.User.Email.Eq(data.Username)).Take()
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
case OauthGrantPasswordTypePassword:
|
case auth2.GrantPasswordSecret:
|
||||||
var err error
|
var err error
|
||||||
user, err = tx.User.
|
user, err = tx.User.
|
||||||
Where(tx.User.Or(
|
Where(tx.User.Or(
|
||||||
@@ -136,10 +136,10 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 保存到会话
|
// 保存到会话
|
||||||
authCtx := auth.Context{
|
authCtx := auth2.Context{
|
||||||
Payload: auth.Payload{
|
Payload: auth2.Payload{
|
||||||
Id: user.ID,
|
Id: user.ID,
|
||||||
Type: auth.PayloadUser,
|
Type: auth2.PayloadUser,
|
||||||
Name: user.Name,
|
Name: user.Name,
|
||||||
Avatar: user.Avatar,
|
Avatar: user.Avatar,
|
||||||
},
|
},
|
||||||
@@ -167,29 +167,12 @@ type GrantRefreshData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type GrantPasswordData struct {
|
type GrantPasswordData struct {
|
||||||
LoginType OauthGrantLoginType `json:"login_type" form:"login_type"`
|
LoginType auth2.PasswordGrantType `json:"login_type" form:"login_type"`
|
||||||
Username string `json:"username" form:"username"`
|
Username string `json:"username" form:"username"`
|
||||||
Password string `json:"password" form:"password"`
|
Password string `json:"password" form:"password"`
|
||||||
Remember bool `json:"remember" form:"remember"`
|
Remember bool `json:"remember" form:"remember"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OauthGrantType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
OauthGrantTypeAuthorizationCode = OauthGrantType("authorization_code")
|
|
||||||
OauthGrantTypeClientCredentials = OauthGrantType("client_credentials")
|
|
||||||
OauthGrantTypeRefreshToken = OauthGrantType("refresh_token")
|
|
||||||
OauthGrantTypePassword = OauthGrantType("password")
|
|
||||||
)
|
|
||||||
|
|
||||||
type OauthGrantLoginType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
OauthGrantPasswordTypePassword = OauthGrantLoginType("password")
|
|
||||||
OauthGrantPasswordTypePhoneCode = OauthGrantLoginType("phone_code")
|
|
||||||
OauthGrantPasswordTypeEmailCode = OauthGrantLoginType("email_code")
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthServiceError string
|
type AuthServiceError string
|
||||||
|
|
||||||
func (e AuthServiceError) Error() string {
|
func (e AuthServiceError) Error() string {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package services
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"platform/web/auth"
|
||||||
"platform/web/models"
|
"platform/web/models"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -10,10 +11,10 @@ import (
|
|||||||
|
|
||||||
// mockSessionService 用于模拟Session服务的行为
|
// mockSessionService 用于模拟Session服务的行为
|
||||||
type mockSessionService struct {
|
type mockSessionService struct {
|
||||||
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error)
|
createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
|
func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
||||||
@@ -22,8 +23,8 @@ func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (
|
|||||||
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
|
func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
|
||||||
return m.createFunc(ctx, auth)
|
return m.createFunc(ctx, authCtx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_authService_OauthClientCredentials(t *testing.T) {
|
func Test_authService_OauthClientCredentials(t *testing.T) {
|
||||||
@@ -52,7 +53,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
|||||||
mockCreateErr error
|
mockCreateErr error
|
||||||
want *TokenDetails
|
want *TokenDetails
|
||||||
wantErr bool
|
wantErr bool
|
||||||
wantPayload Payload
|
wantPayload auth.Payload
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "成功 - 机密客户端 (Spec=0)",
|
name: "成功 - 机密客户端 (Spec=0)",
|
||||||
@@ -64,8 +65,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
|||||||
mockCreateErr: nil,
|
mockCreateErr: nil,
|
||||||
want: expectedToken,
|
want: expectedToken,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
wantPayload: Payload{
|
wantPayload: auth.Payload{
|
||||||
Type: PayloadClientConfidential,
|
Type: auth.PayloadSecuredServer,
|
||||||
Id: 1,
|
Id: 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -79,8 +80,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
|||||||
mockCreateErr: nil,
|
mockCreateErr: nil,
|
||||||
want: expectedToken,
|
want: expectedToken,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
wantPayload: Payload{
|
wantPayload: auth.Payload{
|
||||||
Type: PayloadClientPublic,
|
Type: auth.PayloadPublicServer,
|
||||||
Id: 1,
|
Id: 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -94,8 +95,8 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
|||||||
mockCreateErr: nil,
|
mockCreateErr: nil,
|
||||||
want: expectedToken,
|
want: expectedToken,
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
wantPayload: Payload{
|
wantPayload: auth.Payload{
|
||||||
Type: PayloadClientPublic,
|
Type: auth.PayloadPublicServer,
|
||||||
Id: 1,
|
Id: 1,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -106,23 +107,23 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
|||||||
|
|
||||||
// 为每个测试用例设置模拟的Session服务
|
// 为每个测试用例设置模拟的Session服务
|
||||||
mockSession := &mockSessionService{
|
mockSession := &mockSessionService{
|
||||||
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
|
createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) {
|
||||||
// 验证权限映射
|
// 验证权限映射
|
||||||
if len(auth.Permissions) != len(tt.args.scope) {
|
if len(authCtx.Permissions) != len(tt.args.scope) {
|
||||||
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope))
|
t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope))
|
||||||
for key := range auth.Permissions {
|
for key := range authCtx.Permissions {
|
||||||
if _, ok := auth.Permissions[key]; !ok {
|
if _, ok := authCtx.Permissions[key]; !ok {
|
||||||
t.Errorf("Permissions[%s] not found", key)
|
t.Errorf("Permissions[%s] not found", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证Payload
|
// 验证Payload
|
||||||
if auth.Payload.Type != tt.wantPayload.Type {
|
if authCtx.Payload.Type != tt.wantPayload.Type {
|
||||||
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type)
|
t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type)
|
||||||
}
|
}
|
||||||
if auth.Payload.Id != tt.wantPayload.Id {
|
if authCtx.Payload.Id != tt.wantPayload.Id {
|
||||||
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id)
|
t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
return expectedToken, tt.mockCreateErr
|
return expectedToken, tt.mockCreateErr
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
|
|||||||
// 检查权限,如果为用户操作的话,则只能删除自己的通道
|
// 检查权限,如果为用户操作的话,则只能删除自己的通道
|
||||||
for _, channel := range channels {
|
for _, channel := range channels {
|
||||||
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
|
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
|
||||||
return core.ForbiddenErr("无权限访问")
|
return core.ErrForbidden()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -107,7 +107,8 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query,
|
|||||||
// 调用支付宝支付接口
|
// 调用支付宝支付接口
|
||||||
case trade2.MethodAlipay:
|
case trade2.MethodAlipay:
|
||||||
resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{
|
resp, err := g.Alipay.TradePagePay(alipay.TradePagePay{
|
||||||
QRPayMode: "4",
|
QRPayMode: "4",
|
||||||
|
QRCodeWidth: "196", // 二维码宽度需要-4,支付宝页面布局有问题
|
||||||
Trade: alipay.Trade{
|
Trade: alipay.Trade{
|
||||||
ProductCode: "FAST_INSTANT_TRADE_PAY",
|
ProductCode: "FAST_INSTANT_TRADE_PAY",
|
||||||
OutTradeNo: tradeNo,
|
OutTradeNo: tradeNo,
|
||||||
@@ -380,7 +381,13 @@ type TransactionCompleteResult struct {
|
|||||||
Trade *m.Trade
|
Trade *m.Trade
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type TransactionErr string
|
||||||
|
|
||||||
|
func (e TransactionErr) Error() string {
|
||||||
|
return string(e)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrTransactionNotPaid = core.NewErr("transaction", "交易未完成")
|
ErrTransactionNotPaid = TransactionErr("交易未支付")
|
||||||
ErrTransactionNotSupported = core.NewErr("transaction", "不支持的支付方式")
|
ErrTransactionNotSupported = TransactionErr("不支持的支付方式")
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user