1323 lines
38 KiB
Go
1323 lines
38 KiB
Go
package services
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"fmt"
|
||
"platform/pkg/sdks/baiyin"
|
||
"platform/pkg/testutil"
|
||
"platform/web/common"
|
||
"platform/web/models"
|
||
"reflect"
|
||
"strings"
|
||
"testing"
|
||
"time"
|
||
|
||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||
)
|
||
|
||
func Test_genPassPair(t *testing.T) {
|
||
tests := []struct {
|
||
name string
|
||
}{
|
||
{
|
||
name: "正常生成随机用户名和密码",
|
||
},
|
||
{
|
||
name: "多次调用生成不同的值",
|
||
},
|
||
}
|
||
|
||
// 第一个测试:检查生成的用户名和密码是否有效
|
||
t.Run(tests[0].name, func(t *testing.T) {
|
||
username, password := genPassPair()
|
||
if username == "" {
|
||
t.Errorf("genPassPair() username is empty")
|
||
}
|
||
if password == "" {
|
||
t.Errorf("genPassPair() password is empty")
|
||
}
|
||
})
|
||
|
||
// 第二个测试:确保多次调用生成不同的值
|
||
t.Run(tests[1].name, func(t *testing.T) {
|
||
username1, password1 := genPassPair()
|
||
username2, password2 := genPassPair()
|
||
|
||
if username1 == username2 {
|
||
t.Errorf("genPassPair() generated the same username twice: %v", username1)
|
||
}
|
||
if password1 == password2 {
|
||
t.Errorf("genPassPair() generated the same password twice: %v", password1)
|
||
}
|
||
})
|
||
}
|
||
|
||
func Test_chKey(t *testing.T) {
|
||
type args struct {
|
||
channel *models.Channel
|
||
}
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
want string
|
||
}{
|
||
{
|
||
name: "ID为1的通道",
|
||
args: args{
|
||
channel: &models.Channel{
|
||
ID: 1,
|
||
},
|
||
},
|
||
want: "channel:1",
|
||
},
|
||
{
|
||
name: "ID为100的通道",
|
||
args: args{
|
||
channel: &models.Channel{
|
||
ID: 100,
|
||
},
|
||
},
|
||
want: "channel:100",
|
||
},
|
||
{
|
||
name: "ID为0的通道",
|
||
args: args{
|
||
channel: &models.Channel{
|
||
ID: 0,
|
||
},
|
||
},
|
||
want: "channel:0",
|
||
},
|
||
}
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
if got := chKey(tt.args.channel); got != tt.want {
|
||
t.Errorf("chKey() = %v, want %v", got, tt.want)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func Test_cache(t *testing.T) {
|
||
mr := testutil.SetupRedisTest(t)
|
||
|
||
type args struct {
|
||
ctx context.Context
|
||
channels []*models.Channel
|
||
}
|
||
|
||
// 准备测试数据
|
||
now := time.Now()
|
||
expiration := now.Add(24 * time.Hour)
|
||
|
||
testChannels := []*models.Channel{
|
||
{
|
||
ID: 1,
|
||
UserID: 100,
|
||
ProxyID: 10,
|
||
ProxyPort: 8080,
|
||
Protocol: 1,
|
||
Expiration: expiration,
|
||
},
|
||
{
|
||
ID: 2,
|
||
UserID: 101,
|
||
ProxyID: 11,
|
||
ProxyPort: 8081,
|
||
Protocol: 3,
|
||
Expiration: expiration,
|
||
},
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "正常缓存多个通道",
|
||
args: args{
|
||
ctx: context.Background(),
|
||
channels: testChannels,
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "空通道列表",
|
||
args: args{
|
||
ctx: context.Background(),
|
||
channels: []*models.Channel{},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
mr.FlushAll() // 清空 Redis 数据
|
||
|
||
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
// 验证缓存结果
|
||
if len(tt.args.channels) > 0 {
|
||
for _, channel := range tt.args.channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
if !mr.Exists(key) {
|
||
t.Errorf("缓存未包含通道键 %s", key)
|
||
} else {
|
||
// 验证缓存的数据是否正确
|
||
data, _ := mr.Get(key)
|
||
var cachedChannel models.Channel
|
||
err := json.Unmarshal([]byte(data), &cachedChannel)
|
||
if err != nil {
|
||
t.Errorf("无法解析缓存数据: %v", err)
|
||
}
|
||
if cachedChannel.ID != channel.ID {
|
||
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 验证是否设置了过期时间
|
||
for _, channel := range tt.args.channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
ttl := mr.TTL(key)
|
||
if ttl <= 0 {
|
||
t.Errorf("键 %s 没有设置过期时间", key)
|
||
}
|
||
}
|
||
|
||
// 验证是否添加了有序集合
|
||
if !mr.Exists("tasks:channel") {
|
||
t.Errorf("ZAdd未创建有序集合 tasks:channel")
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func Test_deleteCache(t *testing.T) {
|
||
mr := testutil.SetupRedisTest(t)
|
||
|
||
type args struct {
|
||
ctx context.Context
|
||
channels []*models.Channel
|
||
}
|
||
|
||
// 准备测试数据
|
||
testChannels := []*models.Channel{
|
||
{ID: 1},
|
||
{ID: 2},
|
||
{ID: 3},
|
||
}
|
||
|
||
ctx := context.Background()
|
||
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
wantErr bool
|
||
}{
|
||
{
|
||
name: "正常删除多个通道缓存",
|
||
args: args{
|
||
ctx: ctx,
|
||
channels: testChannels,
|
||
},
|
||
wantErr: false,
|
||
},
|
||
{
|
||
name: "空通道列表",
|
||
args: args{
|
||
ctx: ctx,
|
||
channels: []*models.Channel{},
|
||
},
|
||
wantErr: false,
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
mr.FlushAll() // 清空 Redis 数据
|
||
|
||
// 预先设置缓存数据
|
||
for _, channel := range testChannels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
data, _ := json.Marshal(channel)
|
||
mr.Set(key, string(data))
|
||
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
|
||
}
|
||
|
||
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
// 验证删除结果
|
||
for _, channel := range tt.args.channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
if mr.Exists(key) {
|
||
t.Errorf("通道键 %s 未被删除", key)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func Test_channelService_CreateChannel(t *testing.T) {
|
||
mr := testutil.SetupRedisTest(t)
|
||
db := testutil.SetupDBTest(t)
|
||
mc := testutil.SetupCloudClientMock(t)
|
||
mg := testutil.SetupGatewayClientMock(t)
|
||
|
||
type args struct {
|
||
ctx context.Context
|
||
auth *AuthContext
|
||
resourceId int32
|
||
protocol ChannelProtocol
|
||
authType ChannelAuthType
|
||
count int
|
||
nodeFilter []NodeFilterConfig
|
||
}
|
||
|
||
// 准备测试数据
|
||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||
mc.AutoQueryMock = func() (baiyin.CloudConnectResp, error) {
|
||
return baiyin.CloudConnectResp{
|
||
"test-proxy": []baiyin.AutoConfig{
|
||
{Province: "河南省", Count: 10},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
var user *models.User
|
||
var whitelists []*models.Whitelist
|
||
var proxy *models.Proxy
|
||
var resource *models.Resource
|
||
var resourcePss *models.ResourcePss
|
||
var resetDb = func() {
|
||
user = &models.User{
|
||
ID: 101,
|
||
Phone: "12312341234",
|
||
}
|
||
db.Exec("delete from user where true")
|
||
db.Create(user)
|
||
|
||
whitelists = []*models.Whitelist{
|
||
{ID: 1, UserID: 101, Host: "123.123.123.123"},
|
||
{ID: 2, UserID: 101, Host: "456.456.456.456"},
|
||
{ID: 3, UserID: 101, Host: "789.789.789.789"},
|
||
}
|
||
db.Exec("delete from whitelist where true")
|
||
db.Create(whitelists)
|
||
|
||
proxy = &models.Proxy{
|
||
ID: 1,
|
||
Version: 1,
|
||
Name: "test-proxy",
|
||
Host: "111.111.111.111",
|
||
Type: 1,
|
||
Secret: "test:secret",
|
||
}
|
||
db.Exec("delete from proxy where true")
|
||
db.Create(proxy)
|
||
|
||
resource = &models.Resource{
|
||
ID: 1,
|
||
UserID: 101,
|
||
Active: true,
|
||
}
|
||
db.Exec("delete from resource where true")
|
||
db.Create(resource)
|
||
|
||
resourcePss = &models.ResourcePss{
|
||
ID: 1,
|
||
ResourceID: 1,
|
||
Type: 1,
|
||
Live: 180,
|
||
Expire: common.LocalDateTime(time.Now().AddDate(1, 0, 0)),
|
||
DailyLimit: 10000,
|
||
}
|
||
db.Exec("delete from resource_pss where true")
|
||
db.Create(resourcePss)
|
||
|
||
db.Exec("delete from channel where true")
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
setup func()
|
||
wantErr bool
|
||
wantErrContains string
|
||
want func(t *testing.T, got []*PortInfo) error
|
||
}{
|
||
{
|
||
name: "用户创建HTTP密码通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 1,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypePass,
|
||
count: 3,
|
||
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
|
||
mc.ConnectMock = func(param baiyin.CloudConnectReq) error {
|
||
if param.Uuid != proxy.Name {
|
||
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||
}
|
||
if len(param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||
}
|
||
if !reflect.DeepEqual(param.AutoConfig, []baiyin.AutoConfig{
|
||
{Province: "河南省", Count: 10},
|
||
{Province: "北京市", Count: 6},
|
||
}) {
|
||
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []baiyin.PortConfigsReq) error {
|
||
if c.Host != proxy.Host {
|
||
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
|
||
}
|
||
if len(params) != 3 {
|
||
return fmt.Errorf("端口数量不符合预期: %d", len(params))
|
||
}
|
||
for _, param := range params {
|
||
if param.Status != true {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.AutoEdgeConfig == nil {
|
||
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
|
||
}
|
||
if param.Userpass == nil || *param.Userpass == "" {
|
||
return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass)
|
||
}
|
||
if param.Whitelist == nil || len(*param.Whitelist) != 0 {
|
||
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
|
||
}
|
||
config := param.AutoEdgeConfig
|
||
if config.Province != "北京市" {
|
||
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
|
||
}
|
||
if *config.Count != 1 {
|
||
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
|
||
}
|
||
if config.PacketLoss != 30 {
|
||
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
},
|
||
want: func(t *testing.T, got []*PortInfo) error {
|
||
// 验证返回结果
|
||
if len(got) != 3 {
|
||
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got))
|
||
}
|
||
|
||
// 验证结果
|
||
var gotMap = make(map[int]PortInfo)
|
||
for _, port := range got {
|
||
if port.Proto != 1 {
|
||
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
|
||
}
|
||
if port.Host != proxy.Host {
|
||
return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host)
|
||
}
|
||
gotMap[port.Port] = *port
|
||
}
|
||
|
||
// 验证数据库字段
|
||
var channels []*models.Channel
|
||
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
|
||
for _, ch := range channels {
|
||
if ch.Protocol != 1 {
|
||
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
|
||
}
|
||
if ch.UserID != userAuth.Payload.Id {
|
||
return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID)
|
||
}
|
||
// todo 多代理分配策略,验证 proxy_host
|
||
if ch.ProxyID != proxy.ID {
|
||
return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID)
|
||
}
|
||
var info, ok = gotMap[int(ch.ProxyPort)]
|
||
if !ok {
|
||
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
|
||
}
|
||
if ch.AuthPass != true && ch.AuthIP != false {
|
||
return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass)
|
||
}
|
||
if ch.Protocol != int32(info.Proto) {
|
||
return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol)
|
||
}
|
||
if ch.Username != *info.Username {
|
||
return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username)
|
||
}
|
||
if ch.Password != *info.Password {
|
||
return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password)
|
||
}
|
||
if ch.Expiration.IsZero() {
|
||
return fmt.Errorf("通道过期时间不应为空")
|
||
}
|
||
|
||
// 检查Redis缓存中的字段
|
||
key := fmt.Sprintf("channel:%d", ch.ID)
|
||
if !mr.Exists(key) {
|
||
return fmt.Errorf("redis缓存中应有键 %s", key)
|
||
}
|
||
|
||
data, _ := mr.Get(key)
|
||
var cache models.Channel
|
||
err := json.Unmarshal([]byte(data), &cache)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解析缓存数据: %v", err)
|
||
}
|
||
if reflect.DeepEqual(cache, *ch) {
|
||
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
|
||
}
|
||
}
|
||
|
||
// 检查跨天用量更新
|
||
var pss models.ResourcePss
|
||
db.Where("resource_id = ?", 1).First(&pss)
|
||
if pss.DailyUsed != 3 {
|
||
return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed)
|
||
}
|
||
if time.Time(pss.DailyLast).IsZero() {
|
||
return fmt.Errorf("套餐每日最后更新时间不应为空")
|
||
}
|
||
if pss.Used != 3 {
|
||
return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used)
|
||
}
|
||
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "用户创建HTTP白名单通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 1,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypeIp,
|
||
count: 3,
|
||
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
|
||
mc.ConnectMock = func(param baiyin.CloudConnectReq) error {
|
||
if param.Uuid != proxy.Name {
|
||
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||
}
|
||
if len(param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||
}
|
||
if !reflect.DeepEqual(param.AutoConfig, []baiyin.AutoConfig{
|
||
{Province: "河南省", Count: 10},
|
||
{Province: "北京市", Count: 6},
|
||
}) {
|
||
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []baiyin.PortConfigsReq) error {
|
||
if c.Host != proxy.Host {
|
||
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
|
||
}
|
||
if len(params) != 3 {
|
||
return fmt.Errorf("端口数量不符合预期: %d", len(params))
|
||
}
|
||
for _, param := range params {
|
||
if param.Status != true {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.AutoEdgeConfig == nil {
|
||
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
|
||
}
|
||
if param.Userpass == nil || *param.Userpass != "" {
|
||
return fmt.Errorf("用户名密码不符合预期: %v", *param.Userpass)
|
||
}
|
||
if param.Whitelist == nil || len(*param.Whitelist) == 0 {
|
||
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
|
||
}
|
||
config := param.AutoEdgeConfig
|
||
if config.Province != "北京市" {
|
||
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
|
||
}
|
||
if *config.Count != 1 {
|
||
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
|
||
}
|
||
if config.PacketLoss != 30 {
|
||
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
},
|
||
want: func(t *testing.T, got []*PortInfo) error {
|
||
// 验证返回结果
|
||
if len(got) != 3 {
|
||
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got))
|
||
}
|
||
|
||
// 验证结果
|
||
var gotMap = make(map[int]PortInfo)
|
||
for _, port := range got {
|
||
if port.Proto != 1 {
|
||
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
|
||
}
|
||
if port.Host != proxy.Host {
|
||
return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host)
|
||
}
|
||
gotMap[port.Port] = *port
|
||
}
|
||
|
||
// 验证数据库字段
|
||
var channels []*models.Channel
|
||
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
|
||
for _, ch := range channels {
|
||
if ch.Protocol != 1 {
|
||
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
|
||
}
|
||
if ch.UserID != userAuth.Payload.Id {
|
||
return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID)
|
||
}
|
||
// todo 多代理分配策略,验证 proxy_host
|
||
if ch.ProxyID != proxy.ID {
|
||
return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID)
|
||
}
|
||
var info, ok = gotMap[int(ch.ProxyPort)]
|
||
if !ok {
|
||
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
|
||
}
|
||
if ch.AuthPass != false && ch.AuthIP != true {
|
||
return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass)
|
||
}
|
||
if ch.Protocol != int32(info.Proto) {
|
||
return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol)
|
||
}
|
||
if ch.Expiration.IsZero() {
|
||
return fmt.Errorf("通道过期时间不应为空")
|
||
}
|
||
|
||
// 检查Redis缓存中的字段
|
||
key := fmt.Sprintf("channel:%d", ch.ID)
|
||
if !mr.Exists(key) {
|
||
return fmt.Errorf("redis缓存中应有键 %s", key)
|
||
}
|
||
|
||
data, _ := mr.Get(key)
|
||
var cache models.Channel
|
||
err := json.Unmarshal([]byte(data), &cache)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解析缓存数据: %v", err)
|
||
}
|
||
if reflect.DeepEqual(cache, *ch) {
|
||
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
|
||
}
|
||
}
|
||
|
||
// 检查跨天用量更新
|
||
var pss models.ResourcePss
|
||
db.Where("resource_id = ?", 1).First(&pss)
|
||
if pss.DailyUsed != 3 {
|
||
return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed)
|
||
}
|
||
if time.Time(pss.DailyLast).IsZero() {
|
||
return fmt.Errorf("套餐每日最后更新时间不应为空")
|
||
}
|
||
if pss.Used != 3 {
|
||
return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used)
|
||
}
|
||
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "管理员替用户创建HTTP密码通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: adminAuth,
|
||
resourceId: 1,
|
||
protocol: ProtocolSocks5,
|
||
authType: ChannelAuthTypePass,
|
||
count: 3,
|
||
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
|
||
mc.ConnectMock = func(param baiyin.CloudConnectReq) error {
|
||
if param.Uuid != proxy.Name {
|
||
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||
}
|
||
if len(param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||
}
|
||
if !reflect.DeepEqual(param.AutoConfig, []baiyin.AutoConfig{
|
||
{Province: "河南省", Count: 10},
|
||
{Province: "北京市", Count: 6},
|
||
}) {
|
||
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []baiyin.PortConfigsReq) error {
|
||
if c.Host != proxy.Host {
|
||
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
|
||
}
|
||
if len(params) != 3 {
|
||
return fmt.Errorf("端口数量不符合预期: %d", len(params))
|
||
}
|
||
for _, param := range params {
|
||
if param.Status != true {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.AutoEdgeConfig == nil {
|
||
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
|
||
}
|
||
if param.Userpass == nil || *param.Userpass == "" {
|
||
return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass)
|
||
}
|
||
if param.Whitelist == nil || len(*param.Whitelist) != 0 {
|
||
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
|
||
}
|
||
config := param.AutoEdgeConfig
|
||
if config.Province != "北京市" {
|
||
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
|
||
}
|
||
if *config.Count != 1 {
|
||
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
|
||
}
|
||
if config.PacketLoss != 30 {
|
||
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
},
|
||
want: func(t *testing.T, got []*PortInfo) error {
|
||
// 验证返回结果
|
||
if len(got) != 3 {
|
||
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got))
|
||
}
|
||
|
||
// 验证结果
|
||
var gotMap = make(map[int]PortInfo)
|
||
for _, port := range got {
|
||
if port.Proto != 3 {
|
||
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
|
||
}
|
||
if port.Host != proxy.Host {
|
||
return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host)
|
||
}
|
||
gotMap[port.Port] = *port
|
||
}
|
||
|
||
// 验证数据库字段
|
||
var channels []*models.Channel
|
||
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
|
||
for _, ch := range channels {
|
||
if ch.Protocol != 1 {
|
||
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
|
||
}
|
||
if ch.UserID != userAuth.Payload.Id {
|
||
return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID)
|
||
}
|
||
// todo 多代理分配策略,验证 proxy_host
|
||
if ch.ProxyID != proxy.ID {
|
||
return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID)
|
||
}
|
||
var info, ok = gotMap[int(ch.ProxyPort)]
|
||
if !ok {
|
||
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
|
||
}
|
||
if ch.AuthPass != true && ch.AuthIP != false {
|
||
return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass)
|
||
}
|
||
if ch.Protocol != int32(info.Proto) {
|
||
return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol)
|
||
}
|
||
if ch.Username != *info.Username {
|
||
return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username)
|
||
}
|
||
if ch.Password != *info.Password {
|
||
return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password)
|
||
}
|
||
if ch.Expiration.IsZero() {
|
||
return fmt.Errorf("通道过期时间不应为空")
|
||
}
|
||
|
||
// 检查Redis缓存中的字段
|
||
key := fmt.Sprintf("channel:%d", ch.ID)
|
||
if !mr.Exists(key) {
|
||
return fmt.Errorf("redis缓存中应有键 %s", key)
|
||
}
|
||
|
||
data, _ := mr.Get(key)
|
||
var cache models.Channel
|
||
err := json.Unmarshal([]byte(data), &cache)
|
||
if err != nil {
|
||
return fmt.Errorf("无法解析缓存数据: %v", err)
|
||
}
|
||
if reflect.DeepEqual(cache, *ch) {
|
||
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
|
||
}
|
||
}
|
||
|
||
// 检查跨天用量更新
|
||
var pss models.ResourcePss
|
||
db.Where("resource_id = ?", 1).First(&pss)
|
||
if pss.DailyUsed != 3 {
|
||
return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed)
|
||
}
|
||
if time.Time(pss.DailyLast).IsZero() {
|
||
return fmt.Errorf("套餐每日最后更新时间不应为空")
|
||
}
|
||
if pss.Used != 3 {
|
||
return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used)
|
||
}
|
||
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "套餐不存在",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 999,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypeIp,
|
||
count: 1,
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
},
|
||
wantErr: true,
|
||
wantErrContains: "无权限访问",
|
||
},
|
||
{
|
||
name: "套餐没有权限",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 2,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypeIp,
|
||
count: 1,
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
|
||
resource2 := &models.Resource{
|
||
ID: 2,
|
||
UserID: 102,
|
||
Active: true,
|
||
}
|
||
db.Create(resource2)
|
||
var resourcePss2 = &models.ResourcePss{
|
||
ID: 2,
|
||
ResourceID: 2,
|
||
Type: 1,
|
||
Live: 180,
|
||
Expire: common.LocalDateTime(time.Now().AddDate(1, 0, 0)),
|
||
DailyLimit: 10000,
|
||
}
|
||
db.Create(resourcePss2)
|
||
},
|
||
wantErr: true,
|
||
wantErrContains: "无权限访问",
|
||
},
|
||
{
|
||
name: "套餐配额不足",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 2,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypeIp,
|
||
count: 10,
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
|
||
// 创建一个配额几乎用完的资源包
|
||
resource2 := models.Resource{
|
||
ID: 2,
|
||
UserID: 101,
|
||
Active: true,
|
||
}
|
||
resourcePss2 := models.ResourcePss{
|
||
ID: 2,
|
||
ResourceID: 2,
|
||
Type: 2,
|
||
Quota: 100,
|
||
Used: 91,
|
||
Live: 180,
|
||
DailyLimit: 10000,
|
||
}
|
||
db.Create(&resource2).Create(&resourcePss2)
|
||
},
|
||
wantErr: true,
|
||
wantErrContains: "套餐配额不足",
|
||
},
|
||
{
|
||
name: "端口数量达到上限",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
resourceId: 1,
|
||
protocol: ProtocolHTTP,
|
||
authType: ChannelAuthTypeIp,
|
||
count: 1,
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
resetDb()
|
||
mc.AutoQueryMock = func() (baiyin.CloudConnectResp, error) {
|
||
return baiyin.CloudConnectResp{
|
||
"test-proxy": []baiyin.AutoConfig{
|
||
{Count: 20000},
|
||
},
|
||
}, nil
|
||
}
|
||
// 创建大量占用端口的通道
|
||
var channels = make([]models.Channel, 10000)
|
||
var expr = time.Now().Add(time.Hour)
|
||
for i := range channels {
|
||
channels[i] = models.Channel{
|
||
ProxyID: 1,
|
||
ProxyPort: int32(i + 10000),
|
||
UserID: 101,
|
||
Expiration: expr,
|
||
}
|
||
}
|
||
db.CreateInBatches(channels, 1000)
|
||
},
|
||
wantErr: true,
|
||
wantErrContains: "端口数量到达上限",
|
||
},
|
||
// todo 多地区混杂条件提取
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
if tt.setup != nil {
|
||
tt.setup()
|
||
}
|
||
|
||
s := &channelService{}
|
||
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
|
||
|
||
// 检查错误或结果
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Errorf("CreateChannel() 应当返回错误")
|
||
return
|
||
}
|
||
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||
}
|
||
return
|
||
}
|
||
|
||
if err != nil {
|
||
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
// 使用检查函数验证结果
|
||
if err := tt.want(t, got); err != nil {
|
||
t.Errorf("结果验证失败: %v", err)
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
func Test_channelService_RemoveChannels(t *testing.T) {
|
||
mr := testutil.SetupRedisTest(t)
|
||
md := testutil.SetupDBTest(t)
|
||
mg := testutil.SetupGatewayClientMock(t)
|
||
mc := testutil.SetupCloudClientMock(t)
|
||
|
||
type args struct {
|
||
ctx context.Context
|
||
auth *AuthContext
|
||
id []int32
|
||
}
|
||
|
||
// 准备测试数据
|
||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||
|
||
// 创建用户
|
||
var user = &models.User{
|
||
ID: 101,
|
||
Phone: "12312341234",
|
||
}
|
||
md.Create(user)
|
||
|
||
// 创建管理员
|
||
var adminUser = &models.User{
|
||
ID: 100,
|
||
Phone: "99999999999",
|
||
}
|
||
md.Create(adminUser)
|
||
|
||
// 认证上下文
|
||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||
|
||
// 创建代理
|
||
var proxy = &models.Proxy{
|
||
ID: 1,
|
||
Version: 1,
|
||
Name: "test-proxy",
|
||
Host: "111.111.111.111",
|
||
Type: 1,
|
||
Secret: "test:secret",
|
||
}
|
||
md.Create(proxy)
|
||
|
||
var proxy2 = &models.Proxy{
|
||
ID: 2,
|
||
Version: 1,
|
||
Name: "test-proxy-2",
|
||
Host: "222.222.222.222",
|
||
Type: 1,
|
||
Secret: "test:secret2",
|
||
}
|
||
md.Create(proxy2)
|
||
|
||
// 清空数据库函数
|
||
var clearDb = func() {
|
||
md.Exec("delete from channel where true")
|
||
mr.FlushAll()
|
||
}
|
||
|
||
tests := []struct {
|
||
name string
|
||
args args
|
||
setup func()
|
||
wantErr bool
|
||
wantErrContains string
|
||
want func(t *testing.T) error
|
||
}{
|
||
{
|
||
name: "管理员删除多个通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: adminAuth,
|
||
id: []int32{1, 2, 3},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
clearDb()
|
||
|
||
// 创建通道
|
||
channels := []models.Channel{
|
||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||
}
|
||
|
||
// 保存预设数据
|
||
md.Create(channels)
|
||
for _, channel := range channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
data, _ := json.Marshal(channel)
|
||
_ = mr.Set(key, string(data))
|
||
}
|
||
|
||
// 模拟网关客户端的响应
|
||
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...baiyin.PortActiveReq) (map[string]baiyin.PortData, error) {
|
||
switch {
|
||
case m.Host == proxy.Host:
|
||
return map[string]baiyin.PortData{
|
||
"10001": {Edge: []string{"edge1", "edge4"}},
|
||
"10002": {Edge: []string{"edge2"}},
|
||
}, nil
|
||
case m.Host == proxy2.Host:
|
||
return map[string]baiyin.PortData{
|
||
"10001": {Edge: []string{"edge3"}},
|
||
}, nil
|
||
}
|
||
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||
}
|
||
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []baiyin.PortConfigsReq) error {
|
||
switch {
|
||
case m.Host == proxy.Host:
|
||
for _, param := range params {
|
||
if param.Port != 10001 && param.Port != 10002 {
|
||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||
}
|
||
if param.Status != false {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期1: %v", param.Edge)
|
||
}
|
||
}
|
||
return nil
|
||
case m.Host == proxy2.Host:
|
||
for _, param := range params {
|
||
if param.Port != 10001 {
|
||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||
}
|
||
if param.Status != false {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期2: %v", param.Edge)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||
}
|
||
mc.DisconnectMock = func(param baiyin.CloudDisconnectReq) (int, error) {
|
||
switch {
|
||
case param.Uuid == proxy.Name:
|
||
var edges = []string{"edge1", "edge2", "edge4"}
|
||
if !testutil.SliceEqual(edges, param.Edge) {
|
||
return 0, fmt.Errorf("边缘节点不符合预期3: %v", param.Edge)
|
||
}
|
||
if len(param.Config) != 0 {
|
||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||
}
|
||
return len(param.Edge), nil
|
||
case param.Uuid == proxy2.Name:
|
||
var edges = []string{"edge3"}
|
||
if !testutil.SliceEqual(edges, param.Edge) {
|
||
return 0, fmt.Errorf("边缘节点不符合预期4: %v", param.Edge)
|
||
}
|
||
if len(param.Config) != 0 {
|
||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||
}
|
||
return len(param.Edge), nil
|
||
}
|
||
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||
}
|
||
},
|
||
want: func(t *testing.T) error {
|
||
// 检查通道是否被软删除
|
||
var count int64
|
||
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
|
||
if count > 0 {
|
||
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
|
||
}
|
||
|
||
// 检查Redis缓存是否被删除
|
||
for _, id := range []int32{1, 2, 3} {
|
||
key := fmt.Sprintf("channel:%d", id)
|
||
if mr.Exists(key) {
|
||
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||
}
|
||
}
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "用户删除自己的通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
id: []int32{1, 2, 3},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
clearDb()
|
||
|
||
// 创建通道
|
||
channels := []models.Channel{
|
||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||
}
|
||
|
||
// 保存预设数据
|
||
md.Create(channels)
|
||
for _, channel := range channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
data, _ := json.Marshal(channel)
|
||
_ = mr.Set(key, string(data))
|
||
}
|
||
|
||
// 模拟网关客户端的响应
|
||
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...baiyin.PortActiveReq) (map[string]baiyin.PortData, error) {
|
||
switch {
|
||
case m.Host == proxy.Host:
|
||
return map[string]baiyin.PortData{
|
||
"10001": {Edge: []string{"edge1", "edge4"}},
|
||
"10002": {Edge: []string{"edge2"}},
|
||
}, nil
|
||
case m.Host == proxy2.Host:
|
||
return map[string]baiyin.PortData{
|
||
"10001": {Edge: []string{"edge3"}},
|
||
}, nil
|
||
}
|
||
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||
}
|
||
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []baiyin.PortConfigsReq) error {
|
||
switch {
|
||
case m.Host == proxy.Host:
|
||
for _, param := range params {
|
||
if param.Port != 10001 && param.Port != 10002 {
|
||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||
}
|
||
if param.Status != false {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期5: %v", param.Edge)
|
||
}
|
||
}
|
||
return nil
|
||
case m.Host == proxy2.Host:
|
||
for _, param := range params {
|
||
if param.Port != 10001 {
|
||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||
}
|
||
if param.Status != false {
|
||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||
}
|
||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||
return fmt.Errorf("边缘节点不符合预期6: %v", param.Edge)
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||
}
|
||
mc.DisconnectMock = func(param baiyin.CloudDisconnectReq) (int, error) {
|
||
switch {
|
||
case param.Uuid == proxy.Name:
|
||
var edges = []string{"edge1", "edge2", "edge4"}
|
||
if !testutil.SliceEqual(edges, param.Edge) {
|
||
return 0, fmt.Errorf("边缘节点不符合预期7: %v", param.Edge)
|
||
}
|
||
if len(param.Config) != 0 {
|
||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||
}
|
||
return len(param.Edge), nil
|
||
case param.Uuid == proxy2.Name:
|
||
var edges = []string{"edge3"}
|
||
if !testutil.SliceEqual(edges, param.Edge) {
|
||
return 0, fmt.Errorf("边缘节点不符合预期8: %v", param.Edge)
|
||
}
|
||
if len(param.Config) != 0 {
|
||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||
}
|
||
return len(param.Edge), nil
|
||
}
|
||
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||
}
|
||
},
|
||
want: func(t *testing.T) error {
|
||
// 检查通道是否被软删除
|
||
var count int64
|
||
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
|
||
if count > 0 {
|
||
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
|
||
}
|
||
|
||
// 检查Redis缓存是否被删除
|
||
for _, id := range []int32{1, 2, 3} {
|
||
key := fmt.Sprintf("channel:%d", id)
|
||
if mr.Exists(key) {
|
||
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||
}
|
||
}
|
||
return nil
|
||
},
|
||
},
|
||
{
|
||
name: "用户删除不属于自己的通道",
|
||
args: args{
|
||
ctx: ctx,
|
||
auth: userAuth,
|
||
id: []int32{1, 2, 3},
|
||
},
|
||
setup: func() {
|
||
mr.FlushAll()
|
||
clearDb()
|
||
|
||
// 创建通道
|
||
channels := []models.Channel{
|
||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: time.Now().Add(24 * time.Hour)},
|
||
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: time.Now().Add(24 * time.Hour)},
|
||
}
|
||
|
||
// 保存预设数据
|
||
md.Create(channels)
|
||
for _, channel := range channels {
|
||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||
data, _ := json.Marshal(channel)
|
||
_ = mr.Set(key, string(data))
|
||
}
|
||
},
|
||
wantErr: true,
|
||
wantErrContains: "无权限访问",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
if tt.setup != nil {
|
||
tt.setup()
|
||
}
|
||
|
||
s := &channelService{}
|
||
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
|
||
|
||
// 检查错误
|
||
if tt.wantErr {
|
||
if err == nil {
|
||
t.Errorf("RemoveChannels() 应当返回错误")
|
||
return
|
||
}
|
||
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||
}
|
||
return
|
||
}
|
||
|
||
if err != nil {
|
||
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||
return
|
||
}
|
||
|
||
// 检查数据库和缓存是否正确设置
|
||
|
||
if err := tt.want(t); err != nil {
|
||
t.Errorf("RemoveChannels() 结果验证失败: %v", err)
|
||
}
|
||
})
|
||
}
|
||
}
|