Files
platform/web/services/auth_test.go

147 lines
3.8 KiB
Go

package services
import (
"context"
"platform/web/models"
"reflect"
"testing"
"time"
)
// mockSessionService 用于模拟Session服务的行为
type mockSessionService struct {
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error)
}
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
panic("implement me")
}
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) {
panic("implement me")
}
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
panic("implement me")
}
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) {
return m.createFunc(ctx, auth)
}
func Test_authService_OauthClientCredentials(t *testing.T) {
// 暂存原始Session服务
originalSession := Session
defer func() {
// 测试结束后恢复原始Session服务
Session = originalSession
}()
// 预设的令牌详情
expectedToken := &TokenDetails{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccessTokenExpires: time.Now().Add(3600 * time.Second),
}
type args struct {
ctx context.Context
client *models.Client
scope []string
}
tests := []struct {
name string
args args
mockCreateErr error
want *TokenDetails
wantErr bool
wantPayload Payload
}{
{
name: "成功 - 机密客户端 (Spec=0)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 0},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientConfidential,
Id: 1,
},
},
{
name: "成功 - 公共客户端 (Spec=1)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 1},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
Id: 1,
},
},
{
name: "成功 - 公共客户端 (Spec=2)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 2},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
Id: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 为每个测试用例设置模拟的Session服务
mockSession := &mockSessionService{
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
// 验证权限映射
if len(auth.Permissions) != len(tt.args.scope) {
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope))
for key := range auth.Permissions {
if _, ok := auth.Permissions[key]; !ok {
t.Errorf("Permissions[%s] not found", key)
}
}
}
// 验证Payload
if auth.Payload.Type != tt.wantPayload.Type {
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type)
}
if auth.Payload.Id != tt.wantPayload.Id {
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id)
}
return expectedToken, tt.mockCreateErr
},
}
// 替换Session服务为模拟实现
Session = mockSession
s := &authService{}
got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...)
if (err != nil) != tt.wantErr {
t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want)
}
})
}
}