Files
platform/web/services/channel_test.go

986 lines
24 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"context"
"encoding/json"
"fmt"
"platform/pkg/env"
"platform/pkg/remote"
"platform/pkg/testutil"
"platform/web/models"
"regexp"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gofiber/fiber/v2/middleware/requestid"
"gorm.io/gorm"
)
func Test_genPassPair(t *testing.T) {
tests := []struct {
name string
}{
{
name: "正常生成随机用户名和密码",
},
{
name: "多次调用生成不同的值",
},
}
// 第一个测试:检查生成的用户名和密码是否有效
t.Run(tests[0].name, func(t *testing.T) {
username, password := genPassPair()
if username == "" {
t.Errorf("genPassPair() username is empty")
}
if password == "" {
t.Errorf("genPassPair() password is empty")
}
})
// 第二个测试:确保多次调用生成不同的值
t.Run(tests[1].name, func(t *testing.T) {
username1, password1 := genPassPair()
username2, password2 := genPassPair()
if username1 == username2 {
t.Errorf("genPassPair() generated the same username twice: %v", username1)
}
if password1 == password2 {
t.Errorf("genPassPair() generated the same password twice: %v", password1)
}
})
}
func Test_chKey(t *testing.T) {
type args struct {
channel *models.Channel
}
tests := []struct {
name string
args args
want string
}{
{
name: "ID为1的通道",
args: args{
channel: &models.Channel{
ID: 1,
},
},
want: "channel:1",
},
{
name: "ID为100的通道",
args: args{
channel: &models.Channel{
ID: 100,
},
},
want: "channel:100",
},
{
name: "ID为0的通道",
args: args{
channel: &models.Channel{
ID: 0,
},
},
want: "channel:0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := chKey(tt.args.channel); got != tt.want {
t.Errorf("chKey() = %v, want %v", got, tt.want)
}
})
}
}
func Test_cache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
now := time.Now()
expiration := now.Add(24 * time.Hour)
testChannels := []*models.Channel{
{
ID: 1,
UserID: 100,
ProxyID: 10,
ProxyPort: 8080,
Protocol: "http",
Expiration: expiration,
},
{
ID: 2,
UserID: 101,
ProxyID: 11,
ProxyPort: 8081,
Protocol: "socks5",
Expiration: expiration,
},
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常缓存多个通道",
args: args{
ctx: context.Background(),
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: context.Background(),
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证缓存结果
if len(tt.args.channels) > 0 {
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if !mr.Exists(key) {
t.Errorf("缓存未包含通道键 %s", key)
} else {
// 验证缓存的数据是否正确
data, _ := mr.Get(key)
var cachedChannel models.Channel
err := json.Unmarshal([]byte(data), &cachedChannel)
if err != nil {
t.Errorf("无法解析缓存数据: %v", err)
}
if cachedChannel.ID != channel.ID {
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
}
}
}
// 验证是否设置了过期时间
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
ttl := mr.TTL(key)
if ttl <= 0 {
t.Errorf("键 %s 没有设置过期时间", key)
}
}
// 验证是否添加了有序集合
if !mr.Exists("tasks:channel") {
t.Errorf("ZAdd未创建有序集合 tasks:channel")
}
}
})
}
}
func Test_deleteCache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
testChannels := []*models.Channel{
{ID: 1},
{ID: 2},
{ID: 3},
}
ctx := context.Background()
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常删除多个通道缓存",
args: args{
ctx: ctx,
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: ctx,
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
// 预先设置缓存数据
for _, channel := range testChannels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
}
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证删除结果
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if mr.Exists(key) {
t.Errorf("通道键 %s 未被删除", key)
}
}
})
}
}
func Test_channelService_CreateChannel(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
mc := testutil.SetupCloudClientMock(t)
env.DebugExternalChange = false
type args struct {
ctx context.Context
auth *AuthContext
resourceId int32
protocol ChannelProtocol
authType ChannelAuthType
count int
nodeFilter []NodeFilterConfig
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
tests := []struct {
name string
args args
setup func()
want []string
wantErr bool
wantErrContains string
checkCache func(t *testing.T)
}{
{
name: "用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}},
resourceId: 4,
protocol: ProtocolHTTP,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
4, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道 - PostgreSQL返回ID
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
"http://proxy3.example.com:10002",
},
checkCache: func(t *testing.T) {
// 检查总共创建了3个通道
for i := 4; i <= 6; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "用户创建HTTP白名单通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 5,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
5, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟查询白名单 - 3个IP
whitelistRows := sqlmock.NewRows([]string{"host"}).
AddRow("192.168.1.1").
AddRow("192.168.1.2").
AddRow("192.168.1.3")
mdb.ExpectQuery("SELECT").
WithArgs(int32(100)).
WillReturnRows(whitelistRows)
// 模拟保存通道 - 2个通道 * 3个白名单 = 6个
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).
AddRow(7).AddRow(8).AddRow(9).
AddRow(10).AddRow(11).AddRow(12),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
},
checkCache: func(t *testing.T) {
// 检查应该创建了6个通道2个通道 * 3个白名单
for i := 7; i <= 12; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "管理员创建SOCKS5密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadAdmin,
Id: 1,
},
},
resourceId: 6,
protocol: ProtocolSocks5,
authType: ChannelAuthTypePass,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy4": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 5},
},
}, nil
}
// 设置CloudConnect的模拟逻辑
mc.ConnectMock = func(param remote.CloudConnectReq) error {
return nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
6, 102, true,
1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"socks5://proxy4.example.com:10000",
"socks5://proxy4.example.com:10001",
},
checkCache: func(t *testing.T) {
for i := 13; i <= 14; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
},
},
{
name: "套餐不存在",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 999,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐不存在
mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "套餐不存在",
},
{
name: "套餐没有权限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 101,
},
},
resourceId: 7,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
7, 102, true, // 注意user_id 与 auth.Id 不匹配
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "无权限访问",
},
{
name: "套餐配额不足",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 10,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
2, 100, true,
0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "套餐配额不足",
},
{
name: "端口数量达到上限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 8,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy5": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
8, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟通道端口已用尽
// 构建一个大量已使用端口的结果集
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
for i := 10000; i < 65535; i++ {
channelRows.AddRow(5, i)
}
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "端口数量不足",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
// 检查错误或结果
if tt.wantErr {
if err == nil {
t.Errorf("CreateChannel() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
if len(got) != len(tt.want) {
t.Errorf("CreateChannel() 返回长度 = %v, want %v", len(got), len(tt.want))
return
}
// 检查返回地址格式
for _, addr := range got {
protocol := string(tt.args.protocol)
if !strings.HasPrefix(addr, protocol+"://") {
t.Errorf("CreateChannel() 地址 %v 不是有效的 %s 地址", addr, protocol)
}
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)
}
})
}
}
func Test_channelService_RemoveChannels(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
mg := testutil.SetupGatewayClientMock(t)
env.DebugExternalChange = false
type args struct {
ctx context.Context
auth *AuthContext
id []int32
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
checkCache func(t *testing.T)
}{
{
name: "管理员删除多个通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadAdmin,
Id: 1,
},
},
id: []int32{1, 2, 3},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
channel := models.Channel{ID: id, UserID: 100}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
}
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)).
AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)).
AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1), int32(2), int32(3)).
WillReturnRows(channelRows)
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1).
AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1), int32(2)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 3))
// 提交事务
mdb.ExpectCommit()
},
checkCache: func(t *testing.T) {
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
}
},
},
{
name: "用户删除自己的通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
id: []int32{1},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
key := "channel:1"
channel := models.Channel{ID: 1, UserID: 100}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 模拟查询已激活的端口
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
return map[string]remote.PortData{
"10001": {
Edge: []string{"edge1", "edge2"},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(channelRows)
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
checkCache: func(t *testing.T) {
key := "channel:1"
if mr.Exists(key) {
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
},
},
{
name: "用户删除不属于自己的通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
id: []int32{5},
},
setup: func() {
// 预设 Redis 缓存
mr.FlushAll()
key := "channel:5"
channel := models.Channel{ID: 5, UserID: 101}
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(5)).
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "无权限访问",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
// 检查错误
if tt.wantErr {
if err == nil {
t.Errorf("RemoveChannels() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)
}
})
}
}