diff --git a/README.md b/README.md index 902961c..7507856 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,7 @@ ## todo -- 微信支付 - 错误处理类型转换失败问题 -- 检查数据库枚举字段,0 值只作为空值使用 +- 微信支付 - 移动端适配 - channel 接口 - 重新梳理逻辑流程,简化循环 diff --git a/cmd/fill/main.go b/cmd/fill/main.go index a12bbde..d085e8d 100644 --- a/cmd/fill/main.go +++ b/cmd/fill/main.go @@ -84,14 +84,14 @@ func main() { ClientSecret: string(testSecret), GrantClient: true, GrantRefresh: true, - Spec: 0, + Spec: 3, Name: "默认客户端", }, &m.Client{ ClientID: "tasks", ClientSecret: string(tasksSecret), GrantClient: true, GrantRefresh: true, - Spec: 0, + Spec: 3, Name: "异步任务处理服务", }) return nil diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index a5726c0..533cb0b 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -51,7 +51,7 @@ comment on column admin.name is '真实姓名'; comment on column admin.avatar is '头像URL'; comment on column admin.phone is '手机号码'; comment on column admin.email is '邮箱'; -comment on column admin.status is '状态:1-正常,0-禁用'; +comment on column admin.status is '状态:0-禁用,1-正常'; comment on column admin.last_login is '最后登录时间'; comment on column admin.last_login_host is '最后登录地址'; comment on column admin.last_login_agent is '最后登录代理'; @@ -132,7 +132,7 @@ comment on column "user".username is '用户名'; comment on column "user".phone is '手机号码'; comment on column "user".name is '真实姓名'; comment on column "user".avatar is '头像URL'; -comment on column "user".status is '用户状态:1-正常,0-禁用'; +comment on column "user".status is '用户状态:0-禁用,1-正常'; comment on column "user".balance is '账户余额'; comment on column "user".id_type is '认证类型:0-未认证,1-个人认证,2-企业认证'; comment on column "user".id_no is '身份证号或营业执照号'; @@ -211,10 +211,10 @@ comment on column client.grant_code is '允许授权码授予'; comment on column client.grant_client is '允许客户端凭证授予'; comment on column client.grant_refresh is '允许刷新令牌授予'; comment on column client.grant_password is '允许密码授予'; -comment on column client.spec is '安全规范:0-web,1-native,2-browser'; +comment on column client.spec is '安全规范:1-native,2-browser,3-web'; comment on column client.name is '名称'; comment on column client.icon is '图标URL'; -comment on column client.status is '状态:1-正常,0-禁用'; +comment on column client.status is '状态:0-禁用,1-正常'; comment on column client.created_at is '创建时间'; comment on column client.updated_at is '更新时间'; comment on column client.deleted_at is '删除时间'; @@ -416,7 +416,7 @@ comment on column proxy.id is '代理服务ID'; comment on column proxy.version is '代理服务版本'; comment on column proxy.name is '代理服务名称'; comment on column proxy.host is '代理服务地址'; -comment on column proxy.type is '代理服务类型:0-自有,1-三方'; +comment on column proxy.type is '代理服务类型:1-三方,2-自有'; comment on column proxy.secret is '代理服务密钥'; comment on column proxy.created_at is '创建时间'; comment on column proxy.updated_at is '更新时间'; @@ -455,7 +455,7 @@ comment on column node.id is '节点ID'; comment on column node.version is '节点版本'; comment on column node.name is '节点名称'; comment on column node.host is '节点地址'; -comment on column node.isp is '运营商:0-其他,1-电信,2-联通,3-移动'; +comment on column node.isp is '运营商:0-未知,1-电信,2-联通,3-移动'; comment on column node.prov is '省份'; comment on column node.city is '城市'; comment on column node.proxy_id is '代理ID'; @@ -582,7 +582,7 @@ comment on column product.code is '产品代码'; comment on column product.name is '产品名称'; comment on column product.description is '产品描述'; comment on column product.sort is '排序'; -comment on column product.status is '产品状态:1-正常,0-禁用'; +comment on column product.status is '产品状态:0-禁用,1-正常'; comment on column product.created_at is '创建时间'; comment on column product.updated_at is '更新时间'; comment on column product.deleted_at is '删除时间'; @@ -728,7 +728,7 @@ comment on column trade.id is '订单ID'; comment on column trade.user_id is '用户ID'; comment on column trade.inner_no is '内部订单号'; comment on column trade.outer_no is '外部订单号'; -comment on column trade.type is '订单类型:0-充值余额,1-购买产品'; +comment on column trade.type is '订单类型:1-购买产品,2-充值余额'; comment on column trade.subject is '订单主题'; comment on column trade.remark is '订单备注'; comment on column trade.amount is '订单总金额'; @@ -816,7 +816,7 @@ comment on column bill.resource_id is '套餐ID'; comment on column bill.refund_id is '退款ID'; comment on column bill.bill_no is '易读账单号'; comment on column bill.info is '产品可读信息'; -comment on column bill.type is '账单类型:0-充值,1-消费,2-退款'; +comment on column bill.type is '账单类型:1-消费,2-退款,3-充值'; comment on column bill.amount is '账单金额'; comment on column bill.created_at is '创建时间'; comment on column bill.updated_at is '更新时间'; diff --git a/web/auth/auth.go b/web/auth/auth.go index 5f20bb4..18af515 100644 --- a/web/auth/auth.go +++ b/web/auth/auth.go @@ -100,7 +100,7 @@ func authBasic(_ context.Context, token string) (*services.AuthContext, error) { client, err := q.Client. Where( q.Client.ClientID.Eq(clientID), - q.Client.Spec.Eq(0), + q.Client.Spec.Eq(3), q.Client.GrantClient.Is(true), q.Client.Status.Eq(1)). Take() @@ -108,16 +108,6 @@ func authBasic(_ context.Context, token string) (*services.AuthContext, error) { return nil, err } - // 检查客户端状态 - if client.Status != 1 { - return nil, errors.New("客户端已被禁用") - } - - // 检查客户端类型 - if client.Spec != 0 { - return nil, errors.New("客户端类型错误") - } - // 检查客户端密钥 var clientSecret = split[1] if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { diff --git a/web/handlers/auth.go b/web/handlers/auth.go index 135dbb4..9de8040 100644 --- a/web/handlers/auth.go +++ b/web/handlers/auth.go @@ -188,7 +188,7 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string return nil, s.ErrOauthUnauthorizedClient } case s.OauthGrantTypeClientCredentials: - if !client.GrantClient || client.Spec != 0 { + if !client.GrantClient || client.Spec != 3 { return nil, s.ErrOauthUnauthorizedClient } case s.OauthGrantTypeRefreshToken: @@ -202,7 +202,7 @@ func protect(c *fiber.Ctx, grant s.OauthGrantType, clientId, clientSecret string } // 如果客户端是 confidential,验证 client_secret,失败返回错误 - if client.Spec == 0 { + if client.Spec == 3 { if clientSecret == "" { return nil, s.ErrOauthInvalidRequest } diff --git a/web/handlers/trade.go b/web/handlers/trade.go index 80ad289..18bb126 100644 --- a/web/handlers/trade.go +++ b/web/handlers/trade.go @@ -68,7 +68,7 @@ func AlipayCallback(c *fiber.Ctx) error { switch trade.Type { // 余额充值 - case 0: + case 2: err := s.User.RechargeConfirm(c.Context(), notification.OutTradeNo, verified) if err != nil { return err @@ -175,7 +175,7 @@ func WechatPayCallback(c *fiber.Ctx) error { switch { // 余额充值 - case trade.Type == 0: + case trade.Type == 2: err := s.User.RechargeConfirm(c.Context(), *content.OutTradeNo, verified) if err != nil { return err diff --git a/web/services/auth.go b/web/services/auth.go index c327db9..21b2743 100644 --- a/web/services/auth.go +++ b/web/services/auth.go @@ -26,12 +26,12 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie var clientType PayloadType switch client.Spec { - case 0: - clientType = PayloadClientConfidential case 1: clientType = PayloadClientPublic case 2: clientType = PayloadClientPublic + case 3: + clientType = PayloadClientConfidential } var permissions = make(map[string]struct{}, len(scope)) @@ -50,7 +50,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie } // todo 数据库定义会话持续时间 - token, err := Session.Create(ctx, auth) + token, err := Session.Create(ctx, auth, false) if err != nil { return nil, err } @@ -145,11 +145,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran }, } - duration := DefaultSessionConfig - if !data.Remember { - duration.RefreshTokenDuration = 0 - } - token, err := Session.Create(ctx, auth) + token, err := Session.Create(ctx, auth, data.Remember) if err != nil { return nil, err } diff --git a/web/services/auth_test.go b/web/services/auth_test.go index 2c2c2f2..458d094 100644 --- a/web/services/auth_test.go +++ b/web/services/auth_test.go @@ -16,13 +16,13 @@ type mockSessionService struct { func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) { panic("implement me") } -func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) { +func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) { panic("implement me") } func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error { panic("implement me") } -func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) { +func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { return m.createFunc(ctx, auth) } @@ -58,7 +58,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) { name: "成功 - 机密客户端 (Spec=0)", args: args{ ctx: context.Background(), - client: &models.Client{ID: 1, Spec: 0}, + client: &models.Client{ID: 1, Spec: 3}, scope: []string{"read", "write"}, }, mockCreateErr: nil, diff --git a/web/services/channel_test.go b/web/services/channel_test.go index e6beb95..69a656c 100644 --- a/web/services/channel_test.go +++ b/web/services/channel_test.go @@ -109,7 +109,7 @@ func Test_cache(t *testing.T) { // 准备测试数据 now := time.Now() - expiration := now.Add(24 * time.Hour) + expiration := common.LocalDateTime(now.Add(24 * time.Hour)) testChannels := []*models.Channel{ { @@ -471,7 +471,7 @@ func Test_channelService_CreateChannel(t *testing.T) { if ch.Password != *info.Password { return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) } - if ch.Expiration.IsZero() { + if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } @@ -615,7 +615,7 @@ func Test_channelService_CreateChannel(t *testing.T) { if ch.Protocol != int32(info.Proto) { return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) } - if ch.Expiration.IsZero() { + if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } @@ -765,7 +765,7 @@ func Test_channelService_CreateChannel(t *testing.T) { if ch.Password != *info.Password { return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) } - if ch.Expiration.IsZero() { + if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } @@ -914,7 +914,7 @@ func Test_channelService_CreateChannel(t *testing.T) { ProxyID: 1, ProxyPort: int32(i + 10000), UserID: 101, - Expiration: expr, + Expiration: common.LocalDateTime(expr), } } db.CreateInBatches(channels, 1000) @@ -1040,9 +1040,9 @@ func Test_channelService_RemoveChannels(t *testing.T) { // 创建通道 channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 @@ -1154,9 +1154,9 @@ func Test_channelService_RemoveChannels(t *testing.T) { // 创建通道 channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 @@ -1268,9 +1268,9 @@ func Test_channelService_RemoveChannels(t *testing.T) { // 创建通道 channels := []models.Channel{ - {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, + {ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 diff --git a/web/services/session.go b/web/services/session.go index ee98e51..f66023e 100644 --- a/web/services/session.go +++ b/web/services/session.go @@ -21,7 +21,7 @@ type SessionServiceInter interface { // Find 通过访问令牌获取会话信息 Find(ctx context.Context, token string) (*AuthContext, error) // Create 创建一个新的会话 - Create(ctx context.Context, auth AuthContext) (*TokenDetails, error) + Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) // Refresh 刷新一个会话 Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) // Remove 删除会话 @@ -62,7 +62,7 @@ func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, } // Create 创建一个新的会话 -func (s *sessionService) Create(ctx context.Context, auth AuthContext) (*TokenDetails, error) { +func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) { var now = time.Now() // 生成令牌组 @@ -90,7 +90,9 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext) (*TokenDe pipe := rds.Client.TxPipeline() pipe.Set(ctx, accessKey(accessToken), authData, accessExpire) - pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire) + if remember { + pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire) + } _, err = pipe.Exec(ctx) if err != nil { return nil, err diff --git a/web/services/session_test.go b/web/services/session_test.go index 6a52ddb..16aa120 100644 --- a/web/services/session_test.go +++ b/web/services/session_test.go @@ -34,9 +34,8 @@ func Test_sessionService_Create(t *testing.T) { auth := createTestAuthContext() type args struct { - ctx context.Context - auth AuthContext - config []SessionConfig + ctx context.Context + auth AuthContext } tests := []struct { name string @@ -45,7 +44,7 @@ func Test_sessionService_Create(t *testing.T) { wantErr bool }{ { - name: "使用默认配置创建会话", + name: "创建会话", args: args{ ctx: ctx, auth: auth, @@ -68,53 +67,13 @@ func Test_sessionService_Create(t *testing.T) { }, wantErr: false, }, - { - name: "使用自定义配置创建会话", - args: args{ - ctx: ctx, - auth: auth, - config: []SessionConfig{ - { - AccessTokenDuration: 10 * time.Minute, - RefreshTokenDuration: 24 * time.Hour, - }, - }, - }, - want: func(td *TokenDetails) bool { - // 验证令牌存在且格式正确 - if td.AccessToken == "" || td.RefreshToken == "" { - return false - } - // 验证到期时间在未来且接近预期时间 - now := time.Now() - expectedAccessExpiry := now.Add(10 * time.Minute) - expectedRefreshExpiry := now.Add(24 * time.Hour) - - accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry) - refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry) - - if accessDiff < -2*time.Second || accessDiff > 2*time.Second { - return false - } - if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second { - return false - } - - // 验证认证信息正确 - if !reflect.DeepEqual(td.Auth, auth) { - return false - } - return true - }, - wantErr: false, - }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mr.FlushAll() s := &sessionService{} - got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...) + got, err := s.Create(tt.args.ctx, tt.args.auth, true) if (err != nil) != tt.wantErr { t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) return @@ -145,7 +104,7 @@ func Test_sessionService_Find(t *testing.T) { s := &sessionService{} // 创建一个有效的会话 - td, err := s.Create(ctx, auth) + td, err := s.Create(ctx, auth, true) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } @@ -204,7 +163,7 @@ func Test_sessionService_Refresh(t *testing.T) { s := &sessionService{} // 创建一个初始会话 - td, err := s.Create(ctx, auth) + td, err := s.Create(ctx, auth, true) if err != nil { t.Fatalf("无法创建初始会话: %v", err) } @@ -216,7 +175,6 @@ func Test_sessionService_Refresh(t *testing.T) { type args struct { ctx context.Context refreshToken string - config []SessionConfig } tests := []struct { name string @@ -259,7 +217,7 @@ func Test_sessionService_Refresh(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...) + got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken) if (err != nil) != tt.wantErr { t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr) return @@ -297,7 +255,7 @@ func Test_sessionService_Remove(t *testing.T) { s := &sessionService{} // 创建一个会话 - td, err := s.Create(ctx, auth) + td, err := s.Create(ctx, auth, true) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } diff --git a/web/services/transaction.go b/web/services/transaction.go index 16858bc..ec9da60 100644 --- a/web/services/transaction.go +++ b/web/services/transaction.go @@ -134,12 +134,23 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query, } // 保存交易订单 + var tradeType int + var billType int + switch tType { + case TransactionTypeRecharge: + tradeType = 2 + billType = 3 + case TransactionTypePurchase: + tradeType = 1 + billType = 1 + } + var trade = m.Trade{ UserID: uid, InnerNo: tradeNo, Subject: subject, Method: int32(method), - Type: int32(tType), + Type: int32(tradeType), Amount: amount, Status: 0, // 0-待支付 PayURL: payUrl, @@ -155,7 +166,7 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query, UserID: uid, TradeID: trade.ID, Info: subject, - Type: int32(tType), + Type: int32(billType), Amount: amount, } err = q.Bill. @@ -329,8 +340,8 @@ func (s *transactionService) FinishTransaction(ctx context.Context, q *q.Query, type TransactionType int32 const ( - TransactionTypeRecharge TransactionType = iota - TransactionTypePurchase + TransactionTypePurchase TransactionType = iota + 1 + TransactionTypeRecharge ) type TransactionMethod int32