新增通道服务相关测试用例
This commit is contained in:
985
web/services/channel_test.go
Normal file
985
web/services/channel_test.go
Normal file
@@ -0,0 +1,985 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/remote"
|
||||
"platform/pkg/testutil"
|
||||
"platform/web/models"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func Test_genPassPair(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
}{
|
||||
{
|
||||
name: "正常生成随机用户名和密码",
|
||||
},
|
||||
{
|
||||
name: "多次调用生成不同的值",
|
||||
},
|
||||
}
|
||||
|
||||
// 第一个测试:检查生成的用户名和密码是否有效
|
||||
t.Run(tests[0].name, func(t *testing.T) {
|
||||
username, password := genPassPair()
|
||||
if username == "" {
|
||||
t.Errorf("genPassPair() username is empty")
|
||||
}
|
||||
if password == "" {
|
||||
t.Errorf("genPassPair() password is empty")
|
||||
}
|
||||
})
|
||||
|
||||
// 第二个测试:确保多次调用生成不同的值
|
||||
t.Run(tests[1].name, func(t *testing.T) {
|
||||
username1, password1 := genPassPair()
|
||||
username2, password2 := genPassPair()
|
||||
|
||||
if username1 == username2 {
|
||||
t.Errorf("genPassPair() generated the same username twice: %v", username1)
|
||||
}
|
||||
if password1 == password2 {
|
||||
t.Errorf("genPassPair() generated the same password twice: %v", password1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_chKey(t *testing.T) {
|
||||
type args struct {
|
||||
channel *models.Channel
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "ID为1的通道",
|
||||
args: args{
|
||||
channel: &models.Channel{
|
||||
ID: 1,
|
||||
},
|
||||
},
|
||||
want: "channel:1",
|
||||
},
|
||||
{
|
||||
name: "ID为100的通道",
|
||||
args: args{
|
||||
channel: &models.Channel{
|
||||
ID: 100,
|
||||
},
|
||||
},
|
||||
want: "channel:100",
|
||||
},
|
||||
{
|
||||
name: "ID为0的通道",
|
||||
args: args{
|
||||
channel: &models.Channel{
|
||||
ID: 0,
|
||||
},
|
||||
},
|
||||
want: "channel:0",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := chKey(tt.args.channel); got != tt.want {
|
||||
t.Errorf("chKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_cache(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
channels []*models.Channel
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
now := time.Now()
|
||||
expiration := now.Add(24 * time.Hour)
|
||||
|
||||
testChannels := []*models.Channel{
|
||||
{
|
||||
ID: 1,
|
||||
UserID: 100,
|
||||
ProxyID: 10,
|
||||
ProxyPort: 8080,
|
||||
Protocol: "http",
|
||||
Expiration: expiration,
|
||||
},
|
||||
{
|
||||
ID: 2,
|
||||
UserID: 101,
|
||||
ProxyID: 11,
|
||||
ProxyPort: 8081,
|
||||
Protocol: "socks5",
|
||||
Expiration: expiration,
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常缓存多个通道",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
channels: testChannels,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空通道列表",
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
channels: []*models.Channel{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll() // 清空 Redis 数据
|
||||
|
||||
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||||
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证缓存结果
|
||||
if len(tt.args.channels) > 0 {
|
||||
for _, channel := range tt.args.channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
if !mr.Exists(key) {
|
||||
t.Errorf("缓存未包含通道键 %s", key)
|
||||
} else {
|
||||
// 验证缓存的数据是否正确
|
||||
data, _ := mr.Get(key)
|
||||
var cachedChannel models.Channel
|
||||
err := json.Unmarshal([]byte(data), &cachedChannel)
|
||||
if err != nil {
|
||||
t.Errorf("无法解析缓存数据: %v", err)
|
||||
}
|
||||
if cachedChannel.ID != channel.ID {
|
||||
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 验证是否设置了过期时间
|
||||
for _, channel := range tt.args.channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
ttl := mr.TTL(key)
|
||||
if ttl <= 0 {
|
||||
t.Errorf("键 %s 没有设置过期时间", key)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证是否添加了有序集合
|
||||
if !mr.Exists("tasks:channel") {
|
||||
t.Errorf("ZAdd未创建有序集合 tasks:channel")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_deleteCache(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
channels []*models.Channel
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
testChannels := []*models.Channel{
|
||||
{ID: 1},
|
||||
{ID: 2},
|
||||
{ID: 3},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "正常删除多个通道缓存",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
channels: testChannels,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "空通道列表",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
channels: []*models.Channel{},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll() // 清空 Redis 数据
|
||||
|
||||
// 预先设置缓存数据
|
||||
for _, channel := range testChannels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
|
||||
}
|
||||
|
||||
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||||
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证删除结果
|
||||
for _, channel := range tt.args.channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
if mr.Exists(key) {
|
||||
t.Errorf("通道键 %s 未被删除", key)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_channelService_CreateChannel(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
mdb := testutil.SetupDBTest(t)
|
||||
mc := testutil.SetupCloudClientMock(t)
|
||||
env.DebugExternalChange = false
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth *AuthContext
|
||||
resourceId int32
|
||||
protocol ChannelProtocol
|
||||
authType ChannelAuthType
|
||||
count int
|
||||
nodeFilter []NodeFilterConfig
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
setup func()
|
||||
want []string
|
||||
wantErr bool
|
||||
wantErrContains string
|
||||
checkCache func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "用户创建HTTP密码通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}},
|
||||
resourceId: 4,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypePass,
|
||||
count: 3,
|
||||
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 设置CloudAutoQuery的模拟返回
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"proxy3": []remote.AutoConfig{
|
||||
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
4, 100, true,
|
||||
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows)
|
||||
|
||||
// 模拟查询代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WithArgs(1).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 模拟查询通道
|
||||
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 模拟保存通道 - PostgreSQL返回ID
|
||||
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6),
|
||||
)
|
||||
|
||||
// 模拟更新套餐使用记录
|
||||
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// 提交事务
|
||||
mdb.ExpectCommit()
|
||||
},
|
||||
want: []string{
|
||||
"http://proxy3.example.com:10000",
|
||||
"http://proxy3.example.com:10001",
|
||||
"http://proxy3.example.com:10002",
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
// 检查总共创建了3个通道
|
||||
for i := 4; i <= 6; i++ {
|
||||
key := fmt.Sprintf("channel:%d", i)
|
||||
if !mr.Exists(key) {
|
||||
t.Errorf("Redis缓存中应有键 %s", key)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "用户创建HTTP白名单通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
resourceId: 5,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 2,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 设置CloudAutoQuery的模拟返回
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"proxy3": []remote.AutoConfig{
|
||||
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
5, 100, true,
|
||||
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows)
|
||||
|
||||
// 模拟查询代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WithArgs(1).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 模拟查询通道
|
||||
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 模拟查询白名单 - 3个IP
|
||||
whitelistRows := sqlmock.NewRows([]string{"host"}).
|
||||
AddRow("192.168.1.1").
|
||||
AddRow("192.168.1.2").
|
||||
AddRow("192.168.1.3")
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WithArgs(int32(100)).
|
||||
WillReturnRows(whitelistRows)
|
||||
|
||||
// 模拟保存通道 - 2个通道 * 3个白名单 = 6个
|
||||
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id"}).
|
||||
AddRow(7).AddRow(8).AddRow(9).
|
||||
AddRow(10).AddRow(11).AddRow(12),
|
||||
)
|
||||
|
||||
// 模拟更新套餐使用记录
|
||||
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// 提交事务
|
||||
mdb.ExpectCommit()
|
||||
},
|
||||
want: []string{
|
||||
"http://proxy3.example.com:10000",
|
||||
"http://proxy3.example.com:10001",
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
// 检查应该创建了6个通道(2个通道 * 3个白名单)
|
||||
for i := 7; i <= 12; i++ {
|
||||
key := fmt.Sprintf("channel:%d", i)
|
||||
if !mr.Exists(key) {
|
||||
t.Errorf("Redis缓存中应有键 %s", key)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "管理员创建SOCKS5密码通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadAdmin,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
resourceId: 6,
|
||||
protocol: ProtocolSocks5,
|
||||
authType: ChannelAuthTypePass,
|
||||
count: 2,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 设置CloudAutoQuery的模拟返回
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"proxy4": []remote.AutoConfig{
|
||||
{Province: "河南", City: "郑州", Isp: "电信", Count: 5},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 设置CloudConnect的模拟逻辑
|
||||
mc.ConnectMock = func(param remote.CloudConnectReq) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
6, 102, true,
|
||||
1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows)
|
||||
|
||||
// 模拟查询代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WithArgs(1).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 模拟查询通道
|
||||
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 模拟保存通道
|
||||
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||
sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14),
|
||||
)
|
||||
|
||||
// 模拟更新套餐使用记录
|
||||
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// 提交事务
|
||||
mdb.ExpectCommit()
|
||||
},
|
||||
want: []string{
|
||||
"socks5://proxy4.example.com:10000",
|
||||
"socks5://proxy4.example.com:10001",
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
for i := 13; i <= 14; i++ {
|
||||
key := fmt.Sprintf("channel:%d", i)
|
||||
if !mr.Exists(key) {
|
||||
t.Errorf("Redis缓存中应有键 %s", key)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "套餐不存在",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
resourceId: 999,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 1,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐不存在
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound)
|
||||
|
||||
// 回滚事务
|
||||
mdb.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "套餐不存在",
|
||||
},
|
||||
{
|
||||
name: "套餐没有权限",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 101,
|
||||
},
|
||||
},
|
||||
resourceId: 7,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 1,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
7, 102, true, // 注意:user_id 与 auth.Id 不匹配
|
||||
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows)
|
||||
|
||||
// 回滚事务
|
||||
mdb.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "无权限访问",
|
||||
},
|
||||
{
|
||||
name: "套餐配额不足",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
resourceId: 2,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 10,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
2, 100, true,
|
||||
0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows)
|
||||
|
||||
// 回滚事务
|
||||
mdb.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "套餐配额不足",
|
||||
},
|
||||
{
|
||||
name: "端口数量达到上限",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
resourceId: 8,
|
||||
protocol: ProtocolHTTP,
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 1,
|
||||
},
|
||||
setup: func() {
|
||||
// 清空Redis
|
||||
mr.FlushAll()
|
||||
|
||||
// 设置CloudAutoQuery的模拟返回
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"proxy5": []remote.AutoConfig{
|
||||
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 模拟查询套餐
|
||||
resourceRows := sqlmock.NewRows([]string{
|
||||
"id", "user_id", "active",
|
||||
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||
}).AddRow(
|
||||
8, 100, true,
|
||||
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||
)
|
||||
mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows)
|
||||
|
||||
// 模拟查询代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WithArgs(1).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 模拟通道端口已用尽
|
||||
// 构建一个大量已使用端口的结果集
|
||||
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||
for i := 10000; i < 65535; i++ {
|
||||
channelRows.AddRow(5, i)
|
||||
}
|
||||
mdb.ExpectQuery("SELECT").
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 回滚事务
|
||||
mdb.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "端口数量不足",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setup != nil {
|
||||
tt.setup()
|
||||
}
|
||||
|
||||
s := &channelService{}
|
||||
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
|
||||
|
||||
// 检查错误或结果
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("CreateChannel() 应当返回错误")
|
||||
return
|
||||
}
|
||||
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if len(got) != len(tt.want) {
|
||||
t.Errorf("CreateChannel() 返回长度 = %v, want %v", len(got), len(tt.want))
|
||||
return
|
||||
}
|
||||
|
||||
// 检查返回地址格式
|
||||
for _, addr := range got {
|
||||
protocol := string(tt.args.protocol)
|
||||
if !strings.HasPrefix(addr, protocol+"://") {
|
||||
t.Errorf("CreateChannel() 地址 %v 不是有效的 %s 地址", addr, protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证所有期望的 SQL 已执行
|
||||
if err := mdb.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("有未满足的SQL期望: %s", err)
|
||||
}
|
||||
|
||||
// 检查 Redis 缓存是否正确设置
|
||||
if tt.checkCache != nil {
|
||||
tt.checkCache(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
mdb := testutil.SetupDBTest(t)
|
||||
mg := testutil.SetupGatewayClientMock(t)
|
||||
env.DebugExternalChange = false
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
auth *AuthContext
|
||||
id []int32
|
||||
}
|
||||
|
||||
// 准备测试数据
|
||||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
setup func()
|
||||
wantErr bool
|
||||
wantErrContains string
|
||||
checkCache func(t *testing.T)
|
||||
}{
|
||||
{
|
||||
name: "管理员删除多个通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadAdmin,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
id: []int32{1, 2, 3},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
for _, id := range []int32{1, 2, 3} {
|
||||
key := fmt.Sprintf("channel:%d", id)
|
||||
channel := models.Channel{ID: id, UserID: 100}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 查找通道
|
||||
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)).
|
||||
AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)).
|
||||
AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour))
|
||||
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||
WithArgs(int32(1), int32(2), int32(3)).
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 查找代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1).
|
||||
AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
|
||||
WithArgs(int32(1), int32(2)).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 软删除通道
|
||||
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 3))
|
||||
|
||||
// 提交事务
|
||||
mdb.ExpectCommit()
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
for _, id := range []int32{1, 2, 3} {
|
||||
key := fmt.Sprintf("channel:%d", id)
|
||||
if mr.Exists(key) {
|
||||
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
}
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "用户删除自己的通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
id: []int32{1},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
key := "channel:1"
|
||||
channel := models.Channel{ID: 1, UserID: 100}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
|
||||
// 模拟查询已激活的端口
|
||||
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
return map[string]remote.PortData{
|
||||
"10001": {
|
||||
Edge: []string{"edge1", "edge2"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 查找通道
|
||||
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour))
|
||||
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||
WithArgs(int32(1)).
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 查找代理
|
||||
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1)
|
||||
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
|
||||
WithArgs(int32(1)).
|
||||
WillReturnRows(proxyRows)
|
||||
|
||||
// 软删除通道
|
||||
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
// 提交事务
|
||||
mdb.ExpectCommit()
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
key := "channel:1"
|
||||
if mr.Exists(key) {
|
||||
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "用户删除不属于自己的通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
id: []int32{5},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
key := "channel:5"
|
||||
channel := models.Channel{ID: 5, UserID: 101}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
|
||||
// 开始事务
|
||||
mdb.ExpectBegin()
|
||||
|
||||
// 查找通道
|
||||
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||
AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour))
|
||||
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||
WithArgs(int32(5)).
|
||||
WillReturnRows(channelRows)
|
||||
|
||||
// 回滚事务
|
||||
mdb.ExpectRollback()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "无权限访问",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setup != nil {
|
||||
tt.setup()
|
||||
}
|
||||
|
||||
s := &channelService{}
|
||||
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
|
||||
|
||||
// 检查错误
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Errorf("RemoveChannels() 应当返回错误")
|
||||
return
|
||||
}
|
||||
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
// 验证所有期望的 SQL 已执行
|
||||
if err := mdb.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("有未满足的SQL期望: %s", err)
|
||||
}
|
||||
|
||||
// 检查 Redis 缓存是否正确设置
|
||||
if tt.checkCache != nil {
|
||||
tt.checkCache(t)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user