暂时移除测试代码
This commit is contained in:
@@ -1,127 +0,0 @@
|
||||
package handlers
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_secureAddr(t *testing.T) {
|
||||
type args struct {
|
||||
str string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
// 有效的公网 IP 地址
|
||||
{
|
||||
name: "有效公网IPv4地址",
|
||||
args: args{str: "203.0.113.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "有效公网IPv6地址",
|
||||
args: args{str: "2001:db8::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// 私有地址
|
||||
{
|
||||
name: "IPv4私有地址(10.x.x.x)",
|
||||
args: args{str: "10.0.0.1"},
|
||||
wantErr: false, // 取决于需求,通常私有地址是被允许的全局单播地址
|
||||
},
|
||||
{
|
||||
name: "IPv4私有地址(172.16.x.x)",
|
||||
args: args{str: "172.16.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv4私有地址(192.168.x.x)",
|
||||
args: args{str: "192.168.0.1"},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "IPv6私有地址(ULA)",
|
||||
args: args{str: "fd00::1"},
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// 广播地址
|
||||
{
|
||||
name: "IPv4本地广播地址",
|
||||
args: args{str: "255.255.255.255"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 未指定地址
|
||||
{
|
||||
name: "IPv4未指定地址",
|
||||
args: args{str: "0.0.0.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6未指定地址",
|
||||
args: args{str: "::"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 回环地址
|
||||
{
|
||||
name: "IPv4回环地址",
|
||||
args: args{str: "127.0.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6回环地址",
|
||||
args: args{str: "::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 组播地址
|
||||
{
|
||||
name: "IPv4组播地址",
|
||||
args: args{str: "224.0.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6组播地址",
|
||||
args: args{str: "ff00::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 链路本地地址
|
||||
{
|
||||
name: "IPv4链路本地地址",
|
||||
args: args{str: "169.254.0.1"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "IPv6链路本地地址",
|
||||
args: args{str: "fe80::1"},
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// 格式错误的地址
|
||||
{
|
||||
name: "格式错误的IP地址",
|
||||
args: args{str: "not-an-ip"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "不完整的IP地址",
|
||||
args: args{str: "192.168.0"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "超出范围的IP地址",
|
||||
args: args{str: "256.256.256.256"},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := secureAddr(tt.args.str); (err != nil) != tt.wantErr {
|
||||
t.Errorf("secureAddr() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,147 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"platform/web/auth"
|
||||
"platform/web/models"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockSessionService 用于模拟Session服务的行为
|
||||
type mockSessionService struct {
|
||||
createFunc func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error)
|
||||
}
|
||||
|
||||
func (m *mockSessionService) Find(ctx context.Context, token string) (*auth.Context, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string) (*TokenDetails, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
|
||||
panic("implement me")
|
||||
}
|
||||
func (m *mockSessionService) Create(ctx context.Context, authCtx auth.Context, remember bool) (*TokenDetails, error) {
|
||||
return m.createFunc(ctx, authCtx)
|
||||
}
|
||||
|
||||
func Test_authService_OauthClientCredentials(t *testing.T) {
|
||||
// 暂存原始Session服务
|
||||
originalSession := Session
|
||||
defer func() {
|
||||
// 测试结束后恢复原始Session服务
|
||||
Session = originalSession
|
||||
}()
|
||||
|
||||
// 预设的令牌详情
|
||||
expectedToken := &TokenDetails{
|
||||
AccessToken: "test-access-token",
|
||||
RefreshToken: "test-refresh-token",
|
||||
AccessTokenExpires: time.Now().Add(3600 * time.Second),
|
||||
}
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
client *models.Client
|
||||
scope []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
mockCreateErr error
|
||||
want *TokenDetails
|
||||
wantErr bool
|
||||
wantPayload auth.Payload
|
||||
}{
|
||||
{
|
||||
name: "成功 - 机密客户端 (Spec=0)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
client: &models.Client{ID: 1, Spec: 3},
|
||||
scope: []string{"read", "write"},
|
||||
},
|
||||
mockCreateErr: nil,
|
||||
want: expectedToken,
|
||||
wantErr: false,
|
||||
wantPayload: auth.Payload{
|
||||
Type: auth.PayloadSecuredServer,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "成功 - 公共客户端 (Spec=1)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
client: &models.Client{ID: 1, Spec: 1},
|
||||
scope: []string{"read", "write"},
|
||||
},
|
||||
mockCreateErr: nil,
|
||||
want: expectedToken,
|
||||
wantErr: false,
|
||||
wantPayload: auth.Payload{
|
||||
Type: auth.PayloadPublicServer,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "成功 - 公共客户端 (Spec=2)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
client: &models.Client{ID: 1, Spec: 2},
|
||||
scope: []string{"read", "write"},
|
||||
},
|
||||
mockCreateErr: nil,
|
||||
want: expectedToken,
|
||||
wantErr: false,
|
||||
wantPayload: auth.Payload{
|
||||
Type: auth.PayloadPublicServer,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
// 为每个测试用例设置模拟的Session服务
|
||||
mockSession := &mockSessionService{
|
||||
createFunc: func(ctx context.Context, authCtx auth.Context) (*TokenDetails, error) {
|
||||
// 验证权限映射
|
||||
if len(authCtx.Permissions) != len(tt.args.scope) {
|
||||
t.Errorf("Permissions length = %v, want %v", len(authCtx.Permissions), len(tt.args.scope))
|
||||
for key := range authCtx.Permissions {
|
||||
if _, ok := authCtx.Permissions[key]; !ok {
|
||||
t.Errorf("Permissions[%s] not found", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证Payload
|
||||
if authCtx.Payload.Type != tt.wantPayload.Type {
|
||||
t.Errorf("Payload.Type = %v, want %v", authCtx.Payload.Type, tt.wantPayload.Type)
|
||||
}
|
||||
if authCtx.Payload.Id != tt.wantPayload.Id {
|
||||
t.Errorf("Payload.Id = %v, want %v", authCtx.Payload.Id, tt.wantPayload.Id)
|
||||
}
|
||||
|
||||
return expectedToken, tt.mockCreateErr
|
||||
},
|
||||
}
|
||||
|
||||
// 替换Session服务为模拟实现
|
||||
Session = mockSession
|
||||
|
||||
s := &authService{}
|
||||
got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,422 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/web/auth"
|
||||
"platform/web/testutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 创建测试用的认证上下文
|
||||
func createTestAuthContext() auth.Context {
|
||||
//goland:noinspection ALL
|
||||
return auth.Context{
|
||||
Payload: auth.Payload{
|
||||
Type: auth.PayloadUser,
|
||||
Id: 1001,
|
||||
},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
"write": {},
|
||||
},
|
||||
Metadata: map[string]interface{}{
|
||||
"username": "testuser",
|
||||
"email": "test@example.com",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Create(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
authCtx := createTestAuthContext()
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth auth.Context
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want func(*TokenDetails) bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "创建会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: authCtx,
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
// 验证令牌存在且格式正确
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 验证到期时间在未来
|
||||
now := time.Now()
|
||||
if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息正确
|
||||
if !reflect.DeepEqual(td.Auth, authCtx) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll()
|
||||
s := &sessionService{}
|
||||
got, err := s.Create(tt.args.ctx, tt.args.auth, true)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !tt.want(got) {
|
||||
t.Errorf("Create() got = %v, want to satisfy conditions", got)
|
||||
}
|
||||
|
||||
// 验证 Redis 中是否有相应的键
|
||||
accessKey := accessKey(got.AccessToken)
|
||||
refreshKey := refreshKey(got.RefreshToken)
|
||||
|
||||
if !mr.Exists(accessKey) {
|
||||
t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey)
|
||||
}
|
||||
|
||||
if !mr.Exists(refreshKey) {
|
||||
t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Find(t *testing.T) {
|
||||
testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个有效的会话
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
|
||||
validToken := td.AccessToken
|
||||
invalidToken := "invalid-token"
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
token string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *auth.Context
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "查找有效令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
token: validToken,
|
||||
},
|
||||
want: &authCtx,
|
||||
wantErr: nil,
|
||||
},
|
||||
{
|
||||
name: "查找无效令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
token: invalidToken,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: ErrInvalidToken,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := s.Find(tt.args.ctx, tt.args.token)
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("Find() got = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Refresh(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个初始会话
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建初始会话: %v", err)
|
||||
}
|
||||
|
||||
validRefreshToken := td.RefreshToken
|
||||
invalidRefreshToken := "invalid-refresh-token"
|
||||
originalAccessToken := td.AccessToken
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
refreshToken string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want func(*TokenDetails) bool
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "使用有效的刷新令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
want: func(td *TokenDetails) bool {
|
||||
if td.AccessToken == "" || td.RefreshToken == "" {
|
||||
return false
|
||||
}
|
||||
// 新的令牌应该与旧的不同
|
||||
if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken {
|
||||
return false
|
||||
}
|
||||
// 验证认证信息一致
|
||||
if !reflect.DeepEqual(td.Auth, authCtx) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "使用无效的刷新令牌",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
refreshToken: invalidRefreshToken,
|
||||
},
|
||||
want: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.want != nil && !tt.want(got) {
|
||||
t.Errorf("Refresh() got = %v, want to satisfy conditions", got)
|
||||
}
|
||||
|
||||
if !tt.wantErr && got != nil {
|
||||
// 验证旧的令牌已被删除
|
||||
if mr.Exists(accessKey(originalAccessToken)) {
|
||||
t.Errorf("原始访问令牌键应被删除")
|
||||
}
|
||||
if mr.Exists(refreshKey(validRefreshToken)) {
|
||||
t.Errorf("原始刷新令牌键应被删除")
|
||||
}
|
||||
|
||||
// 验证新的令牌已被添加
|
||||
if !mr.Exists(accessKey(got.AccessToken)) {
|
||||
t.Errorf("新的访问令牌键应存在")
|
||||
}
|
||||
if !mr.Exists(refreshKey(got.RefreshToken)) {
|
||||
t.Errorf("新的刷新令牌键应存在")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sessionService_Remove(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
ctx := context.Background()
|
||||
authCtx := createTestAuthContext()
|
||||
s := &sessionService{}
|
||||
|
||||
// 创建一个会话
|
||||
td, err := s.Create(ctx, authCtx, true)
|
||||
if err != nil {
|
||||
t.Fatalf("无法创建测试会话: %v", err)
|
||||
}
|
||||
|
||||
validAccessToken := td.AccessToken
|
||||
validRefreshToken := td.RefreshToken
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
accessToken string
|
||||
refreshToken string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "删除有效会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
accessToken: validAccessToken,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "删除已删除的会话",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
accessToken: validAccessToken,
|
||||
refreshToken: validRefreshToken,
|
||||
},
|
||||
wantErr: false, // 删除不存在的会话不应报错
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
// 验证键已被删除
|
||||
if mr.Exists(accessKey(tt.args.accessToken)) {
|
||||
t.Errorf("访问令牌键应已被删除")
|
||||
}
|
||||
if mr.Exists(refreshKey(tt.args.refreshToken)) {
|
||||
t.Errorf("刷新令牌键应已被删除")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthContext_AnyPermission(t *testing.T) {
|
||||
type fields struct {
|
||||
Payload auth.Payload
|
||||
Permissions map[string]struct{}
|
||||
Metadata map[string]interface{}
|
||||
}
|
||||
type args struct {
|
||||
requiredPermission []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "用户拥有所需权限",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
"write": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "用户拥有至少一个所需权限",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read", "admin"},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "用户没有所需权限",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{
|
||||
"read": {},
|
||||
},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"admin", "delete"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "空权限列表",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: map[string]struct{}{},
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil权限列表",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{Type: auth.PayloadUser, Id: 1},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil认证上下文",
|
||||
fields: fields{
|
||||
Payload: auth.Payload{},
|
||||
Permissions: nil,
|
||||
Metadata: nil,
|
||||
},
|
||||
args: args{
|
||||
requiredPermission: []string{"read"},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &auth.Context{
|
||||
Payload: tt.fields.Payload,
|
||||
Permissions: tt.fields.Permissions,
|
||||
Metadata: tt.fields.Metadata,
|
||||
}
|
||||
if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want {
|
||||
t.Errorf("AnyPermission() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,244 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"platform/web/testutil"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
)
|
||||
|
||||
func Test_verifierService_SendSms(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
phone string
|
||||
purpose VerifierSmsPurpose
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
setup func(mr *miniredis.Miniredis)
|
||||
wantErr bool
|
||||
wantErrType error
|
||||
}{
|
||||
{
|
||||
name: "正常发送成功(无旧验证码)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345678",
|
||||
purpose: VerifierSmsPurposeLogin,
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "正常发送成功(有旧验证码)",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345679",
|
||||
purpose: VerifierSmsPurposeLogin,
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
key := smsKey("13812345679", VerifierSmsPurposeLogin)
|
||||
mr.Set(key, "123456")
|
||||
mr.SetTTL(key, 10*time.Minute)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "发送频率过快",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345680",
|
||||
purpose: VerifierSmsPurposeLogin,
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
key := smsKey("13812345680", VerifierSmsPurposeLogin) + ":lock"
|
||||
mr.Set(key, "")
|
||||
mr.SetTTL(key, 1*time.Minute)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrType: VerifierServiceSendLimitErr(0),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 设置 Redis 测试环境
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
defer mr.Close()
|
||||
|
||||
// 执行测试前的设置
|
||||
if tt.setup != nil {
|
||||
tt.setup(mr)
|
||||
}
|
||||
|
||||
s := &verifierService{}
|
||||
err := s.SendSms(tt.args.ctx, tt.args.phone, tt.args.purpose)
|
||||
|
||||
// 验证错误
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("SendSms() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证错误类型
|
||||
if tt.wantErr && tt.wantErrType != nil {
|
||||
var verifierServiceSendLimitErr VerifierServiceSendLimitErr
|
||||
if errors.As(err, &verifierServiceSendLimitErr) {
|
||||
var verifierServiceSendLimitErr VerifierServiceSendLimitErr
|
||||
if !errors.As(tt.wantErrType, &verifierServiceSendLimitErr) {
|
||||
t.Errorf("SendSms() error type = %T, wantErrType %T", err, tt.wantErrType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 Redis 中的记录
|
||||
if !tt.wantErr {
|
||||
key := smsKey(tt.args.phone, tt.args.purpose)
|
||||
keyLock := key + ":lock"
|
||||
|
||||
// 验证码应存在
|
||||
val, err := mr.Get(key)
|
||||
if err != nil {
|
||||
t.Errorf("验证码应存在但不存在: %v", err)
|
||||
}
|
||||
|
||||
// 限速锁应存在
|
||||
_, err = mr.Get(keyLock)
|
||||
if err != nil {
|
||||
t.Errorf("限速锁应存在但不存在: %v", err)
|
||||
}
|
||||
|
||||
// 验证码应为6位数字
|
||||
code, err := strconv.Atoi(val)
|
||||
if err != nil || code < 100000 || code > 999999 {
|
||||
t.Errorf("验证码应为6位数字: %v", val)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_verifierService_VerifySms(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
phone string
|
||||
code string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
setup func(mr *miniredis.Miniredis)
|
||||
wantErr bool
|
||||
wantErrType error
|
||||
}{
|
||||
{
|
||||
name: "验证码正确",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345678",
|
||||
code: "123456",
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
key := smsKey("13812345678", VerifierSmsPurposeLogin)
|
||||
keyLock := key + ":lock"
|
||||
mr.Set(key, "123456")
|
||||
mr.SetTTL(key, 10*time.Minute)
|
||||
mr.Set(keyLock, "")
|
||||
mr.SetTTL(keyLock, 1*time.Minute)
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "验证码错误",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345679",
|
||||
code: "654321",
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
key := smsKey("13812345679", VerifierSmsPurposeLogin)
|
||||
keyLock := key + ":lock"
|
||||
mr.Set(key, "123456")
|
||||
mr.SetTTL(key, 10*time.Minute)
|
||||
mr.Set(keyLock, "")
|
||||
mr.SetTTL(keyLock, 1*time.Minute)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrType: ErrVerifierServiceInvalid,
|
||||
},
|
||||
{
|
||||
name: "验证码过期",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345680",
|
||||
code: "123456",
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
// 不设置验证码,模拟过期情况
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrType: ErrVerifierServiceInvalid,
|
||||
},
|
||||
{
|
||||
name: "手机号错误",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
phone: "13812345681",
|
||||
code: "123456",
|
||||
},
|
||||
setup: func(mr *miniredis.Miniredis) {
|
||||
// 设置一个不同手机号的验证码
|
||||
key := smsKey("13800000000", VerifierSmsPurposeLogin)
|
||||
mr.Set(key, "123456")
|
||||
mr.SetTTL(key, 10*time.Minute)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrType: ErrVerifierServiceInvalid,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// 设置 Redis 测试环境
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
defer mr.Close()
|
||||
|
||||
// 执行测试前的设置
|
||||
if tt.setup != nil {
|
||||
tt.setup(mr)
|
||||
}
|
||||
|
||||
s := &verifierService{}
|
||||
err := s.VerifySms(tt.args.ctx, tt.args.phone, tt.args.code)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("VerifySms() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 检查错误类型
|
||||
if tt.wantErr && tt.wantErrType != nil && !errors.Is(err, tt.wantErrType) {
|
||||
t.Errorf("VerifySms() error = %v, wantErrType %v", err, tt.wantErrType)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证成功后 Redis 中应该没有该记录
|
||||
if err == nil {
|
||||
key := smsKey(tt.args.phone, VerifierSmsPurposeLogin)
|
||||
keyLock := key + ":lock"
|
||||
|
||||
_, redisErr := mr.Get(key)
|
||||
if redisErr == nil {
|
||||
t.Errorf("验证码验证成功后应删除,但仍存在")
|
||||
}
|
||||
|
||||
_, redisErr = mr.Get(keyLock)
|
||||
if redisErr == nil {
|
||||
t.Errorf("限速锁验证成功后应删除,但仍存在")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
g"platform/web/globals"
|
||||
m"platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"testing"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SetupDBTest 创建一个基于 SQLite 内存数据库的 GORM 连接
|
||||
func SetupDBTest(t *testing.T) *gorm.DB {
|
||||
// 使用 SQLite 内存数据库
|
||||
gormDB, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
t.Fatalf("gorm 打开 SQLite 内存数据库失败: %v", err)
|
||||
}
|
||||
|
||||
// 自动迁移数据表结构
|
||||
err = gormDB.AutoMigrate(
|
||||
&m.User{},
|
||||
&m.Whitelist{},
|
||||
&m.Resource{},
|
||||
&m.ResourcePss{},
|
||||
&m.Proxy{},
|
||||
&m.Channel{},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("自动迁移表结构失败: %v", err)
|
||||
}
|
||||
|
||||
// 设置全局数据库连接
|
||||
q.SetDefault(gormDB)
|
||||
g.DB = gormDB
|
||||
|
||||
return gormDB
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
g "platform/web/globals"
|
||||
"testing"
|
||||
|
||||
"github.com/alicebob/miniredis/v2"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// SetupRedisTest 创建一个测试用的Redis实例
|
||||
// 返回miniredis实例,使用t.Cleanup自动清理资源
|
||||
func SetupRedisTest(t *testing.T) *miniredis.Miniredis {
|
||||
mr, err := miniredis.Run()
|
||||
if err != nil {
|
||||
t.Fatalf("设置 miniredis 失败: %v", err)
|
||||
}
|
||||
|
||||
// 替换 Redis 客户端为测试客户端
|
||||
g.Redis = redis.NewClient(&redis.Options{
|
||||
Addr: mr.Addr(),
|
||||
})
|
||||
|
||||
// 使用t.Cleanup确保测试结束后恢复原始客户端并关闭miniredis
|
||||
t.Cleanup(func() {
|
||||
mr.Close()
|
||||
})
|
||||
|
||||
return mr
|
||||
}
|
||||
@@ -1,150 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
g "platform/web/globals"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// MockCloudClient 是CloudClient接口的测试实现
|
||||
type MockCloudClient struct {
|
||||
// 存储预期结果的字段
|
||||
EdgesMock func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error)
|
||||
ConnectMock func(param g.CloudConnectReq) error
|
||||
DisconnectMock func(param g.CloudDisconnectReq) (int, error)
|
||||
AutoQueryMock func() (g.CloudConnectResp, error)
|
||||
|
||||
// 记录调用历史
|
||||
EdgesCalls []g.CloudEdgesReq
|
||||
ConnectCalls []g.CloudConnectReq
|
||||
DisconnectCalls []g.CloudDisconnectReq
|
||||
AutoQueryCalls int
|
||||
|
||||
// 用于并发安全
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// 确保MockCloudClient实现了CloudClient接口
|
||||
var _ g.CloudClient = (*MockCloudClient)(nil)
|
||||
|
||||
func (m *MockCloudClient) CloudEdges(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.EdgesCalls = append(m.EdgesCalls, param)
|
||||
if m.EdgesMock != nil {
|
||||
return m.EdgesMock(param)
|
||||
}
|
||||
return &g.CloudEdgesResp{}, nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudConnect(param g.CloudConnectReq) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.ConnectCalls = append(m.ConnectCalls, param)
|
||||
if m.ConnectMock != nil {
|
||||
return m.ConnectMock(param)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudDisconnect(param g.CloudDisconnectReq) (int, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.DisconnectCalls = append(m.DisconnectCalls, param)
|
||||
if m.DisconnectMock != nil {
|
||||
return m.DisconnectMock(param)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *MockCloudClient) CloudAutoQuery() (g.CloudConnectResp, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.AutoQueryCalls++
|
||||
if m.AutoQueryMock != nil {
|
||||
return m.AutoQueryMock()
|
||||
}
|
||||
return g.CloudConnectResp{}, nil
|
||||
}
|
||||
|
||||
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||
mock := &MockCloudClient{
|
||||
EdgesMock: func(param g.CloudEdgesReq) (*g.CloudEdgesResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
ConnectMock: func(param g.CloudConnectReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
DisconnectMock: func(param g.CloudDisconnectReq) (int, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
AutoQueryMock: func() (g.CloudConnectResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
g.Cloud = mock
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockGatewayClient 是GatewayClient接口的测试实现
|
||||
type MockGatewayClient struct {
|
||||
Host string
|
||||
}
|
||||
|
||||
// 确保MockGatewayClient实现了GatewayClient接口
|
||||
var _ g.GatewayClient = (*MockGatewayClient)(nil)
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortConfigs(params []g.PortConfigsReq) error {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortConfigsCalls = append(testGatewayBase.PortConfigsCalls, params)
|
||||
if testGatewayBase.PortConfigsMock != nil {
|
||||
return testGatewayBase.PortConfigsMock(m, params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortActive(param ...g.PortActiveReq) (map[string]g.PortData, error) {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortActiveCalls = append(testGatewayBase.PortActiveCalls, param)
|
||||
if testGatewayBase.PortActiveMock != nil {
|
||||
return testGatewayBase.PortActiveMock(m, param...)
|
||||
}
|
||||
return map[string]g.PortData{}, nil
|
||||
}
|
||||
|
||||
type GatewayClientIns struct {
|
||||
|
||||
// 存储预期结果的字段
|
||||
PortConfigsMock func(c *MockGatewayClient, params []g.PortConfigsReq) error
|
||||
PortActiveMock func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error)
|
||||
|
||||
// 记录调用历史
|
||||
PortConfigsCalls [][]g.PortConfigsReq
|
||||
PortActiveCalls [][]g.PortActiveReq
|
||||
|
||||
// 用于并发安全
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var testGatewayBase = &GatewayClientIns{
|
||||
PortConfigsMock: func(c *MockGatewayClient, params []g.PortConfigsReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
PortActiveMock: func(c *MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数
|
||||
func SetupGatewayClientMock(t *testing.T) *GatewayClientIns {
|
||||
g.GatewayInitializer = func(url, username, password string) g.GatewayClient {
|
||||
return &MockGatewayClient{
|
||||
Host: url,
|
||||
}
|
||||
}
|
||||
return testGatewayBase
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// SliceEqual 检查两个字符串切片是否完全相等(忽略顺序)
|
||||
func SliceEqual(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 复制切片以避免修改原始数据
|
||||
aCopy := make([]string, len(a))
|
||||
bCopy := make([]string, len(b))
|
||||
copy(aCopy, a)
|
||||
copy(bCopy, b)
|
||||
|
||||
// 排序两个切片
|
||||
sort.Strings(aCopy)
|
||||
sort.Strings(bCopy)
|
||||
|
||||
// 比较排序后的切片
|
||||
return reflect.DeepEqual(aCopy, bCopy)
|
||||
}
|
||||
Reference in New Issue
Block a user