修改部分枚举字段的编号与注释确保 0 值的正确语义
This commit is contained in:
@@ -26,12 +26,12 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
|
||||
var clientType PayloadType
|
||||
switch client.Spec {
|
||||
case 0:
|
||||
clientType = PayloadClientConfidential
|
||||
case 1:
|
||||
clientType = PayloadClientPublic
|
||||
case 2:
|
||||
clientType = PayloadClientPublic
|
||||
case 3:
|
||||
clientType = PayloadClientConfidential
|
||||
}
|
||||
|
||||
var permissions = make(map[string]struct{}, len(scope))
|
||||
@@ -50,7 +50,7 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *m.Clie
|
||||
}
|
||||
|
||||
// todo 数据库定义会话持续时间
|
||||
token, err := Session.Create(ctx, auth)
|
||||
token, err := Session.Create(ctx, auth, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -145,11 +145,7 @@ func (s *authService) OauthPassword(ctx context.Context, _ *m.Client, data *Gran
|
||||
},
|
||||
}
|
||||
|
||||
duration := DefaultSessionConfig
|
||||
if !data.Remember {
|
||||
duration.RefreshTokenDuration = 0
|
||||
}
|
||||
token, err := Session.Create(ctx, auth)
|
||||
token, err := Session.Create(ctx, auth, data.Remember)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ type mockSessionService struct {
|
||||
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) {
|
||||
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) {
|
||||
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
|
||||
return m.createFunc(ctx, auth)
|
||||
}
|
||||
|
||||
@@ -58,7 +58,7 @@ func Test_authService_OauthClientCredentials(t *testing.T) {
|
||||
name: "成功 - 机密客户端 (Spec=0)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
client: &models.Client{ID: 1, Spec: 0},
|
||||
client: &models.Client{ID: 1, Spec: 3},
|
||||
scope: []string{"read", "write"},
|
||||
},
|
||||
mockCreateErr: nil,
|
||||
|
||||
@@ -109,7 +109,7 @@ func Test_cache(t *testing.T) {
|
||||
|
||||
// 准备测试数据
|
||||
now := time.Now()
|
||||
expiration := now.Add(24 * time.Hour)
|
||||
expiration := common.LocalDateTime(now.Add(24 * time.Hour))
|
||||
|
||||
testChannels := []*models.Channel{
|
||||
{
|
||||
@@ -471,7 +471,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
if ch.Password != *info.Password {
|
||||
return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password)
|
||||
}
|
||||
if ch.Expiration.IsZero() {
|
||||
if time.Time(ch.Expiration).IsZero() {
|
||||
return fmt.Errorf("通道过期时间不应为空")
|
||||
}
|
||||
|
||||
@@ -615,7 +615,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
if ch.Protocol != int32(info.Proto) {
|
||||
return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol)
|
||||
}
|
||||
if ch.Expiration.IsZero() {
|
||||
if time.Time(ch.Expiration).IsZero() {
|
||||
return fmt.Errorf("通道过期时间不应为空")
|
||||
}
|
||||
|
||||
@@ -765,7 +765,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
if ch.Password != *info.Password {
|
||||
return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password)
|
||||
}
|
||||
if ch.Expiration.IsZero() {
|
||||
if time.Time(ch.Expiration).IsZero() {
|
||||
return fmt.Errorf("通道过期时间不应为空")
|
||||
}
|
||||
|
||||
@@ -914,7 +914,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
ProxyID: 1,
|
||||
ProxyPort: int32(i + 10000),
|
||||
UserID: 101,
|
||||
Expiration: expr,
|
||||
Expiration: common.LocalDateTime(expr),
|
||||
}
|
||||
}
|
||||
db.CreateInBatches(channels, 1000)
|
||||
@@ -1040,9 +1040,9 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
|
||||
// 创建通道
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
}
|
||||
|
||||
// 保存预设数据
|
||||
@@ -1154,9 +1154,9 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
|
||||
// 创建通道
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
}
|
||||
|
||||
// 保存预设数据
|
||||
@@ -1268,9 +1268,9 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
|
||||
// 创建通道
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: common.LocalDateTime(time.Now().Add(24 * time.Hour))},
|
||||
}
|
||||
|
||||
// 保存预设数据
|
||||
|
||||
@@ -21,7 +21,7 @@ type SessionServiceInter interface {
|
||||
// Find 通过访问令牌获取会话信息
|
||||
Find(ctx context.Context, token string) (*AuthContext, error)
|
||||
// Create 创建一个新的会话
|
||||
Create(ctx context.Context, auth AuthContext) (*TokenDetails, error)
|
||||
Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error)
|
||||
// Refresh 刷新一个会话
|
||||
Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error)
|
||||
// Remove 删除会话
|
||||
@@ -62,7 +62,7 @@ func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext,
|
||||
}
|
||||
|
||||
// Create 创建一个新的会话
|
||||
func (s *sessionService) Create(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
|
||||
func (s *sessionService) Create(ctx context.Context, auth AuthContext, remember bool) (*TokenDetails, error) {
|
||||
var now = time.Now()
|
||||
|
||||
// 生成令牌组
|
||||
@@ -90,7 +90,9 @@ func (s *sessionService) Create(ctx context.Context, auth AuthContext) (*TokenDe
|
||||
|
||||
pipe := rds.Client.TxPipeline()
|
||||
pipe.Set(ctx, accessKey(accessToken), authData, accessExpire)
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
if remember {
|
||||
pipe.Set(ctx, refreshKey(refreshToken), refreshData, refreshExpire)
|
||||
}
|
||||
_, err = pipe.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -34,9 +34,8 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
auth := createTestAuthContext()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth AuthContext
|
||||
config []SessionConfig
|
||||
ctx context.Context
|
||||
auth AuthContext
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -45,7 +44,7 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "使用默认配置创建会话",
|
||||
name: "创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
@@ -68,53 +67,13 @@ func Test_sessionService_Create(t *testing.T) {
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "使用自定义配置创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
config: []SessionConfig{
|
||||
{
|
||||
AccessTokenDuration: 10 * time.Minute,
|
||||
RefreshTokenDuration: 24 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
// 验证令牌存在且格式正确
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 验证到期时间在未来且接近预期时间
|
||||
now := time.Now()
|
||||
expectedAccessExpiry := now.Add(10 * time.Minute)
|
||||
expectedRefreshExpiry := now.Add(24 * time.Hour)
|
||||
|
||||
accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry)
|
||||
refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry)
|
||||
|
||||
if accessDiff < -2*time.Second || accessDiff > 2*time.Second {
|
||||
return false
|
||||
}
|
||||
if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证认证信息正确
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll()
|
||||
s := &sessionService{}
|
||||
got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...)
|
||||
got, err := s.Create(tt.args.ctx, tt.args.auth, true)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -145,7 +104,7 @@ func Test_sessionService_Find(t *testing.T) {
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个有效的会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
@@ -204,7 +163,7 @@ func Test_sessionService_Refresh(t *testing.T) {
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个初始会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建初始会话: %v", err)
|
||||
}
|
||||
@@ -216,7 +175,6 @@ func Test_sessionService_Refresh(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
config []SessionConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -259,7 +217,7 @@ func Test_sessionService_Refresh(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...)
|
||||
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
@@ -297,7 +255,7 @@ func Test_sessionService_Remove(t *testing.T) {
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
td, err := s.Create(ctx, auth, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
|
||||
@@ -134,12 +134,23 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query,
|
||||
}
|
||||
|
||||
// 保存交易订单
|
||||
var tradeType int
|
||||
var billType int
|
||||
switch tType {
|
||||
case TransactionTypeRecharge:
|
||||
tradeType = 2
|
||||
billType = 3
|
||||
case TransactionTypePurchase:
|
||||
tradeType = 1
|
||||
billType = 1
|
||||
}
|
||||
|
||||
var trade = m.Trade{
|
||||
UserID: uid,
|
||||
InnerNo: tradeNo,
|
||||
Subject: subject,
|
||||
Method: int32(method),
|
||||
Type: int32(tType),
|
||||
Type: int32(tradeType),
|
||||
Amount: amount,
|
||||
Status: 0, // 0-待支付
|
||||
PayURL: payUrl,
|
||||
@@ -155,7 +166,7 @@ func (s *transactionService) PrepareTransaction(ctx context.Context, q *q.Query,
|
||||
UserID: uid,
|
||||
TradeID: trade.ID,
|
||||
Info: subject,
|
||||
Type: int32(tType),
|
||||
Type: int32(billType),
|
||||
Amount: amount,
|
||||
}
|
||||
err = q.Bill.
|
||||
@@ -329,8 +340,8 @@ func (s *transactionService) FinishTransaction(ctx context.Context, q *q.Query,
|
||||
type TransactionType int32
|
||||
|
||||
const (
|
||||
TransactionTypeRecharge TransactionType = iota
|
||||
TransactionTypePurchase
|
||||
TransactionTypePurchase TransactionType = iota + 1
|
||||
TransactionTypeRecharge
|
||||
)
|
||||
|
||||
type TransactionMethod int32
|
||||
|
||||
Reference in New Issue
Block a user