认证授权测试代码与业务代码质量修复
This commit is contained in:
486
web/services/session_test.go
Normal file
486
web/services/session_test.go
Normal file
@@ -0,0 +1,486 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/init/rds"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 设置 Redis 模拟服务器
|
||||
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("无法启动 miniredis: %v", err)
|
||||
}
|
||||
|
||||
// 替换 Redis 客户端为测试客户端
|
||||
origClient := rds.Client
|
||||
rds.Client = redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
rds.Client = origClient
|
||||
})
|
||||
|
||||
return mr
|
||||
}
|
||||
|
||||
// 创建测试用的认证上下文
|
||||
func createTestAuthContext() AuthContext {
|
||||
return AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 1001,
|
||||
},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
"write": {},
|
||||
},
|
||||
Metadata: map[string]interface{}{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Create(t *testing.T) {
|
||||
mr := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth AuthContext
|
||||
config []SessionConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want func(*TokenDetails) bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "使用默认配置创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
// 验证令牌存在且格式正确
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 验证到期时间在未来
|
||||
now := time.Now()
|
||||
if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息正确
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "使用自定义配置创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: auth,
|
||||
config: []SessionConfig{
|
||||
{
|
||||
AccessTokenDuration: 10 * time.Minute,
|
||||
RefreshTokenDuration: 24 * time.Hour,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
// 验证令牌存在且格式正确
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 验证到期时间在未来且接近预期时间
|
||||
now := time.Now()
|
||||
expectedAccessExpiry := now.Add(10 * time.Minute)
|
||||
expectedRefreshExpiry := now.Add(24 * time.Hour)
|
||||
|
||||
accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry)
|
||||
refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry)
|
||||
|
||||
if accessDiff < -2*time.Second || accessDiff > 2*time.Second {
|
||||
return false
|
||||
}
|
||||
if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second {
|
||||
return false
|
||||
}
|
||||
|
||||
// 验证认证信息正确
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll()
|
||||
s := &sessionService{}
|
||||
got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.want(got) {
|
||||
t.Errorf("Create() got = %v, want to satisfy conditions", got)
|
||||
}
|
||||
|
||||
// 验证 Redis 中是否有相应的键
|
||||
accessKey := accessKey(got.AccessToken)
|
||||
refreshKey := refreshKey(got.RefreshToken)
|
||||
|
||||
if !mr.Exists(accessKey) {
|
||||
t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey)
|
||||
}
|
||||
|
||||
if !mr.Exists(refreshKey) {
|
||||
t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Find(t *testing.T) {
|
||||
_ = setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个有效的会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
|
||||
validToken := td.AccessToken
|
||||
invalidToken := "invalid-token"
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *AuthContext
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "查找有效令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
token: validToken,
|
||||
},
|
||||
want: &auth,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "查找无效令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
token: invalidToken,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := s.Find(tt.args.ctx, tt.args.token)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Refresh(t *testing.T) {
|
||||
mr := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个初始会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建初始会话: %v", err)
|
||||
}
|
||||
|
||||
validRefreshToken := td.RefreshToken
|
||||
invalidRefreshToken := "invalid-refresh-token"
|
||||
originalAccessToken := td.AccessToken
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
config []SessionConfig
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want func(*TokenDetails) bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "使用有效的刷新令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 新的令牌应该与旧的不同
|
||||
if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息一致
|
||||
if !reflect.DeepEqual(td.Auth, auth) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "使用无效的刷新令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
refreshToken: invalidRefreshToken,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.want != nil && !tt.want(got) {
|
||||
t.Errorf("Refresh() got = %v, want to satisfy conditions", got)
|
||||
}
|
||||
|
||||
if !tt.wantErr && got != nil {
|
||||
// 验证旧的令牌已被删除
|
||||
if mr.Exists(accessKey(originalAccessToken)) {
|
||||
t.Errorf("原始访问令牌键应被删除")
|
||||
}
|
||||
if mr.Exists(refreshKey(validRefreshToken)) {
|
||||
t.Errorf("原始刷新令牌键应被删除")
|
||||
}
|
||||
|
||||
// 验证新的令牌已被添加
|
||||
if !mr.Exists(accessKey(got.AccessToken)) {
|
||||
t.Errorf("新的访问令牌键应存在")
|
||||
}
|
||||
if !mr.Exists(refreshKey(got.RefreshToken)) {
|
||||
t.Errorf("新的刷新令牌键应存在")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Remove(t *testing.T) {
|
||||
mr := setupTestRedis(t)
|
||||
ctx := context.Background()
|
||||
auth := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个会话
|
||||
td, err := s.Create(ctx, auth)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
|
||||
validAccessToken := td.AccessToken
|
||||
validRefreshToken := td.RefreshToken
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
accessToken string
|
||||
refreshToken string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "删除有效会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
accessToken: validAccessToken,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "删除已删除的会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
accessToken: validAccessToken,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
wantErr: false, // 删除不存在的会话不应报错
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
// 验证键已被删除
|
||||
if mr.Exists(accessKey(tt.args.accessToken)) {
|
||||
t.Errorf("访问令牌键应已被删除")
|
||||
}
|
||||
if mr.Exists(refreshKey(tt.args.refreshToken)) {
|
||||
t.Errorf("刷新令牌键应已被删除")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
type fields struct {
|
||||
Payload Payload
|
||||
Permissions map[string]struct{}
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
type args struct {
|
||||
requiredPermission []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "用户拥有所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
"write": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "用户拥有至少一个所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read", "admin"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "用户没有所需权限",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"admin", "delete"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "空权限列表",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil权限列表",
|
||||
fields: fields{
|
||||
Payload: Payload{Type: PayloadUser, Id: 1},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil认证上下文",
|
||||
fields: fields{
|
||||
Payload: Payload{},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &AuthContext{
|
||||
Payload: tt.fields.Payload,
|
||||
Permissions: tt.fields.Permissions,
|
||||
Metadata: tt.fields.Metadata,
|
||||
}
|
||||
if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want {
|
||||
t.Errorf("AnyPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user