147 lines
3.8 KiB
Go
147 lines
3.8 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"platform/web/models"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
// mockSessionService 用于模拟Session服务的行为
|
|
type mockSessionService struct {
|
|
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error)
|
|
}
|
|
|
|
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
|
|
panic("implement me")
|
|
}
|
|
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) {
|
|
panic("implement me")
|
|
}
|
|
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
|
panic("implement me")
|
|
}
|
|
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) {
|
|
return m.createFunc(ctx, auth)
|
|
}
|
|
|
|
func Test_authService_OauthClientCredentials(t *testing.T) {
|
|
// 暂存原始Session服务
|
|
originalSession := Session
|
|
defer func() {
|
|
// 测试结束后恢复原始Session服务
|
|
Session = originalSession
|
|
}()
|
|
|
|
// 预设的令牌详情
|
|
expectedToken := &TokenDetails{
|
|
AccessToken: "test-access-token",
|
|
RefreshToken: "test-refresh-token",
|
|
AccessTokenExpires: time.Now().Add(3600 * time.Second),
|
|
}
|
|
|
|
type args struct {
|
|
ctx context.Context
|
|
client *models.Client
|
|
scope []string
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
args args
|
|
mockCreateErr error
|
|
want *TokenDetails
|
|
wantErr bool
|
|
wantPayload Payload
|
|
}{
|
|
{
|
|
name: "成功 - 机密客户端 (Spec=0)",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
client: &models.Client{ID: 1, Spec: 0},
|
|
scope: []string{"read", "write"},
|
|
},
|
|
mockCreateErr: nil,
|
|
want: expectedToken,
|
|
wantErr: false,
|
|
wantPayload: Payload{
|
|
Type: PayloadClientConfidential,
|
|
Id: 1,
|
|
},
|
|
},
|
|
{
|
|
name: "成功 - 公共客户端 (Spec=1)",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
client: &models.Client{ID: 1, Spec: 1},
|
|
scope: []string{"read", "write"},
|
|
},
|
|
mockCreateErr: nil,
|
|
want: expectedToken,
|
|
wantErr: false,
|
|
wantPayload: Payload{
|
|
Type: PayloadClientPublic,
|
|
Id: 1,
|
|
},
|
|
},
|
|
{
|
|
name: "成功 - 公共客户端 (Spec=2)",
|
|
args: args{
|
|
ctx: context.Background(),
|
|
client: &models.Client{ID: 1, Spec: 2},
|
|
scope: []string{"read", "write"},
|
|
},
|
|
mockCreateErr: nil,
|
|
want: expectedToken,
|
|
wantErr: false,
|
|
wantPayload: Payload{
|
|
Type: PayloadClientPublic,
|
|
Id: 1,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
// 为每个测试用例设置模拟的Session服务
|
|
mockSession := &mockSessionService{
|
|
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
|
|
// 验证权限映射
|
|
if len(auth.Permissions) != len(tt.args.scope) {
|
|
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope))
|
|
for key := range auth.Permissions {
|
|
if _, ok := auth.Permissions[key]; !ok {
|
|
t.Errorf("Permissions[%s] not found", key)
|
|
}
|
|
}
|
|
}
|
|
|
|
// 验证Payload
|
|
if auth.Payload.Type != tt.wantPayload.Type {
|
|
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type)
|
|
}
|
|
if auth.Payload.Id != tt.wantPayload.Id {
|
|
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id)
|
|
}
|
|
|
|
return expectedToken, tt.mockCreateErr
|
|
},
|
|
}
|
|
|
|
// 替换Session服务为模拟实现
|
|
Session = mockSession
|
|
|
|
s := &authService{}
|
|
got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if !reflect.DeepEqual(got, tt.want) {
|
|
t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|