package services import ( "context" "errors" "platform/pkg/testutil" "reflect" "testing" "time" ) // 创建测试用的认证上下文 func createTestAuthContext() AuthContext { return AuthContext{ Payload: Payload{ Type: PayloadUser, Id: 1001, }, Permissions: map[string]struct{}{ "read": {}, "write": {}, }, Metadata: map[string]interface{}{ "username": "testuser", "email": "test@example.com", }, } } func Test_sessionService_Create(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() type args struct { ctx context.Context auth AuthContext config []SessionConfig } tests := []struct { name string args args want func(*TokenDetails) bool wantErr bool }{ { name: "使用默认配置创建会话", args: args{ ctx: ctx, auth: auth, }, want: func(td *TokenDetails) bool { // 验证令牌存在且格式正确 if td.AccessToken == "" || td.RefreshToken == "" { return false } // 验证到期时间在未来 now := time.Now() if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) { return false } // 验证认证信息正确 if !reflect.DeepEqual(td.Auth, auth) { return false } return true }, wantErr: false, }, { name: "使用自定义配置创建会话", args: args{ ctx: ctx, auth: auth, config: []SessionConfig{ { AccessTokenDuration: 10 * time.Minute, RefreshTokenDuration: 24 * time.Hour, }, }, }, want: func(td *TokenDetails) bool { // 验证令牌存在且格式正确 if td.AccessToken == "" || td.RefreshToken == "" { return false } // 验证到期时间在未来且接近预期时间 now := time.Now() expectedAccessExpiry := now.Add(10 * time.Minute) expectedRefreshExpiry := now.Add(24 * time.Hour) accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry) refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry) if accessDiff < -2*time.Second || accessDiff > 2*time.Second { return false } if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second { return false } // 验证认证信息正确 if !reflect.DeepEqual(td.Auth, auth) { return false } return true }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mr.FlushAll() s := &sessionService{} got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...) if (err != nil) != tt.wantErr { t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr) return } if !tt.want(got) { t.Errorf("Create() got = %v, want to satisfy conditions", got) } // 验证 Redis 中是否有相应的键 accessKey := accessKey(got.AccessToken) refreshKey := refreshKey(got.RefreshToken) if !mr.Exists(accessKey) { t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey) } if !mr.Exists(refreshKey) { t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey) } }) } } func Test_sessionService_Find(t *testing.T) { testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} // 创建一个有效的会话 td, err := s.Create(ctx, auth) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } validToken := td.AccessToken invalidToken := "invalid-token" type args struct { ctx context.Context token string } tests := []struct { name string args args want *AuthContext wantErr error }{ { name: "查找有效令牌", args: args{ ctx: ctx, token: validToken, }, want: &auth, wantErr: nil, }, { name: "查找无效令牌", args: args{ ctx: ctx, token: invalidToken, }, want: nil, wantErr: ErrInvalidToken, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := s.Find(tt.args.ctx, tt.args.token) if !errors.Is(err, tt.wantErr) { t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr) return } if !reflect.DeepEqual(got, tt.want) { t.Errorf("Find() got = %v, want %v", got, tt.want) } }) } } func Test_sessionService_Refresh(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} // 创建一个初始会话 td, err := s.Create(ctx, auth) if err != nil { t.Fatalf("无法创建初始会话: %v", err) } validRefreshToken := td.RefreshToken invalidRefreshToken := "invalid-refresh-token" originalAccessToken := td.AccessToken type args struct { ctx context.Context refreshToken string config []SessionConfig } tests := []struct { name string args args want func(*TokenDetails) bool wantErr bool }{ { name: "使用有效的刷新令牌", args: args{ ctx: ctx, refreshToken: validRefreshToken, }, want: func(td *TokenDetails) bool { if td.AccessToken == "" || td.RefreshToken == "" { return false } // 新的令牌应该与旧的不同 if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken { return false } // 验证认证信息一致 if !reflect.DeepEqual(td.Auth, auth) { return false } return true }, wantErr: false, }, { name: "使用无效的刷新令牌", args: args{ ctx: ctx, refreshToken: invalidRefreshToken, }, want: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...) if (err != nil) != tt.wantErr { t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr) return } if tt.want != nil && !tt.want(got) { t.Errorf("Refresh() got = %v, want to satisfy conditions", got) } if !tt.wantErr && got != nil { // 验证旧的令牌已被删除 if mr.Exists(accessKey(originalAccessToken)) { t.Errorf("原始访问令牌键应被删除") } if mr.Exists(refreshKey(validRefreshToken)) { t.Errorf("原始刷新令牌键应被删除") } // 验证新的令牌已被添加 if !mr.Exists(accessKey(got.AccessToken)) { t.Errorf("新的访问令牌键应存在") } if !mr.Exists(refreshKey(got.RefreshToken)) { t.Errorf("新的刷新令牌键应存在") } } }) } } func Test_sessionService_Remove(t *testing.T) { mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} // 创建一个会话 td, err := s.Create(ctx, auth) if err != nil { t.Fatalf("无法创建测试会话: %v", err) } validAccessToken := td.AccessToken validRefreshToken := td.RefreshToken type args struct { ctx context.Context accessToken string refreshToken string } tests := []struct { name string args args wantErr bool }{ { name: "删除有效会话", args: args{ ctx: ctx, accessToken: validAccessToken, refreshToken: validRefreshToken, }, wantErr: false, }, { name: "删除已删除的会话", args: args{ ctx: ctx, accessToken: validAccessToken, refreshToken: validRefreshToken, }, wantErr: false, // 删除不存在的会话不应报错 }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr { t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr) } // 验证键已被删除 if mr.Exists(accessKey(tt.args.accessToken)) { t.Errorf("访问令牌键应已被删除") } if mr.Exists(refreshKey(tt.args.refreshToken)) { t.Errorf("刷新令牌键应已被删除") } }) } } func TestAuthContext_AnyPermission(t *testing.T) { type fields struct { Payload Payload Permissions map[string]struct{} Metadata map[string]interface{} } type args struct { requiredPermission []string } tests := []struct { name string fields fields args args want bool }{ { name: "用户拥有所需权限", fields: fields{ Payload: Payload{Type: PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, "write": {}, }, Metadata: nil, }, args: args{ requiredPermission: []string{"read"}, }, want: true, }, { name: "用户拥有至少一个所需权限", fields: fields{ Payload: Payload{Type: PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, }, Metadata: nil, }, args: args{ requiredPermission: []string{"read", "admin"}, }, want: true, }, { name: "用户没有所需权限", fields: fields{ Payload: Payload{Type: PayloadUser, Id: 1}, Permissions: map[string]struct{}{ "read": {}, }, Metadata: nil, }, args: args{ requiredPermission: []string{"admin", "delete"}, }, want: false, }, { name: "空权限列表", fields: fields{ Payload: Payload{Type: PayloadUser, Id: 1}, Permissions: map[string]struct{}{}, Metadata: nil, }, args: args{ requiredPermission: []string{"read"}, }, want: false, }, { name: "nil权限列表", fields: fields{ Payload: Payload{Type: PayloadUser, Id: 1}, Permissions: nil, Metadata: nil, }, args: args{ requiredPermission: []string{"read"}, }, want: false, }, { name: "nil认证上下文", fields: fields{ Payload: Payload{}, Permissions: nil, Metadata: nil, }, args: args{ requiredPermission: []string{"read"}, }, want: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { a := &AuthContext{ Payload: tt.fields.Payload, Permissions: tt.fields.Permissions, Metadata: tt.fields.Metadata, } if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want { t.Errorf("AnyPermission() = %v, want %v", got, tt.want) } }) } }