Files
platform/web/services/channel_test.go

921 lines
24 KiB
Go
Raw Normal View History

2025-04-01 11:32:17 +08:00
package services
import (
"context"
"encoding/json"
"fmt"
"platform/pkg/remote"
"platform/pkg/testutil"
"platform/web/models"
2025-04-03 13:30:57 +08:00
"slices"
2025-04-01 11:32:17 +08:00
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2/middleware/requestid"
)
func Test_genPassPair(t *testing.T) {
tests := []struct {
name string
}{
{
name: "正常生成随机用户名和密码",
},
{
name: "多次调用生成不同的值",
},
}
// 第一个测试:检查生成的用户名和密码是否有效
t.Run(tests[0].name, func(t *testing.T) {
username, password := genPassPair()
if username == "" {
t.Errorf("genPassPair() username is empty")
}
if password == "" {
t.Errorf("genPassPair() password is empty")
}
})
// 第二个测试:确保多次调用生成不同的值
t.Run(tests[1].name, func(t *testing.T) {
username1, password1 := genPassPair()
username2, password2 := genPassPair()
if username1 == username2 {
t.Errorf("genPassPair() generated the same username twice: %v", username1)
}
if password1 == password2 {
t.Errorf("genPassPair() generated the same password twice: %v", password1)
}
})
}
func Test_chKey(t *testing.T) {
type args struct {
channel *models.Channel
}
tests := []struct {
name string
args args
want string
}{
{
name: "ID为1的通道",
args: args{
channel: &models.Channel{
ID: 1,
},
},
want: "channel:1",
},
{
name: "ID为100的通道",
args: args{
channel: &models.Channel{
ID: 100,
},
},
want: "channel:100",
},
{
name: "ID为0的通道",
args: args{
channel: &models.Channel{
ID: 0,
},
},
want: "channel:0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := chKey(tt.args.channel); got != tt.want {
t.Errorf("chKey() = %v, want %v", got, tt.want)
}
})
}
}
func Test_cache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
now := time.Now()
expiration := now.Add(24 * time.Hour)
testChannels := []*models.Channel{
{
ID: 1,
UserID: 100,
ProxyID: 10,
ProxyPort: 8080,
Protocol: "http",
Expiration: expiration,
},
{
ID: 2,
UserID: 101,
ProxyID: 11,
ProxyPort: 8081,
Protocol: "socks5",
Expiration: expiration,
},
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常缓存多个通道",
args: args{
ctx: context.Background(),
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: context.Background(),
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证缓存结果
if len(tt.args.channels) > 0 {
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if !mr.Exists(key) {
t.Errorf("缓存未包含通道键 %s", key)
} else {
// 验证缓存的数据是否正确
data, _ := mr.Get(key)
var cachedChannel models.Channel
err := json.Unmarshal([]byte(data), &cachedChannel)
if err != nil {
t.Errorf("无法解析缓存数据: %v", err)
}
if cachedChannel.ID != channel.ID {
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
}
}
}
// 验证是否设置了过期时间
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
ttl := mr.TTL(key)
if ttl <= 0 {
t.Errorf("键 %s 没有设置过期时间", key)
}
}
// 验证是否添加了有序集合
if !mr.Exists("tasks:channel") {
t.Errorf("ZAdd未创建有序集合 tasks:channel")
}
}
})
}
}
func Test_deleteCache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
testChannels := []*models.Channel{
{ID: 1},
{ID: 2},
{ID: 3},
}
ctx := context.Background()
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常删除多个通道缓存",
args: args{
ctx: ctx,
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: ctx,
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
// 预先设置缓存数据
for _, channel := range testChannels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
}
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证删除结果
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if mr.Exists(key) {
t.Errorf("通道键 %s 未被删除", key)
}
}
})
}
}
func Test_channelService_CreateChannel(t *testing.T) {
mr := testutil.SetupRedisTest(t)
db := testutil.SetupDBTest(t)
2025-04-01 11:32:17 +08:00
mc := testutil.SetupCloudClientMock(t)
type args struct {
ctx context.Context
auth *AuthContext
resourceId int32
protocol ChannelProtocol
authType ChannelAuthType
count int
nodeFilter []NodeFilterConfig
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
var user = &models.User{
ID: 101,
Phone: "12312341234",
}
db.Create(user)
var whitelists = []*models.Whitelist{
{ID: 1, UserID: 101, Host: "123.123.123.123"},
{ID: 2, UserID: 101, Host: "456.456.456.456"},
{ID: 3, UserID: 101, Host: "789.789.789.789"},
}
db.Create(whitelists)
var resource = &models.Resource{
ID: 1,
UserID: 101,
Active: true,
}
db.Create(resource)
var resourcePss = &models.ResourcePss{
ID: 1,
ResourceID: 1,
Type: 1,
Live: 180,
Expire: time.Now().AddDate(1, 0, 0),
DailyLimit: 10000,
}
db.Create(resourcePss)
var proxy = &models.Proxy{
ID: 1,
Version: 1,
Name: "test-proxy",
Host: "111.111.111.111",
Type: 1,
Secret: "test:secret",
}
db.Create(proxy)
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"test-proxy": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
var clearDb = func() {
db.Exec("delete from channel where true")
db.Exec("update resource_pss set daily_used = 0, daily_last = null, used = 0 where true")
}
2025-04-01 11:32:17 +08:00
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
2025-04-03 13:30:57 +08:00
want func(t *testing.T, got []*PortInfo) error
2025-04-01 11:32:17 +08:00
}{
{
name: "用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
2025-04-01 11:32:17 +08:00
protocol: ProtocolHTTP,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
},
2025-04-03 13:30:57 +08:00
want: func(t *testing.T, got []*PortInfo) error {
// 验证返回结果
if len(got) == 0 {
return fmt.Errorf("返回的 PortInfo 不应为空")
}
// 验证协议正确
for _, port := range got {
if port.Proto != "http" {
return fmt.Errorf("期望协议为 http得到 %s", port.Proto)
}
if port.Host != proxy.Host {
return fmt.Errorf("期望主机为 %s得到 %s", proxy.Host, port.Host)
}
}
// 验证数据库字段
var channels []*models.Channel
db.Where("user_id = ? AND proxy_id = ?", userAuth.Payload.Id, proxy.ID).Find(&channels)
for _, ch := range channels {
if ch.Protocol != "http" {
return fmt.Errorf("通道协议不正确,期望 http得到 %s", ch.Protocol)
}
if ch.UserID != userAuth.Payload.Id {
return fmt.Errorf("通道用户ID不正确期望 %d得到 %d", userAuth.Payload.Id, ch.UserID)
}
if ch.ProxyID != proxy.ID {
return fmt.Errorf("通道代理ID不正确期望 %d得到 %d", proxy.ID, ch.ProxyID)
}
// 检查Redis缓存中的字段
key := fmt.Sprintf("channel:%d", ch.ID)
if !mr.Exists(key) {
return fmt.Errorf("Redis缓存中应有键 %s", key)
}
data, _ := mr.Get(key)
var cachedChannel models.Channel
err := json.Unmarshal([]byte(data), &cachedChannel)
if err != nil {
return fmt.Errorf("无法解析缓存数据: %v", err)
}
if cachedChannel.ID != ch.ID {
return fmt.Errorf("缓存ID不正确期望 %d得到 %d", ch.ID, cachedChannel.ID)
}
if cachedChannel.Protocol != ch.Protocol {
return fmt.Errorf("缓存协议不正确,期望 %s得到 %s", ch.Protocol, cachedChannel.Protocol)
}
}
return nil
2025-04-01 11:32:17 +08:00
},
},
{
name: "用户创建HTTP白名单通道",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
2025-04-01 11:32:17 +08:00
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 2,
},
2025-04-03 13:30:57 +08:00
want: func(t *testing.T, got []*PortInfo) error {
return nil
},
2025-04-01 11:32:17 +08:00
},
{
name: "管理员创建SOCKS5密码通道",
args: args{
ctx: ctx,
auth: adminAuth,
resourceId: 1,
2025-04-01 11:32:17 +08:00
protocol: ProtocolSocks5,
authType: ChannelAuthTypePass,
count: 2,
},
2025-04-03 13:30:57 +08:00
want: func(t *testing.T, got []*PortInfo) error {
return nil
},
2025-04-01 11:32:17 +08:00
},
{
name: "套餐不存在",
args: args{
ctx: ctx,
auth: userAuth,
2025-04-01 11:32:17 +08:00
resourceId: 999,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
wantErr: true,
wantErrContains: "套餐不存在",
},
{
name: "套餐没有权限",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 2,
2025-04-01 11:32:17 +08:00
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
wantErr: true,
wantErrContains: "无权限访问",
},
{
name: "套餐配额不足",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 10,
},
setup: func() {
// 创建一个配额几乎用完的资源包
resource2 := models.Resource{
ID: 2,
UserID: 101,
Active: true,
}
resourcePss2 := models.ResourcePss{
ID: 1,
ResourceID: 1,
Type: 2,
Quota: 100,
Used: 91,
Live: 180,
DailyLimit: 10000,
}
db.Create(&resource2).Create(&resourcePss2)
2025-04-01 11:32:17 +08:00
},
wantErr: true,
wantErrContains: "套餐配额不足",
},
{
name: "端口数量达到上限",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
2025-04-01 11:32:17 +08:00
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 创建大量占用端口的通道
for i := 10000; i < 20000; i++ {
channel := models.Channel{
ProxyID: 1,
ProxyPort: int32(i),
UserID: 101,
}
db.Create(&channel)
2025-04-01 11:32:17 +08:00
}
},
wantErr: true,
wantErrContains: "端口数量不足",
},
2025-04-02 17:24:12 +08:00
// todo 跨天用量更新
// todo 多地区混杂条件提取
2025-04-01 11:32:17 +08:00
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll()
clearDb()
2025-04-01 11:32:17 +08:00
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
// 检查错误或结果
if tt.wantErr {
if err == nil {
t.Errorf("CreateChannel() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
2025-04-03 13:30:57 +08:00
// 使用检查函数验证结果
if tt.want != nil {
if err := tt.want(t, got); err != nil {
t.Errorf("结果验证失败: %v", err)
}
2025-04-01 11:32:17 +08:00
}
})
}
}
func Test_channelService_RemoveChannels(t *testing.T) {
mr := testutil.SetupRedisTest(t)
2025-04-03 13:30:57 +08:00
md := testutil.SetupDBTest(t)
2025-04-01 11:32:17 +08:00
mg := testutil.SetupGatewayClientMock(t)
2025-04-03 13:30:57 +08:00
mc := testutil.SetupCloudClientMock(t)
2025-04-01 11:32:17 +08:00
type args struct {
ctx context.Context
auth *AuthContext
id []int32
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
2025-04-03 13:30:57 +08:00
// 创建用户
var user = &models.User{
ID: 101,
Phone: "12312341234",
}
md.Create(user)
// 创建管理员
var adminUser = &models.User{
ID: 100,
Phone: "99999999999",
}
md.Create(adminUser)
// 认证上下文
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
// 创建代理
var proxy = &models.Proxy{
ID: 1,
Version: 1,
Name: "test-proxy",
Host: "111.111.111.111",
Type: 1,
Secret: "test:secret",
}
md.Create(proxy)
var proxy2 = &models.Proxy{
ID: 2,
Version: 1,
Name: "test-proxy-2",
Host: "222.222.222.222",
Type: 1,
Secret: "test:secret2",
}
md.Create(proxy2)
// 清空数据库函数
var clearDb = func() {
md.Exec("delete from channel where true")
mr.FlushAll()
}
2025-04-01 11:32:17 +08:00
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
2025-04-03 13:30:57 +08:00
want func(t *testing.T) error
2025-04-01 11:32:17 +08:00
}{
{
name: "管理员删除多个通道",
args: args{
2025-04-03 13:30:57 +08:00
ctx: ctx,
auth: adminAuth,
id: []int32{1, 2, 3},
2025-04-01 11:32:17 +08:00
},
setup: func() {
mr.FlushAll()
2025-04-03 13:30:57 +08:00
clearDb()
// 创建通道
channels := []models.Channel{
2025-04-03 13:30:57 +08:00
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
}
2025-04-03 13:30:57 +08:00
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
// 模拟网关客户端的响应
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
switch {
case m.Host == proxy.Host:
return map[string]remote.PortData{
"10001": {Edge: []string{"edge1", "edge4"}},
"10002": {Edge: []string{"edge2"}},
}, nil
case m.Host == proxy2.Host:
return map[string]remote.PortData{
"10001": {Edge: []string{"edge3"}},
}, nil
}
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error {
switch {
case m.Host == proxy.Host:
for _, param := range params {
if param.Port != 10001 && param.Port != 10002 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
}
case m.Host == proxy2.Host:
for _, param := range params {
if param.Port != 10001 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
}
}
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) {
switch {
case param.Uuid == proxy.Name:
var edges = []string{"edge1", "edge2", "edge4"}
if !slices.Equal(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
case param.Uuid == proxy2.Name:
var edges = []string{"edge3"}
if !slices.Equal(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
}
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
2025-04-01 11:32:17 +08:00
},
2025-04-03 13:30:57 +08:00
want: func(t *testing.T) error {
// 检查通道是否被软删除
var count int64
2025-04-03 13:30:57 +08:00
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
if count > 0 {
2025-04-03 13:30:57 +08:00
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
}
// 检查Redis缓存是否被删除
2025-04-01 11:32:17 +08:00
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
2025-04-03 13:30:57 +08:00
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
2025-04-01 11:32:17 +08:00
}
}
2025-04-03 13:30:57 +08:00
return nil
2025-04-01 11:32:17 +08:00
},
},
{
name: "用户删除自己的通道",
args: args{
2025-04-03 13:30:57 +08:00
ctx: ctx,
auth: userAuth,
id: []int32{1, 2, 3},
2025-04-01 11:32:17 +08:00
},
setup: func() {
mr.FlushAll()
2025-04-03 13:30:57 +08:00
clearDb()
2025-04-01 11:32:17 +08:00
2025-04-03 13:30:57 +08:00
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
}
2025-04-03 13:30:57 +08:00
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
2025-04-03 13:30:57 +08:00
// 模拟网关客户端的响应
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
switch {
case m.Host == proxy.Host:
return map[string]remote.PortData{
"10001": {Edge: []string{"edge1", "edge4"}},
"10002": {Edge: []string{"edge2"}},
}, nil
case m.Host == proxy2.Host:
return map[string]remote.PortData{
"10001": {Edge: []string{"edge3"}},
}, nil
}
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error {
switch {
case m.Host == proxy.Host:
for _, param := range params {
if param.Port != 10001 && param.Port != 10002 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
}
case m.Host == proxy2.Host:
for _, param := range params {
if param.Port != 10001 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
}
}
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) {
switch {
case param.Uuid == proxy.Name:
var edges = []string{"edge1", "edge2", "edge4"}
if !slices.Equal(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
case param.Uuid == proxy2.Name:
var edges = []string{"edge3"}
if !slices.Equal(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
}
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
2025-04-01 11:32:17 +08:00
}
},
2025-04-03 13:30:57 +08:00
want: func(t *testing.T) error {
// 检查通道是否被软删除
var count int64
2025-04-03 13:30:57 +08:00
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
if count > 0 {
2025-04-03 13:30:57 +08:00
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
}
// 检查Redis缓存是否被删除
2025-04-03 13:30:57 +08:00
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
2025-04-01 11:32:17 +08:00
}
2025-04-03 13:30:57 +08:00
return nil
2025-04-01 11:32:17 +08:00
},
},
{
name: "用户删除不属于自己的通道",
args: args{
2025-04-03 13:30:57 +08:00
ctx: ctx,
auth: userAuth,
id: []int32{1, 2, 3},
2025-04-01 11:32:17 +08:00
},
setup: func() {
mr.FlushAll()
2025-04-03 13:30:57 +08:00
clearDb()
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
}
2025-04-01 11:32:17 +08:00
2025-04-03 13:30:57 +08:00
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
2025-04-01 11:32:17 +08:00
},
wantErr: true,
wantErrContains: "无权限访问",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
// 检查错误
if tt.wantErr {
if err == nil {
t.Errorf("RemoveChannels() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
2025-04-03 13:30:57 +08:00
// 检查数据库和缓存是否正确设置
want := tt.want(t)
if tt.want(t) != nil {
t.Errorf("RemoveChannels() 结果验证失败: %v", want)
2025-04-01 11:32:17 +08:00
}
})
}
}