Files
platform/web/services/session_test.go

423 lines
9.1 KiB
Go
Raw Normal View History

package services
import (
"context"
"errors"
"platform/web/auth"
"platform/web/testutil"
"reflect"
"testing"
"time"
)
// 创建测试用的认证上下文
func createTestAuthContext() auth.Context {
//goland:noinspection ALL
return auth.Context{
Payload: auth.Payload{
Type: auth.PayloadUser,
Id: 1001,
},
Permissions: map[string]struct{}{
"read": {},
"write": {},
},
Metadata: map[string]interface{}{
"username": "testuser",
"email": "test@example.com",
},
}
}
func Test_sessionService_Create(t *testing.T) {
2025-04-01 11:32:17 +08:00
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
authCtx := createTestAuthContext()
type args struct {
ctx context.Context
auth auth.Context
}
tests := []struct {
name string
args args
want func(*TokenDetails) bool
wantErr bool
}{
{
name: "创建会话",
args: args{
ctx: ctx,
auth: authCtx,
},
want: func(td *TokenDetails) bool {
// 验证令牌存在且格式正确
if td.AccessToken == "" || td.RefreshToken == "" {
return false
}
// 验证到期时间在未来
now := time.Now()
if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) {
return false
}
// 验证认证信息正确
if !reflect.DeepEqual(td.Auth, authCtx) {
return false
}
return true
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll()
s := &sessionService{}
got, err := s.Create(tt.args.ctx, tt.args.auth, true)
if (err != nil) != tt.wantErr {
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.want(got) {
t.Errorf("Create() got = %v, want to satisfy conditions", got)
}
// 验证 Redis 中是否有相应的键
accessKey := accessKey(got.AccessToken)
refreshKey := refreshKey(got.RefreshToken)
if !mr.Exists(accessKey) {
t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey)
}
if !mr.Exists(refreshKey) {
t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey)
}
})
}
}
func Test_sessionService_Find(t *testing.T) {
2025-04-01 11:32:17 +08:00
testutil.SetupRedisTest(t)
ctx := context.Background()
authCtx := createTestAuthContext()
s := &sessionService{}
// 创建一个有效的会话
td, err := s.Create(ctx, authCtx, true)
if err != nil {
t.Fatalf("无法创建测试会话: %v", err)
}
validToken := td.AccessToken
invalidToken := "invalid-token"
type args struct {
ctx context.Context
token string
}
tests := []struct {
name string
args args
want *auth.Context
wantErr error
}{
{
name: "查找有效令牌",
args: args{
ctx: ctx,
token: validToken,
},
want: &authCtx,
wantErr: nil,
},
{
name: "查找无效令牌",
args: args{
ctx: ctx,
token: invalidToken,
},
want: nil,
wantErr: ErrInvalidToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := s.Find(tt.args.ctx, tt.args.token)
if !errors.Is(err, tt.wantErr) {
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Find() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_sessionService_Refresh(t *testing.T) {
2025-04-01 11:32:17 +08:00
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
authCtx := createTestAuthContext()
s := &sessionService{}
// 创建一个初始会话
td, err := s.Create(ctx, authCtx, true)
if err != nil {
t.Fatalf("无法创建初始会话: %v", err)
}
validRefreshToken := td.RefreshToken
invalidRefreshToken := "invalid-refresh-token"
originalAccessToken := td.AccessToken
type args struct {
ctx context.Context
refreshToken string
}
tests := []struct {
name string
args args
want func(*TokenDetails) bool
wantErr bool
}{
{
name: "使用有效的刷新令牌",
args: args{
ctx: ctx,
refreshToken: validRefreshToken,
},
want: func(td *TokenDetails) bool {
if td.AccessToken == "" || td.RefreshToken == "" {
return false
}
// 新的令牌应该与旧的不同
if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken {
return false
}
// 验证认证信息一致
if !reflect.DeepEqual(td.Auth, authCtx) {
return false
}
return true
},
wantErr: false,
},
{
name: "使用无效的刷新令牌",
args: args{
ctx: ctx,
refreshToken: invalidRefreshToken,
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken)
if (err != nil) != tt.wantErr {
t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil && !tt.want(got) {
t.Errorf("Refresh() got = %v, want to satisfy conditions", got)
}
if !tt.wantErr && got != nil {
// 验证旧的令牌已被删除
if mr.Exists(accessKey(originalAccessToken)) {
t.Errorf("原始访问令牌键应被删除")
}
if mr.Exists(refreshKey(validRefreshToken)) {
t.Errorf("原始刷新令牌键应被删除")
}
// 验证新的令牌已被添加
if !mr.Exists(accessKey(got.AccessToken)) {
t.Errorf("新的访问令牌键应存在")
}
if !mr.Exists(refreshKey(got.RefreshToken)) {
t.Errorf("新的刷新令牌键应存在")
}
}
})
}
}
func Test_sessionService_Remove(t *testing.T) {
2025-04-01 11:32:17 +08:00
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
authCtx := createTestAuthContext()
s := &sessionService{}
// 创建一个会话
td, err := s.Create(ctx, authCtx, true)
if err != nil {
t.Fatalf("无法创建测试会话: %v", err)
}
validAccessToken := td.AccessToken
validRefreshToken := td.RefreshToken
type args struct {
ctx context.Context
accessToken string
refreshToken string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "删除有效会话",
args: args{
ctx: ctx,
accessToken: validAccessToken,
refreshToken: validRefreshToken,
},
wantErr: false,
},
{
name: "删除已删除的会话",
args: args{
ctx: ctx,
accessToken: validAccessToken,
refreshToken: validRefreshToken,
},
wantErr: false, // 删除不存在的会话不应报错
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr {
t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr)
}
// 验证键已被删除
if mr.Exists(accessKey(tt.args.accessToken)) {
t.Errorf("访问令牌键应已被删除")
}
if mr.Exists(refreshKey(tt.args.refreshToken)) {
t.Errorf("刷新令牌键应已被删除")
}
})
}
}
func TestAuthContext_AnyPermission(t *testing.T) {
type fields struct {
Payload auth.Payload
Permissions map[string]struct{}
Metadata map[string]interface{}
}
type args struct {
requiredPermission []string
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "用户拥有所需权限",
fields: fields{
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
"write": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: true,
},
{
name: "用户拥有至少一个所需权限",
fields: fields{
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read", "admin"},
},
want: true,
},
{
name: "用户没有所需权限",
fields: fields{
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"admin", "delete"},
},
want: false,
},
{
name: "空权限列表",
fields: fields{
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: map[string]struct{}{},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
{
name: "nil权限列表",
fields: fields{
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
Permissions: nil,
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
{
name: "nil认证上下文",
fields: fields{
Payload: auth.Payload{},
Permissions: nil,
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &auth.Context{
Payload: tt.fields.Payload,
Permissions: tt.fields.Permissions,
Metadata: tt.fields.Metadata,
}
if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want {
t.Errorf("AnyPermission() = %v, want %v", got, tt.want)
}
})
}
}