认证授权测试代码与业务代码质量修复

This commit is contained in:
2025-03-22 16:37:24 +08:00
parent 6ddf1118a5
commit c3abb42bce
10 changed files with 960 additions and 33 deletions

View File

@@ -42,13 +42,13 @@ func Login(c *fiber.Ctx) error {
func loginByPhone(c *fiber.Ctx, req *LoginReq) error {
// 验证验证码
ok, err := services.Verifier.VerifySms(c.Context(), req.Username, req.Password)
err := services.Verifier.VerifySms(c.Context(), req.Username, req.Password)
if err != nil {
if errors.Is(err, services.ErrVerifierServiceInvalid) {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
return err
}
if !ok {
return fiber.NewError(fiber.StatusBadRequest, "验证码错误")
}
// 查找用户 todo 获取权限信息
var tx = q.Q.Begin()

View File

@@ -104,7 +104,7 @@ func clientCredentials(c *fiber.Ctx, req *TokenReq) error {
}
scope := strings.Split(req.Scope, ",")
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope)
token, err := services.Auth.OauthClientCredentials(c.Context(), client, scope...)
if err != nil {
return sendError(c, err.(services.AuthServiceOauthError))
}

View File

@@ -8,8 +8,6 @@ import (
var Auth = &authService{}
type authService struct{}
type AuthServiceError string
func (e AuthServiceError) Error() string {
@@ -31,6 +29,8 @@ var (
ErrOauthUnsupportedGrantType = AuthServiceOauthError("unsupported_grant_type")
)
type authService struct{}
// OauthAuthorizationCode 验证授权码
func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models.Client, code, redirectURI, codeVerifier string) (*TokenDetails, error) {
// TODO: 从数据库验证授权码
@@ -38,7 +38,7 @@ func (s *authService) OauthAuthorizationCode(ctx context.Context, client *models
}
// OauthClientCredentials 验证客户端凭证
func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...[]string) (*TokenDetails, error) {
func (s *authService) OauthClientCredentials(ctx context.Context, client *models.Client, scope ...string) (*TokenDetails, error) {
var clientType PayloadType
switch client.Spec {
@@ -47,14 +47,17 @@ func (s *authService) OauthClientCredentials(ctx context.Context, client *models
case 1:
clientType = PayloadClientPublic
case 2:
clientType = PayloadClientConfidential
clientType = PayloadClientPublic
}
var permissions = make(map[string]struct{}, len(scope))
for _, item := range scope {
permissions[item] = struct{}{}
}
// 保存会话并返回令牌
auth := AuthContext{
Permissions: map[string]struct{}{
"client": {},
},
Permissions: permissions,
Payload: Payload{
Type: clientType,
Id: client.ID,

146
web/services/auth_test.go Normal file
View File

@@ -0,0 +1,146 @@
package services
import (
"context"
"platform/web/models"
"reflect"
"testing"
"time"
)
// mockSessionService 用于模拟Session服务的行为
type mockSessionService struct {
createFunc func(ctx context.Context, auth AuthContext) (*TokenDetails, error)
}
func (m *mockSessionService) Find(ctx context.Context, token string) (*AuthContext, error) {
panic("implement me")
}
func (m *mockSessionService) Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error) {
panic("implement me")
}
func (m *mockSessionService) Remove(ctx context.Context, accessToken, refreshToken string) error {
panic("implement me")
}
func (m *mockSessionService) Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error) {
return m.createFunc(ctx, auth)
}
func Test_authService_OauthClientCredentials(t *testing.T) {
// 暂存原始Session服务
originalSession := Session
defer func() {
// 测试结束后恢复原始Session服务
Session = originalSession
}()
// 预设的令牌详情
expectedToken := &TokenDetails{
AccessToken: "test-access-token",
RefreshToken: "test-refresh-token",
AccessTokenExpires: time.Now().Add(3600 * time.Second),
}
type args struct {
ctx context.Context
client *models.Client
scope []string
}
tests := []struct {
name string
args args
mockCreateErr error
want *TokenDetails
wantErr bool
wantPayload Payload
}{
{
name: "成功 - 机密客户端 (Spec=0)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 0},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientConfidential,
Id: 1,
},
},
{
name: "成功 - 公共客户端 (Spec=1)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 1},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
Id: 1,
},
},
{
name: "成功 - 公共客户端 (Spec=2)",
args: args{
ctx: context.Background(),
client: &models.Client{ID: 1, Spec: 2},
scope: []string{"read", "write"},
},
mockCreateErr: nil,
want: expectedToken,
wantErr: false,
wantPayload: Payload{
Type: PayloadClientPublic,
Id: 1,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 为每个测试用例设置模拟的Session服务
mockSession := &mockSessionService{
createFunc: func(ctx context.Context, auth AuthContext) (*TokenDetails, error) {
// 验证权限映射
if len(auth.Permissions) != len(tt.args.scope) {
t.Errorf("Permissions length = %v, want %v", len(auth.Permissions), len(tt.args.scope))
for key := range auth.Permissions {
if _, ok := auth.Permissions[key]; !ok {
t.Errorf("Permissions[%s] not found", key)
}
}
}
// 验证Payload
if auth.Payload.Type != tt.wantPayload.Type {
t.Errorf("Payload.Type = %v, want %v", auth.Payload.Type, tt.wantPayload.Type)
}
if auth.Payload.Id != tt.wantPayload.Id {
t.Errorf("Payload.Id = %v, want %v", auth.Payload.Id, tt.wantPayload.Id)
}
return expectedToken, tt.mockCreateErr
},
}
// 替换Session服务为模拟实现
Session = mockSession
s := &authService{}
got, err := s.OauthClientCredentials(tt.args.ctx, tt.args.client, tt.args.scope...)
if (err != nil) != tt.wantErr {
t.Errorf("OauthClientCredentials() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("OauthClientCredentials() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -14,9 +14,17 @@ import (
// region SessionService
var Session = &sessionService{}
var Session SessionServiceInter = &sessionService{}
type sessionService struct {
type SessionServiceInter interface {
// Find 通过访问令牌获取会话信息
Find(ctx context.Context, token string) (*AuthContext, error)
// Create 创建一个新的会话
Create(ctx context.Context, auth AuthContext, config ...SessionConfig) (*TokenDetails, error)
// Refresh 刷新一个会话
Refresh(ctx context.Context, refreshToken string, config ...SessionConfig) (*TokenDetails, error)
// Remove 删除会话
Remove(ctx context.Context, accessToken, refreshToken string) error
}
type SessionServiceError string
@@ -29,6 +37,8 @@ var (
ErrInvalidToken = SessionServiceError("invalid_token")
)
type sessionService struct{}
// Find 通过访问令牌获取会话信息
func (s *sessionService) Find(ctx context.Context, token string) (*AuthContext, error) {

View File

@@ -0,0 +1,486 @@
package services
import (
"context"
"errors"
"platform/init/rds"
"reflect"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
// 设置 Redis 模拟服务器
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("无法启动 miniredis: %v", err)
}
// 替换 Redis 客户端为测试客户端
origClient := rds.Client
rds.Client = redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
t.Cleanup(func() {
mr.Close()
rds.Client = origClient
})
return mr
}
// 创建测试用的认证上下文
func createTestAuthContext() AuthContext {
return AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 1001,
},
Permissions: map[string]struct{}{
"read": {},
"write": {},
},
Metadata: map[string]interface{}{
"username": "testuser",
"email": "test@example.com",
},
}
}
func Test_sessionService_Create(t *testing.T) {
mr := setupTestRedis(t)
ctx := context.Background()
auth := createTestAuthContext()
type args struct {
ctx context.Context
auth AuthContext
config []SessionConfig
}
tests := []struct {
name string
args args
want func(*TokenDetails) bool
wantErr bool
}{
{
name: "使用默认配置创建会话",
args: args{
ctx: ctx,
auth: auth,
},
want: func(td *TokenDetails) bool {
// 验证令牌存在且格式正确
if td.AccessToken == "" || td.RefreshToken == "" {
return false
}
// 验证到期时间在未来
now := time.Now()
if td.AccessTokenExpires.Before(now) || td.RefreshTokenExpires.Before(now) {
return false
}
// 验证认证信息正确
if !reflect.DeepEqual(td.Auth, auth) {
return false
}
return true
},
wantErr: false,
},
{
name: "使用自定义配置创建会话",
args: args{
ctx: ctx,
auth: auth,
config: []SessionConfig{
{
AccessTokenDuration: 10 * time.Minute,
RefreshTokenDuration: 24 * time.Hour,
},
},
},
want: func(td *TokenDetails) bool {
// 验证令牌存在且格式正确
if td.AccessToken == "" || td.RefreshToken == "" {
return false
}
// 验证到期时间在未来且接近预期时间
now := time.Now()
expectedAccessExpiry := now.Add(10 * time.Minute)
expectedRefreshExpiry := now.Add(24 * time.Hour)
accessDiff := td.AccessTokenExpires.Sub(expectedAccessExpiry)
refreshDiff := td.RefreshTokenExpires.Sub(expectedRefreshExpiry)
if accessDiff < -2*time.Second || accessDiff > 2*time.Second {
return false
}
if refreshDiff < -2*time.Second || refreshDiff > 2*time.Second {
return false
}
// 验证认证信息正确
if !reflect.DeepEqual(td.Auth, auth) {
return false
}
return true
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll()
s := &sessionService{}
got, err := s.Create(tt.args.ctx, tt.args.auth, tt.args.config...)
if (err != nil) != tt.wantErr {
t.Errorf("Create() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.want(got) {
t.Errorf("Create() got = %v, want to satisfy conditions", got)
}
// 验证 Redis 中是否有相应的键
accessKey := accessKey(got.AccessToken)
refreshKey := refreshKey(got.RefreshToken)
if !mr.Exists(accessKey) {
t.Errorf("访问令牌键 %s 不存在于 Redis 中", accessKey)
}
if !mr.Exists(refreshKey) {
t.Errorf("刷新令牌键 %s 不存在于 Redis 中", refreshKey)
}
})
}
}
func Test_sessionService_Find(t *testing.T) {
_ = setupTestRedis(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}
// 创建一个有效的会话
td, err := s.Create(ctx, auth)
if err != nil {
t.Fatalf("无法创建测试会话: %v", err)
}
validToken := td.AccessToken
invalidToken := "invalid-token"
type args struct {
ctx context.Context
token string
}
tests := []struct {
name string
args args
want *AuthContext
wantErr error
}{
{
name: "查找有效令牌",
args: args{
ctx: ctx,
token: validToken,
},
want: &auth,
wantErr: nil,
},
{
name: "查找无效令牌",
args: args{
ctx: ctx,
token: invalidToken,
},
want: nil,
wantErr: ErrInvalidToken,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := s.Find(tt.args.ctx, tt.args.token)
if !errors.Is(err, tt.wantErr) {
t.Errorf("Find() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Find() got = %v, want %v", got, tt.want)
}
})
}
}
func Test_sessionService_Refresh(t *testing.T) {
mr := setupTestRedis(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}
// 创建一个初始会话
td, err := s.Create(ctx, auth)
if err != nil {
t.Fatalf("无法创建初始会话: %v", err)
}
validRefreshToken := td.RefreshToken
invalidRefreshToken := "invalid-refresh-token"
originalAccessToken := td.AccessToken
type args struct {
ctx context.Context
refreshToken string
config []SessionConfig
}
tests := []struct {
name string
args args
want func(*TokenDetails) bool
wantErr bool
}{
{
name: "使用有效的刷新令牌",
args: args{
ctx: ctx,
refreshToken: validRefreshToken,
},
want: func(td *TokenDetails) bool {
if td.AccessToken == "" || td.RefreshToken == "" {
return false
}
// 新的令牌应该与旧的不同
if td.AccessToken == originalAccessToken || td.RefreshToken == validRefreshToken {
return false
}
// 验证认证信息一致
if !reflect.DeepEqual(td.Auth, auth) {
return false
}
return true
},
wantErr: false,
},
{
name: "使用无效的刷新令牌",
args: args{
ctx: ctx,
refreshToken: invalidRefreshToken,
},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := s.Refresh(tt.args.ctx, tt.args.refreshToken, tt.args.config...)
if (err != nil) != tt.wantErr {
t.Errorf("Refresh() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.want != nil && !tt.want(got) {
t.Errorf("Refresh() got = %v, want to satisfy conditions", got)
}
if !tt.wantErr && got != nil {
// 验证旧的令牌已被删除
if mr.Exists(accessKey(originalAccessToken)) {
t.Errorf("原始访问令牌键应被删除")
}
if mr.Exists(refreshKey(validRefreshToken)) {
t.Errorf("原始刷新令牌键应被删除")
}
// 验证新的令牌已被添加
if !mr.Exists(accessKey(got.AccessToken)) {
t.Errorf("新的访问令牌键应存在")
}
if !mr.Exists(refreshKey(got.RefreshToken)) {
t.Errorf("新的刷新令牌键应存在")
}
}
})
}
}
func Test_sessionService_Remove(t *testing.T) {
mr := setupTestRedis(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}
// 创建一个会话
td, err := s.Create(ctx, auth)
if err != nil {
t.Fatalf("无法创建测试会话: %v", err)
}
validAccessToken := td.AccessToken
validRefreshToken := td.RefreshToken
type args struct {
ctx context.Context
accessToken string
refreshToken string
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "删除有效会话",
args: args{
ctx: ctx,
accessToken: validAccessToken,
refreshToken: validRefreshToken,
},
wantErr: false,
},
{
name: "删除已删除的会话",
args: args{
ctx: ctx,
accessToken: validAccessToken,
refreshToken: validRefreshToken,
},
wantErr: false, // 删除不存在的会话不应报错
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := s.Remove(tt.args.ctx, tt.args.accessToken, tt.args.refreshToken); (err != nil) != tt.wantErr {
t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr)
}
// 验证键已被删除
if mr.Exists(accessKey(tt.args.accessToken)) {
t.Errorf("访问令牌键应已被删除")
}
if mr.Exists(refreshKey(tt.args.refreshToken)) {
t.Errorf("刷新令牌键应已被删除")
}
})
}
}
func TestAuthContext_AnyPermission(t *testing.T) {
type fields struct {
Payload Payload
Permissions map[string]struct{}
Metadata map[string]interface{}
}
type args struct {
requiredPermission []string
}
tests := []struct {
name string
fields fields
args args
want bool
}{
{
name: "用户拥有所需权限",
fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
"write": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: true,
},
{
name: "用户拥有至少一个所需权限",
fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read", "admin"},
},
want: true,
},
{
name: "用户没有所需权限",
fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1},
Permissions: map[string]struct{}{
"read": {},
},
Metadata: nil,
},
args: args{
requiredPermission: []string{"admin", "delete"},
},
want: false,
},
{
name: "空权限列表",
fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1},
Permissions: map[string]struct{}{},
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
{
name: "nil权限列表",
fields: fields{
Payload: Payload{Type: PayloadUser, Id: 1},
Permissions: nil,
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
{
name: "nil认证上下文",
fields: fields{
Payload: Payload{},
Permissions: nil,
Metadata: nil,
},
args: args{
requiredPermission: []string{"read"},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &AuthContext{
Payload: tt.fields.Payload,
Permissions: tt.fields.Permissions,
Metadata: tt.fields.Metadata,
}
if got := a.AnyPermission(tt.args.requiredPermission...); got != tt.want {
t.Errorf("AnyPermission() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -15,9 +15,6 @@ import (
var Verifier = &verifierService{}
type verifierService struct {
}
type VerifierServiceError string
func (e VerifierServiceError) Error() string {
@@ -37,11 +34,10 @@ func (e VerifierServiceSendLimitErr) Error() string {
type VerifierSmsPurpose int
const (
Login VerifierSmsPurpose = iota
VerifierSmsPurposeLogin VerifierSmsPurpose = iota
)
func smsKey(phone string, purpose VerifierSmsPurpose) string {
return fmt.Sprintf("verify:sms:%d:%s", purpose, phone)
type verifierService struct {
}
func (s *verifierService) SendSms(ctx context.Context, phone string, purpose VerifierSmsPurpose) error {
@@ -83,8 +79,8 @@ func (s *verifierService) SendSms(ctx context.Context, phone string, purpose Ver
return nil
}
func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bool, error) {
key := smsKey(phone, Login)
func (s *verifierService) VerifySms(ctx context.Context, phone, code string) error {
key := smsKey(phone, VerifierSmsPurposeLogin)
keyLock := key + ":lock"
err := rds.Client.Watch(ctx, func(tx *redis.Tx) error {
@@ -114,11 +110,12 @@ func (s *verifierService) VerifySms(ctx context.Context, phone, code string) (bo
return nil
}, key)
if err != nil {
if errors.Is(err, ErrVerifierServiceInvalid) {
return false, nil
}
return false, err
return err
}
return true, nil
return nil
}
func smsKey(phone string, purpose VerifierSmsPurpose) string {
return fmt.Sprintf("verify:sms:%d:%s", purpose, phone)
}

View File

@@ -0,0 +1,257 @@
package services
import (
"context"
"platform/init/rds"
"strconv"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
// 设置测试的 Redis 环境
func setupRedisTest(t *testing.T) *miniredis.Miniredis {
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("设置 miniredis 失败: %v", err)
}
// 替换 redis 客户端为测试客户端
rds.Client = redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
return mr
}
func Test_verifierService_SendSms(t *testing.T) {
type args struct {
ctx context.Context
phone string
purpose VerifierSmsPurpose
}
tests := []struct {
name string
args args
setup func(mr *miniredis.Miniredis)
wantErr bool
wantErrType error
}{
{
name: "正常发送成功(无旧验证码)",
args: args{
ctx: context.Background(),
phone: "13812345678",
purpose: VerifierSmsPurposeLogin,
},
setup: func(mr *miniredis.Miniredis) {},
wantErr: false,
},
{
name: "正常发送成功(有旧验证码)",
args: args{
ctx: context.Background(),
phone: "13812345679",
purpose: VerifierSmsPurposeLogin,
},
setup: func(mr *miniredis.Miniredis) {
key := smsKey("13812345679", VerifierSmsPurposeLogin)
mr.Set(key, "123456")
mr.SetTTL(key, 10*time.Minute)
},
wantErr: false,
},
{
name: "发送频率过快",
args: args{
ctx: context.Background(),
phone: "13812345680",
purpose: VerifierSmsPurposeLogin,
},
setup: func(mr *miniredis.Miniredis) {
key := smsKey("13812345680", VerifierSmsPurposeLogin) + ":lock"
mr.Set(key, "")
mr.SetTTL(key, 1*time.Minute)
},
wantErr: true,
wantErrType: VerifierServiceSendLimitErr(0),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置 Redis 测试环境
mr := setupRedisTest(t)
defer mr.Close()
// 执行测试前的设置
if tt.setup != nil {
tt.setup(mr)
}
s := &verifierService{}
err := s.SendSms(tt.args.ctx, tt.args.phone, tt.args.purpose)
// 验证错误
if (err != nil) != tt.wantErr {
t.Errorf("SendSms() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证错误类型
if tt.wantErr && tt.wantErrType != nil {
if _, isSendLimitErr := err.(VerifierServiceSendLimitErr); isSendLimitErr {
if _, wantSendLimitErr := tt.wantErrType.(VerifierServiceSendLimitErr); !wantSendLimitErr {
t.Errorf("SendSms() error type = %T, wantErrType %T", err, tt.wantErrType)
}
}
}
// 验证 Redis 中的记录
if !tt.wantErr {
key := smsKey(tt.args.phone, tt.args.purpose)
keyLock := key + ":lock"
// 验证码应存在
val, err := mr.Get(key)
if err != nil {
t.Errorf("验证码应存在但不存在: %v", err)
}
// 限速锁应存在
_, err = mr.Get(keyLock)
if err != nil {
t.Errorf("限速锁应存在但不存在: %v", err)
}
// 验证码应为6位数字
code, err := strconv.Atoi(val)
if err != nil || code < 100000 || code > 999999 {
t.Errorf("验证码应为6位数字: %v", val)
}
}
})
}
}
func Test_verifierService_VerifySms(t *testing.T) {
type args struct {
ctx context.Context
phone string
code string
}
tests := []struct {
name string
args args
setup func(mr *miniredis.Miniredis)
wantErr bool
wantErrType error
}{
{
name: "验证码正确",
args: args{
ctx: context.Background(),
phone: "13812345678",
code: "123456",
},
setup: func(mr *miniredis.Miniredis) {
key := smsKey("13812345678", VerifierSmsPurposeLogin)
keyLock := key + ":lock"
mr.Set(key, "123456")
mr.SetTTL(key, 10*time.Minute)
mr.Set(keyLock, "")
mr.SetTTL(keyLock, 1*time.Minute)
},
wantErr: false,
},
{
name: "验证码错误",
args: args{
ctx: context.Background(),
phone: "13812345679",
code: "654321",
},
setup: func(mr *miniredis.Miniredis) {
key := smsKey("13812345679", VerifierSmsPurposeLogin)
keyLock := key + ":lock"
mr.Set(key, "123456")
mr.SetTTL(key, 10*time.Minute)
mr.Set(keyLock, "")
mr.SetTTL(keyLock, 1*time.Minute)
},
wantErr: true,
wantErrType: ErrVerifierServiceInvalid,
},
{
name: "验证码过期",
args: args{
ctx: context.Background(),
phone: "13812345680",
code: "123456",
},
setup: func(mr *miniredis.Miniredis) {
// 不设置验证码,模拟过期情况
},
wantErr: true,
wantErrType: ErrVerifierServiceInvalid,
},
{
name: "手机号错误",
args: args{
ctx: context.Background(),
phone: "13812345681",
code: "123456",
},
setup: func(mr *miniredis.Miniredis) {
// 设置一个不同手机号的验证码
key := smsKey("13800000000", VerifierSmsPurposeLogin)
mr.Set(key, "123456")
mr.SetTTL(key, 10*time.Minute)
},
wantErr: true,
wantErrType: ErrVerifierServiceInvalid,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置 Redis 测试环境
mr := setupRedisTest(t)
defer mr.Close()
// 执行测试前的设置
if tt.setup != nil {
tt.setup(mr)
}
s := &verifierService{}
err := s.VerifySms(tt.args.ctx, tt.args.phone, tt.args.code)
if (err != nil) != tt.wantErr {
t.Errorf("VerifySms() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 检查错误类型
if tt.wantErr && tt.wantErrType != nil && err != tt.wantErrType {
t.Errorf("VerifySms() error = %v, wantErrType %v", err, tt.wantErrType)
return
}
// 验证成功后 Redis 中应该没有该记录
if err == nil {
key := smsKey(tt.args.phone, VerifierSmsPurposeLogin)
keyLock := key + ":lock"
_, redisErr := mr.Get(key)
if redisErr == nil {
t.Errorf("验证码验证成功后应删除,但仍存在")
}
_, redisErr = mr.Get(keyLock)
if redisErr == nil {
t.Errorf("限速锁验证成功后应删除,但仍存在")
}
}
})
}
}