Files
platform/web/services/channel_test.go
2025-05-07 17:39:36 +08:00

1323 lines
38 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"context"
"encoding/json"
"fmt"
"platform/pkg/testutil"
"platform/web/core"
g "platform/web/globals"
"platform/web/models"
"reflect"
"strings"
"testing"
"time"
"github.com/gofiber/fiber/v2/middleware/requestid"
)
func Test_genPassPair(t *testing.T) {
tests := []struct {
name string
}{
{
name: "正常生成随机用户名和密码",
},
{
name: "多次调用生成不同的值",
},
}
// 第一个测试:检查生成的用户名和密码是否有效
t.Run(tests[0].name, func(t *testing.T) {
username, password := genPassPair()
if username == "" {
t.Errorf("genPassPair() username is empty")
}
if password == "" {
t.Errorf("genPassPair() password is empty")
}
})
// 第二个测试:确保多次调用生成不同的值
t.Run(tests[1].name, func(t *testing.T) {
username1, password1 := genPassPair()
username2, password2 := genPassPair()
if username1 == username2 {
t.Errorf("genPassPair() generated the same username twice: %v", username1)
}
if password1 == password2 {
t.Errorf("genPassPair() generated the same password twice: %v", password1)
}
})
}
func Test_chKey(t *testing.T) {
type args struct {
channel *models.Channel
}
tests := []struct {
name string
args args
want string
}{
{
name: "ID为1的通道",
args: args{
channel: &models.Channel{
ID: 1,
},
},
want: "channel:1",
},
{
name: "ID为100的通道",
args: args{
channel: &models.Channel{
ID: 100,
},
},
want: "channel:100",
},
{
name: "ID为0的通道",
args: args{
channel: &models.Channel{
ID: 0,
},
},
want: "channel:0",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := chKey(tt.args.channel); got != tt.want {
t.Errorf("chKey() = %v, want %v", got, tt.want)
}
})
}
}
func Test_cache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
now := time.Now()
expiration := core.LocalDateTime(now.Add(24 * time.Hour))
testChannels := []*models.Channel{
{
ID: 1,
UserID: 100,
ProxyID: 10,
ProxyPort: 8080,
Protocol: 1,
Expiration: expiration,
},
{
ID: 2,
UserID: 101,
ProxyID: 11,
ProxyPort: 8081,
Protocol: 3,
Expiration: expiration,
},
}
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常缓存多个通道",
args: args{
ctx: context.Background(),
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: context.Background(),
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证缓存结果
if len(tt.args.channels) > 0 {
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if !mr.Exists(key) {
t.Errorf("缓存未包含通道键 %s", key)
} else {
// 验证缓存的数据是否正确
data, _ := mr.Get(key)
var cachedChannel models.Channel
err := json.Unmarshal([]byte(data), &cachedChannel)
if err != nil {
t.Errorf("无法解析缓存数据: %v", err)
}
if cachedChannel.ID != channel.ID {
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
}
}
}
// 验证是否设置了过期时间
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
ttl := mr.TTL(key)
if ttl <= 0 {
t.Errorf("键 %s 没有设置过期时间", key)
}
}
// 验证是否添加了有序集合
if !mr.Exists("tasks:channel") {
t.Errorf("ZAdd未创建有序集合 tasks:channel")
}
}
})
}
}
func Test_deleteCache(t *testing.T) {
mr := testutil.SetupRedisTest(t)
type args struct {
ctx context.Context
channels []*models.Channel
}
// 准备测试数据
testChannels := []*models.Channel{
{ID: 1},
{ID: 2},
{ID: 3},
}
ctx := context.Background()
tests := []struct {
name string
args args
wantErr bool
}{
{
name: "正常删除多个通道缓存",
args: args{
ctx: ctx,
channels: testChannels,
},
wantErr: false,
},
{
name: "空通道列表",
args: args{
ctx: ctx,
channels: []*models.Channel{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll() // 清空 Redis 数据
// 预先设置缓存数据
for _, channel := range testChannels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
}
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
return
}
// 验证删除结果
for _, channel := range tt.args.channels {
key := fmt.Sprintf("channel:%d", channel.ID)
if mr.Exists(key) {
t.Errorf("通道键 %s 未被删除", key)
}
}
})
}
}
func Test_channelService_CreateChannel(t *testing.T) {
mr := testutil.SetupRedisTest(t)
db := testutil.SetupDBTest(t)
mc := testutil.SetupCloudClientMock(t)
mg := testutil.SetupGatewayClientMock(t)
type args struct {
ctx context.Context
auth *AuthContext
resourceId int32
protocol ChannelProtocol
authType ChannelAuthType
count int
nodeFilter []NodeFilterConfig
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
mc.AutoQueryMock = func() (g.CloudConnectResp, error) {
return g.CloudConnectResp{
"test-proxy": []g.AutoConfig{
{Province: "河南省", Count: 10},
},
}, nil
}
var user *models.User
var whitelists []*models.Whitelist
var proxy *models.Proxy
var resource *models.Resource
var resourcePss *models.ResourcePss
var resetDb = func() {
user = &models.User{
ID: 101,
Phone: "12312341234",
}
db.Exec("delete from user where true")
db.Create(user)
whitelists = []*models.Whitelist{
{ID: 1, UserID: 101, Host: "123.123.123.123"},
{ID: 2, UserID: 101, Host: "456.456.456.456"},
{ID: 3, UserID: 101, Host: "789.789.789.789"},
}
db.Exec("delete from whitelist where true")
db.Create(whitelists)
proxy = &models.Proxy{
ID: 1,
Version: 1,
Name: "test-proxy",
Host: "111.111.111.111",
Type: 1,
Secret: "test:secret",
}
db.Exec("delete from proxy where true")
db.Create(proxy)
resource = &models.Resource{
ID: 1,
UserID: 101,
Active: true,
}
db.Exec("delete from resource where true")
db.Create(resource)
resourcePss = &models.ResourcePss{
ID: 1,
ResourceID: 1,
Type: 1,
Live: 180,
Expire: core.LocalDateTime(time.Now().AddDate(1, 0, 0)),
DailyLimit: 10000,
}
db.Exec("delete from resource_pss where true")
db.Create(resourcePss)
db.Exec("delete from channel where true")
}
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
want func(t *testing.T, got []*PortInfo) error
}{
{
name: "用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
},
setup: func() {
mr.FlushAll()
resetDb()
mc.ConnectMock = func(param g.CloudConnectReq) error {
if param.Uuid != proxy.Name {
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
if len(param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{
{Province: "河南省", Count: 10},
{Province: "北京市", Count: 6},
}) {
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
}
return nil
}
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error {
if c.Host != proxy.Host {
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
}
if len(params) != 3 {
return fmt.Errorf("端口数量不符合预期: %d", len(params))
}
for _, param := range params {
if param.Status != true {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.AutoEdgeConfig == nil {
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
}
if param.Userpass == nil || *param.Userpass == "" {
return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass)
}
if param.Whitelist == nil || len(*param.Whitelist) != 0 {
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
}
config := param.AutoEdgeConfig
if config.Province != "北京市" {
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
}
if *config.Count != 1 {
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
}
if config.PacketLoss != 30 {
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
}
}
return nil
}
},
want: func(t *testing.T, got []*PortInfo) error {
// 验证返回结果
if len(got) != 3 {
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3得到 %d", len(got))
}
// 验证结果
var gotMap = make(map[int]PortInfo)
for _, port := range got {
if port.Proto != 1 {
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
}
if port.Host != proxy.Host {
return fmt.Errorf("期望主机为 %s得到 %s", proxy.Host, port.Host)
}
gotMap[port.Port] = *port
}
// 验证数据库字段
var channels []*models.Channel
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
for _, ch := range channels {
if ch.Protocol != 1 {
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
}
if ch.UserID != userAuth.Payload.Id {
return fmt.Errorf("通道用户ID不正确期望 %d得到 %d", userAuth.Payload.Id, ch.UserID)
}
// todo 多代理分配策略,验证 proxy_host
if ch.ProxyID != proxy.ID {
return fmt.Errorf("通道代理ID不正确期望 %d得到 %d", proxy.ID, ch.ProxyID)
}
var info, ok = gotMap[int(ch.ProxyPort)]
if !ok {
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
}
if ch.AuthPass != true && ch.AuthIP != false {
return fmt.Errorf("通道认证类型不正确,期望 Pass得到 %v", ch.AuthPass)
}
if ch.Protocol != int32(info.Proto) {
return fmt.Errorf("通道协议不正确,期望 %d得到 %d", info.Proto, ch.Protocol)
}
if ch.Username != *info.Username {
return fmt.Errorf("通道用户名不正确,期望 %s得到 %s", *info.Username, ch.Username)
}
if ch.Password != *info.Password {
return fmt.Errorf("通道密码不正确,期望 %s得到 %s", *info.Password, ch.Password)
}
if time.Time(ch.Expiration).IsZero() {
return fmt.Errorf("通道过期时间不应为空")
}
// 检查Redis缓存中的字段
key := fmt.Sprintf("channel:%d", ch.ID)
if !mr.Exists(key) {
return fmt.Errorf("redis缓存中应有键 %s", key)
}
data, _ := mr.Get(key)
var cache models.Channel
err := json.Unmarshal([]byte(data), &cache)
if err != nil {
return fmt.Errorf("无法解析缓存数据: %v", err)
}
if reflect.DeepEqual(cache, *ch) {
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
}
}
// 检查跨天用量更新
var pss models.ResourcePss
db.Where("resource_id = ?", 1).First(&pss)
if pss.DailyUsed != 3 {
return fmt.Errorf("套餐每日用量不正确,期望 3得到 %d", pss.DailyUsed)
}
if time.Time(pss.DailyLast).IsZero() {
return fmt.Errorf("套餐每日最后更新时间不应为空")
}
if pss.Used != 3 {
return fmt.Errorf("套餐总用量不正确,期望 3得到 %d", pss.Used)
}
return nil
},
},
{
name: "用户创建HTTP白名单通道",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
},
setup: func() {
mr.FlushAll()
resetDb()
mc.ConnectMock = func(param g.CloudConnectReq) error {
if param.Uuid != proxy.Name {
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
if len(param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{
{Province: "河南省", Count: 10},
{Province: "北京市", Count: 6},
}) {
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
}
return nil
}
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error {
if c.Host != proxy.Host {
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
}
if len(params) != 3 {
return fmt.Errorf("端口数量不符合预期: %d", len(params))
}
for _, param := range params {
if param.Status != true {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.AutoEdgeConfig == nil {
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
}
if param.Userpass == nil || *param.Userpass != "" {
return fmt.Errorf("用户名密码不符合预期: %v", *param.Userpass)
}
if param.Whitelist == nil || len(*param.Whitelist) == 0 {
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
}
config := param.AutoEdgeConfig
if config.Province != "北京市" {
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
}
if *config.Count != 1 {
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
}
if config.PacketLoss != 30 {
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
}
}
return nil
}
},
want: func(t *testing.T, got []*PortInfo) error {
// 验证返回结果
if len(got) != 3 {
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3得到 %d", len(got))
}
// 验证结果
var gotMap = make(map[int]PortInfo)
for _, port := range got {
if port.Proto != 1 {
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
}
if port.Host != proxy.Host {
return fmt.Errorf("期望主机为 %s得到 %s", proxy.Host, port.Host)
}
gotMap[port.Port] = *port
}
// 验证数据库字段
var channels []*models.Channel
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
for _, ch := range channels {
if ch.Protocol != 1 {
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
}
if ch.UserID != userAuth.Payload.Id {
return fmt.Errorf("通道用户ID不正确期望 %d得到 %d", userAuth.Payload.Id, ch.UserID)
}
// todo 多代理分配策略,验证 proxy_host
if ch.ProxyID != proxy.ID {
return fmt.Errorf("通道代理ID不正确期望 %d得到 %d", proxy.ID, ch.ProxyID)
}
var info, ok = gotMap[int(ch.ProxyPort)]
if !ok {
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
}
if ch.AuthPass != false && ch.AuthIP != true {
return fmt.Errorf("通道认证类型不正确,期望 Pass得到 %v", ch.AuthPass)
}
if ch.Protocol != int32(info.Proto) {
return fmt.Errorf("通道协议不正确,期望 %d得到 %d", info.Proto, ch.Protocol)
}
if time.Time(ch.Expiration).IsZero() {
return fmt.Errorf("通道过期时间不应为空")
}
// 检查Redis缓存中的字段
key := fmt.Sprintf("channel:%d", ch.ID)
if !mr.Exists(key) {
return fmt.Errorf("redis缓存中应有键 %s", key)
}
data, _ := mr.Get(key)
var cache models.Channel
err := json.Unmarshal([]byte(data), &cache)
if err != nil {
return fmt.Errorf("无法解析缓存数据: %v", err)
}
if reflect.DeepEqual(cache, *ch) {
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
}
}
// 检查跨天用量更新
var pss models.ResourcePss
db.Where("resource_id = ?", 1).First(&pss)
if pss.DailyUsed != 3 {
return fmt.Errorf("套餐每日用量不正确,期望 3得到 %d", pss.DailyUsed)
}
if time.Time(pss.DailyLast).IsZero() {
return fmt.Errorf("套餐每日最后更新时间不应为空")
}
if pss.Used != 3 {
return fmt.Errorf("套餐总用量不正确,期望 3得到 %d", pss.Used)
}
return nil
},
},
{
name: "管理员替用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: adminAuth,
resourceId: 1,
protocol: ProtocolSocks5,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "北京市"}},
},
setup: func() {
mr.FlushAll()
resetDb()
mc.ConnectMock = func(param g.CloudConnectReq) error {
if param.Uuid != proxy.Name {
return fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
if len(param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
}
if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{
{Province: "河南省", Count: 10},
{Province: "北京市", Count: 6},
}) {
return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig)
}
return nil
}
mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error {
if c.Host != proxy.Host {
return fmt.Errorf("代理主机不符合预期: %s", c.Host)
}
if len(params) != 3 {
return fmt.Errorf("端口数量不符合预期: %d", len(params))
}
for _, param := range params {
if param.Status != true {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.AutoEdgeConfig == nil {
return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig)
}
if param.Userpass == nil || *param.Userpass == "" {
return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass)
}
if param.Whitelist == nil || len(*param.Whitelist) != 0 {
return fmt.Errorf("白名单不符合预期: %v", param.Whitelist)
}
config := param.AutoEdgeConfig
if config.Province != "北京市" {
return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province)
}
if *config.Count != 1 {
return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count)
}
if config.PacketLoss != 30 {
return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss)
}
}
return nil
}
},
want: func(t *testing.T, got []*PortInfo) error {
// 验证返回结果
if len(got) != 3 {
return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3得到 %d", len(got))
}
// 验证结果
var gotMap = make(map[int]PortInfo)
for _, port := range got {
if port.Proto != 3 {
return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto)
}
if port.Host != proxy.Host {
return fmt.Errorf("期望主机为 %s得到 %s", proxy.Host, port.Host)
}
gotMap[port.Port] = *port
}
// 验证数据库字段
var channels []*models.Channel
db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels)
for _, ch := range channels {
if ch.Protocol != 1 {
return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol)
}
if ch.UserID != userAuth.Payload.Id {
return fmt.Errorf("通道用户ID不正确期望 %d得到 %d", userAuth.Payload.Id, ch.UserID)
}
// todo 多代理分配策略,验证 proxy_host
if ch.ProxyID != proxy.ID {
return fmt.Errorf("通道代理ID不正确期望 %d得到 %d", proxy.ID, ch.ProxyID)
}
var info, ok = gotMap[int(ch.ProxyPort)]
if !ok {
return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort)
}
if ch.AuthPass != true && ch.AuthIP != false {
return fmt.Errorf("通道认证类型不正确,期望 Pass得到 %v", ch.AuthPass)
}
if ch.Protocol != int32(info.Proto) {
return fmt.Errorf("通道协议不正确,期望 %d得到 %d", info.Proto, ch.Protocol)
}
if ch.Username != *info.Username {
return fmt.Errorf("通道用户名不正确,期望 %s得到 %s", *info.Username, ch.Username)
}
if ch.Password != *info.Password {
return fmt.Errorf("通道密码不正确,期望 %s得到 %s", *info.Password, ch.Password)
}
if time.Time(ch.Expiration).IsZero() {
return fmt.Errorf("通道过期时间不应为空")
}
// 检查Redis缓存中的字段
key := fmt.Sprintf("channel:%d", ch.ID)
if !mr.Exists(key) {
return fmt.Errorf("redis缓存中应有键 %s", key)
}
data, _ := mr.Get(key)
var cache models.Channel
err := json.Unmarshal([]byte(data), &cache)
if err != nil {
return fmt.Errorf("无法解析缓存数据: %v", err)
}
if reflect.DeepEqual(cache, *ch) {
return fmt.Errorf("缓存数据与数据库不匹配: %v", cache)
}
}
// 检查跨天用量更新
var pss models.ResourcePss
db.Where("resource_id = ?", 1).First(&pss)
if pss.DailyUsed != 3 {
return fmt.Errorf("套餐每日用量不正确,期望 3得到 %d", pss.DailyUsed)
}
if time.Time(pss.DailyLast).IsZero() {
return fmt.Errorf("套餐每日最后更新时间不应为空")
}
if pss.Used != 3 {
return fmt.Errorf("套餐总用量不正确,期望 3得到 %d", pss.Used)
}
return nil
},
},
{
name: "套餐不存在",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 999,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
mr.FlushAll()
resetDb()
},
wantErr: true,
wantErrContains: "无权限访问",
},
{
name: "套餐没有权限",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
mr.FlushAll()
resetDb()
resource2 := &models.Resource{
ID: 2,
UserID: 102,
Active: true,
}
db.Create(resource2)
var resourcePss2 = &models.ResourcePss{
ID: 2,
ResourceID: 2,
Type: 1,
Live: 180,
Expire: core.LocalDateTime(time.Now().AddDate(1, 0, 0)),
DailyLimit: 10000,
}
db.Create(resourcePss2)
},
wantErr: true,
wantErrContains: "无权限访问",
},
{
name: "套餐配额不足",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 10,
},
setup: func() {
mr.FlushAll()
resetDb()
// 创建一个配额几乎用完的资源包
resource2 := models.Resource{
ID: 2,
UserID: 101,
Active: true,
}
resourcePss2 := models.ResourcePss{
ID: 2,
ResourceID: 2,
Type: 2,
Quota: 100,
Used: 91,
Live: 180,
DailyLimit: 10000,
}
db.Create(&resource2).Create(&resourcePss2)
},
wantErr: true,
wantErrContains: "套餐配额不足",
},
{
name: "端口数量达到上限",
args: args{
ctx: ctx,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
mr.FlushAll()
resetDb()
mc.AutoQueryMock = func() (g.CloudConnectResp, error) {
return g.CloudConnectResp{
"test-proxy": []g.AutoConfig{
{Count: 20000},
},
}, nil
}
// 创建大量占用端口的通道
var channels = make([]models.Channel, 10000)
var expr = time.Now().Add(time.Hour)
for i := range channels {
channels[i] = models.Channel{
ProxyID: 1,
ProxyPort: int32(i + 10000),
UserID: 101,
Expiration: core.LocalDateTime(expr),
}
}
db.CreateInBatches(channels, 1000)
},
wantErr: true,
wantErrContains: "端口数量到达上限",
},
// todo 多地区混杂条件提取
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
// 检查错误或结果
if tt.wantErr {
if err == nil {
t.Errorf("CreateChannel() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
// 使用检查函数验证结果
if err := tt.want(t, got); err != nil {
t.Errorf("结果验证失败: %v", err)
}
})
}
}
func Test_channelService_RemoveChannels(t *testing.T) {
mr := testutil.SetupRedisTest(t)
md := testutil.SetupDBTest(t)
mg := testutil.SetupGatewayClientMock(t)
mc := testutil.SetupCloudClientMock(t)
type args struct {
ctx context.Context
auth *AuthContext
id []int32
}
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
// 创建用户
var user = &models.User{
ID: 101,
Phone: "12312341234",
}
md.Create(user)
// 创建管理员
var adminUser = &models.User{
ID: 100,
Phone: "99999999999",
}
md.Create(adminUser)
// 认证上下文
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
// 创建代理
var proxy = &models.Proxy{
ID: 1,
Version: 1,
Name: "test-proxy",
Host: "111.111.111.111",
Type: 1,
Secret: "test:secret",
}
md.Create(proxy)
var proxy2 = &models.Proxy{
ID: 2,
Version: 1,
Name: "test-proxy-2",
Host: "222.222.222.222",
Type: 1,
Secret: "test:secret2",
}
md.Create(proxy2)
// 清空数据库函数
var clearDb = func() {
md.Exec("delete from channel where true")
mr.FlushAll()
}
tests := []struct {
name string
args args
setup func()
wantErr bool
wantErrContains string
want func(t *testing.T) error
}{
{
name: "管理员删除多个通道",
args: args{
ctx: ctx,
auth: adminAuth,
id: []int32{1, 2, 3},
},
setup: func() {
mr.FlushAll()
clearDb()
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
}
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
// 模拟网关客户端的响应
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) {
switch {
case m.Host == proxy.Host:
return map[string]g.PortData{
"10001": {Edge: []string{"edge1", "edge4"}},
"10002": {Edge: []string{"edge2"}},
}, nil
case m.Host == proxy2.Host:
return map[string]g.PortData{
"10001": {Edge: []string{"edge3"}},
}, nil
}
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []g.PortConfigsReq) error {
switch {
case m.Host == proxy.Host:
for _, param := range params {
if param.Port != 10001 && param.Port != 10002 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期1: %v", param.Edge)
}
}
return nil
case m.Host == proxy2.Host:
for _, param := range params {
if param.Port != 10001 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期2: %v", param.Edge)
}
}
return nil
}
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) {
switch {
case param.Uuid == proxy.Name:
var edges = []string{"edge1", "edge2", "edge4"}
if !testutil.SliceEqual(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期3: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
case param.Uuid == proxy2.Name:
var edges = []string{"edge3"}
if !testutil.SliceEqual(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期4: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
}
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
},
want: func(t *testing.T) error {
// 检查通道是否被软删除
var count int64
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
if count > 0 {
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
}
// 检查Redis缓存是否被删除
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
}
return nil
},
},
{
name: "用户删除自己的通道",
args: args{
ctx: ctx,
auth: userAuth,
id: []int32{1, 2, 3},
},
setup: func() {
mr.FlushAll()
clearDb()
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
}
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
// 模拟网关客户端的响应
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) {
switch {
case m.Host == proxy.Host:
return map[string]g.PortData{
"10001": {Edge: []string{"edge1", "edge4"}},
"10002": {Edge: []string{"edge2"}},
}, nil
case m.Host == proxy2.Host:
return map[string]g.PortData{
"10001": {Edge: []string{"edge3"}},
}, nil
}
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []g.PortConfigsReq) error {
switch {
case m.Host == proxy.Host:
for _, param := range params {
if param.Port != 10001 && param.Port != 10002 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期5: %v", param.Edge)
}
}
return nil
case m.Host == proxy2.Host:
for _, param := range params {
if param.Port != 10001 {
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
}
if param.Status != false {
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
}
if param.Edge == nil || len(*param.Edge) != 0 {
return fmt.Errorf("边缘节点不符合预期6: %v", param.Edge)
}
}
return nil
}
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
}
mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) {
switch {
case param.Uuid == proxy.Name:
var edges = []string{"edge1", "edge2", "edge4"}
if !testutil.SliceEqual(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期7: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
case param.Uuid == proxy2.Name:
var edges = []string{"edge3"}
if !testutil.SliceEqual(edges, param.Edge) {
return 0, fmt.Errorf("边缘节点不符合预期8: %v", param.Edge)
}
if len(param.Config) != 0 {
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
}
return len(param.Edge), nil
}
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
}
},
want: func(t *testing.T) error {
// 检查通道是否被软删除
var count int64
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
if count > 0 {
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
}
// 检查Redis缓存是否被删除
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
}
}
return nil
},
},
{
name: "用户删除不属于自己的通道",
args: args{
ctx: ctx,
auth: userAuth,
id: []int32{1, 2, 3},
},
setup: func() {
mr.FlushAll()
clearDb()
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))},
}
// 保存预设数据
md.Create(channels)
for _, channel := range channels {
key := fmt.Sprintf("channel:%d", channel.ID)
data, _ := json.Marshal(channel)
_ = mr.Set(key, string(data))
}
},
wantErr: true,
wantErrContains: "无权限访问",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.setup != nil {
tt.setup()
}
s := &channelService{}
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
// 检查错误
if tt.wantErr {
if err == nil {
t.Errorf("RemoveChannels() 应当返回错误")
return
}
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
}
return
}
if err != nil {
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
return
}
// 检查数据库和缓存是否正确设置
if err := tt.want(t); err != nil {
t.Errorf("RemoveChannels() 结果验证失败: %v", err)
}
})
}
}