新增通道服务相关测试用例

This commit is contained in:
2025-04-01 11:32:17 +08:00
parent 87eecdb8cb
commit e4bd86642e
10 changed files with 1249 additions and 69 deletions

View File

@@ -99,6 +99,9 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
proxies, err := tx.Proxy.Where(
q.Proxy.ID.In(proxyIds...),
).Find()
if err != nil {
return err
}
slog.Debug("查找代理", "rid", rid, "step", time.Since(step))
@@ -163,7 +166,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
}
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
gateway := remote.NewGateway(
proxy.Host,
secret[0],
secret[1],
@@ -204,7 +207,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
}
}
if len(edges) > 0 {
_, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{
_, err := remote.Cloud.CloudDisconnect(remote.CloudDisconnectReq{
Uuid: proxy.Name,
Edge: edges,
})
@@ -395,7 +398,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
// 查询已配置的节点
step = time.Now()
rProxyConfigs, err := remote.Client.CloudAutoQuery()
rProxyConfigs, err := remote.Cloud.CloudAutoQuery()
if err != nil {
return nil, err
}
@@ -466,7 +469,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
step = time.Now()
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
err := remote.Client.CloudConnect(remote.CloudConnectReq{
err := remote.Cloud.CloudConnect(remote.CloudConnectReq{
Uuid: info.proxy.Name,
Edge: nil,
AutoConfig: []remote.AutoConfig{{
@@ -520,7 +523,7 @@ func assignPort(
expiration time.Time,
filter NodeFilterConfig,
) ([]string, []*models.Channel, error) {
var step = time.Now()
var step time.Time
var configs = proxies.configs
var exists = proxies.channels
@@ -639,7 +642,7 @@ func assignPort(
step = time.Now()
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
gateway := remote.NewGateway(
proxy.Host,
secret[0],
secret[1],
@@ -677,6 +680,10 @@ func chKey(channel *models.Channel) string {
}
func cache(ctx context.Context, channels []*models.Channel) error {
if len(channels) == 0 {
return nil
}
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
@@ -685,7 +692,7 @@ func cache(ctx context.Context, channels []*models.Channel) error {
if err != nil {
return err
}
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
pipe.Set(ctx, chKey(channel), string(marshal), time.Until(channel.Expiration))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
@@ -702,6 +709,10 @@ func cache(ctx context.Context, channels []*models.Channel) error {
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
if len(channels) == 0 {
return nil
}
keys := make([]string, len(channels))
for i := range channels {
keys[i] = chKey(channels[i])

View File

@@ -0,0 +1,985 @@
package services
import (
"context"
"encoding/json"
"fmt"
"platform/pkg/env"
"platform/pkg/remote"
"platform/pkg/testutil"
"platform/web/models"
"regexp"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gofiber/fiber/v2/middleware/requestid"
"gorm.io/gorm"
)
func Test_genPassPair(t *testing.T) {
tests := []struct {
name string
}{
{
name: "正常生成随机用户名和密码",
},
{
name: "多次调用生成不同的值",
},
}
// 第一个测试:检查生成的用户名和密码是否有效
t.Run(tests[0].name, func(t *testing.T) {
username, password := genPassPair()
if username == "" {
t.Errorf("genPassPair() username is empty")
}
if password == "" {
t.Errorf("genPassPair() password is empty")
}
})
// 第二个测试:确保多次调用生成不同的值
t.Run(tests[1].name, func(t *testing.T) {
username1, password1 := genPassPair()
username2, password2 := genPassPair()
if username1 == username2 {
t.Errorf("genPassPair() generated the same username twice: %v", username1)
}
if password1 == password2 {
t.Errorf("genPassPair() generated the same password twice: %v", password1)
}
})
}
func Test_chKey(t *testing.T) {
type args struct {
channel *models.Channel
}
tests := []struct {
name string
args args
want string
}{
{
name: "ID为1的通道",
args: args{
channel: &models.Channel{
ID: 1,
},
},
want: "channel:1",
},
{
name: "ID为100的通道",
args: args{
channel: &models.Channel{
ID: 100,
},
},
want: "channel:100",
},
{
name: "ID为0的通道",
args: args{
channel: &models.Channel{
ID: 0,
},
},
want: "channel:0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := chKey(tt.args.channel); got != tt.want {
t.Errorf("chKey() = %v, want %v", got, tt.want)
}
})
}
}
func Test_cache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
now := time.Now()
expiration := now.Add(24 * time.Hour)
testChannels := []*models.Channel{
{
ID: 1,
UserID: 100,
ProxyID: 10,
ProxyPort: 8080,
Protocol: "http",
Expiration: expiration,
},
{
ID: 2,
UserID: 101,
ProxyID: 11,
ProxyPort: 8081,
Protocol: "socks5",
Expiration: expiration,
},
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常缓存多个通道",
args: args{
ctx: context.Background(),
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: context.Background(),
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证缓存结果
if len(tt.args.channels) > 0 {
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if !mr.Exists(key) {
t.Errorf("缓存未包含通道键 %s", key)
} else {
// 验证缓存的数据是否正确
data, _ := mr.Get(key)
var cachedChannel models.Channel
err := json.Unmarshal([]byte(data), &cachedChannel)
if err != nil {
t.Errorf("无法解析缓存数据: %v", err)
}
if cachedChannel.ID != channel.ID {
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
}
}
}
// 验证是否设置了过期时间
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
ttl := mr.TTL(key)
if ttl <= 0 {
t.Errorf("键 %s 没有设置过期时间", key)
}
}
// 验证是否添加了有序集合
if !mr.Exists("tasks:channel") {
t.Errorf("ZAdd未创建有序集合 tasks:channel")
}
}
})
}
}
func Test_deleteCache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
testChannels := []*models.Channel{
{ID: 1},
{ID: 2},
{ID: 3},
}
ctx := context.Background()
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常删除多个通道缓存",
args: args{
ctx: ctx,
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: ctx,
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
// 预先设置缓存数据
for _, channel := range testChannels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
}
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证删除结果
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if mr.Exists(key) {
t.Errorf("通道键 %s 未被删除", key)
}
}
})
}
}
func Test_channelService_CreateChannel(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
mc := testutil.SetupCloudClientMock(t)
env.DebugExternalChange = false
type args struct {
ctx context.Context
auth *AuthContext
resourceId int32
protocol ChannelProtocol
authType ChannelAuthType
count int
nodeFilter []NodeFilterConfig
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
tests := []struct {
name string
args args
setup func()
want []string
wantErr bool
wantErrContains string
checkCache func(t *testing.T)
}{
{
name: "用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}},
resourceId: 4,
protocol: ProtocolHTTP,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
4, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道 - PostgreSQL返回ID
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
"http://proxy3.example.com:10002",
},
checkCache: func(t *testing.T) {
// 检查总共创建了3个通道
for i := 4; i <= 6; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "用户创建HTTP白名单通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 5,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
5, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟查询白名单 - 3个IP
whitelistRows := sqlmock.NewRows([]string{"host"}).
AddRow("192.168.1.1").
AddRow("192.168.1.2").
AddRow("192.168.1.3")
mdb.ExpectQuery("SELECT").
WithArgs(int32(100)).
WillReturnRows(whitelistRows)
// 模拟保存通道 - 2个通道 * 3个白名单 = 6个
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).
AddRow(7).AddRow(8).AddRow(9).
AddRow(10).AddRow(11).AddRow(12),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
},
checkCache: func(t *testing.T) {
// 检查应该创建了6个通道2个通道 * 3个白名单
for i := 7; i <= 12; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "管理员创建SOCKS5密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadAdmin,
Id: 1,
},
},
resourceId: 6,
protocol: ProtocolSocks5,
authType: ChannelAuthTypePass,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy4": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 5},
},
}, nil
}
// 设置CloudConnect的模拟逻辑
mc.ConnectMock = func(param remote.CloudConnectReq) error {
return nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
6, 102, true,
1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"socks5://proxy4.example.com:10000",
"socks5://proxy4.example.com:10001",
},
checkCache: func(t *testing.T) {
for i := 13; i <= 14; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "套餐不存在",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 999,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐不存在
mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "套餐不存在",
},
{
name: "套餐没有权限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 101,
},
},
resourceId: 7,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
7, 102, true, // 注意user_id 与 auth.Id 不匹配
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "无权限访问",
},
{
name: "套餐配额不足",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 10,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
2, 100, true,
0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "套餐配额不足",
},
{
name: "端口数量达到上限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 8,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy5": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
8, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟通道端口已用尽
// 构建一个大量已使用端口的结果集
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
for i := 10000; i < 65535; i++ {
channelRows.AddRow(5, i)
}
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "端口数量不足",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
// 检查错误或结果
if tt.wantErr {
if err == nil {
t.Errorf("CreateChannel() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("CreateChannel() 返回长度 = %v, want %v", len(got), len(tt.want))
return
}
// 检查返回地址格式
for _, addr := range got {
protocol := string(tt.args.protocol)
if !strings.HasPrefix(addr, protocol+"://") {
t.Errorf("CreateChannel() 地址 %v 不是有效的 %s 地址", addr, protocol)
}
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)
}
})
}
}
func Test_channelService_RemoveChannels(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
mg := testutil.SetupGatewayClientMock(t)
env.DebugExternalChange = false
type args struct {
ctx context.Context
auth *AuthContext
id []int32
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
checkCache func(t *testing.T)
}{
{
name: "管理员删除多个通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadAdmin,
Id: 1,
},
},
id: []int32{1, 2, 3},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
channel := models.Channel{ID: id, UserID: 100}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
}
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)).
AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)).
AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1), int32(2), int32(3)).
WillReturnRows(channelRows)
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1).
AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1), int32(2)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 3))
// 提交事务
mdb.ExpectCommit()
},
checkCache: func(t *testing.T) {
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
}
},
},
{
name: "用户删除自己的通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
id: []int32{1},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
key := "channel:1"
channel := models.Channel{ID: 1, UserID: 100}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 模拟查询已激活的端口
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
return map[string]remote.PortData{
"10001": {
Edge: []string{"edge1", "edge2"},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(channelRows)
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
checkCache: func(t *testing.T) {
key := "channel:1"
if mr.Exists(key) {
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
},
},
{
name: "用户删除不属于自己的通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
id: []int32{5},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
key := "channel:5"
channel := models.Channel{ID: 5, UserID: 101}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(5)).
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "无权限访问",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
// 检查错误
if tt.wantErr {
if err == nil {
t.Errorf("RemoveChannels() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)
}
})
}
}

View File

@@ -3,36 +3,12 @@ package services
import (
"context"
"errors"
"platform/pkg/rds"
"platform/pkg/testutil"
"reflect"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
// 设置 Redis 模拟服务器
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("无法启动 miniredis: %v", err)
}
// 替换 Redis 客户端为测试客户端
origClient := rds.Client
rds.Client = redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
t.Cleanup(func() {
mr.Close()
rds.Client = origClient
})
return mr
}
// 创建测试用的认证上下文
func createTestAuthContext() AuthContext {
return AuthContext{
@@ -52,7 +28,7 @@ func createTestAuthContext() AuthContext {
}
func Test_sessionService_Create(t *testing.T) {
mr := setupTestRedis(t)
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
auth := createTestAuthContext()
@@ -162,7 +138,7 @@ func Test_sessionService_Create(t *testing.T) {
}
func Test_sessionService_Find(t *testing.T) {
_ = setupTestRedis(t)
testutil.SetupRedisTest(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}
@@ -221,7 +197,7 @@ func Test_sessionService_Find(t *testing.T) {
}
func Test_sessionService_Refresh(t *testing.T) {
mr := setupTestRedis(t)
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}
@@ -314,7 +290,7 @@ func Test_sessionService_Refresh(t *testing.T) {
}
func Test_sessionService_Remove(t *testing.T) {
mr := setupTestRedis(t)
mr := testutil.SetupRedisTest(t)
ctx := context.Background()
auth := createTestAuthContext()
s := &sessionService{}

View File

@@ -2,30 +2,14 @@ package services
import (
"context"
"platform/pkg/rds"
"platform/pkg/testutil"
"strconv"
"testing"
"time"
"github.com/alicebob/miniredis/v2"
"github.com/redis/go-redis/v9"
)
// 设置测试的 Redis 环境
func setupRedisTest(t *testing.T) *miniredis.Miniredis {
mr, err := miniredis.Run()
if err != nil {
t.Fatalf("设置 miniredis 失败: %v", err)
}
// 替换 redis 客户端为测试客户端
rds.Client = redis.NewClient(&redis.Options{
Addr: mr.Addr(),
})
return mr
}
func Test_verifierService_SendSms(t *testing.T) {
type args struct {
ctx context.Context
@@ -82,7 +66,7 @@ func Test_verifierService_SendSms(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置 Redis 测试环境
mr := setupRedisTest(t)
mr := testutil.SetupRedisTest(t)
defer mr.Close()
// 执行测试前的设置
@@ -216,7 +200,7 @@ func Test_verifierService_VerifySms(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// 设置 Redis 测试环境
mr := setupRedisTest(t)
mr := testutil.SetupRedisTest(t)
defer mr.Close()
// 执行测试前的设置