package services import ( "context" "encoding/json" "fmt" "platform/pkg/testutil" "platform/web/auth" "platform/web/core" g "platform/web/globals" "platform/web/models" "reflect" "strings" "testing" "time" "github.com/gofiber/fiber/v2/middleware/requestid" ) func Test_genPassPair(t *testing.T) { tests := []struct { name string }{ { name: "正常生成随机用户名和密码", }, { name: "多次调用生成不同的值", }, } // 第一个测试:检查生成的用户名和密码是否有效 t.Run(tests[0].name, func(t *testing.T) { username, password := genPassPair() if username == "" { t.Errorf("genPassPair() username is empty") } if password == "" { t.Errorf("genPassPair() password is empty") } }) // 第二个测试:确保多次调用生成不同的值 t.Run(tests[1].name, func(t *testing.T) { username1, password1 := genPassPair() username2, password2 := genPassPair() if username1 == username2 { t.Errorf("genPassPair() generated the same username twice: %v", username1) } if password1 == password2 { t.Errorf("genPassPair() generated the same password twice: %v", password1) } }) } func Test_chKey(t *testing.T) { type args struct { channel *models.Channel } tests := []struct { name string args args want string }{ { name: "ID为1的通道", args: args{ channel: &models.Channel{ ID: 1, }, }, want: "channel:1", }, { name: "ID为100的通道", args: args{ channel: &models.Channel{ ID: 100, }, }, want: "channel:100", }, { name: "ID为0的通道", args: args{ channel: &models.Channel{ ID: 0, }, }, want: "channel:0", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := chKey(tt.args.channel); got != tt.want { t.Errorf("chKey() = %v, want %v", got, tt.want) } }) } } func Test_cache(t *testing.T) { mr := testutil.SetupRedisTest(t) type args struct { ctx context.Context channels []*models.Channel } // 准备测试数据 now := time.Now() expiration := core.LocalDateTime(now.Add(24 * time.Hour)) testChannels := []*models.Channel{ { ID: 1, UserID: 100, ProxyID: 10, ProxyPort: 8080, Protocol: 1, Expiration: expiration, }, { ID: 2, UserID: 101, ProxyID: 11, ProxyPort: 8081, Protocol: 3, Expiration: expiration, }, } tests := []struct { name string args args wantErr bool }{ { name: "正常缓存多个通道", args: args{ ctx: context.Background(), channels: testChannels, }, wantErr: false, }, { name: "空通道列表", args: args{ ctx: context.Background(), channels: []*models.Channel{}, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mr.FlushAll() // 清空 Redis 数据 if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr) return } // 验证缓存结果 if len(tt.args.channels) > 0 { for _, channel := range tt.args.channels { key := fmt.Sprintf("channel:%d", channel.ID) if !mr.Exists(key) { t.Errorf("缓存未包含通道键 %s", key) } else { // 验证缓存的数据是否正确 data, _ := mr.Get(key) var cachedChannel models.Channel err := json.Unmarshal([]byte(data), &cachedChannel) if err != nil { t.Errorf("无法解析缓存数据: %v", err) } if cachedChannel.ID != channel.ID { t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID) } } } // 验证是否设置了过期时间 for _, channel := range tt.args.channels { key := fmt.Sprintf("channel:%d", channel.ID) ttl := mr.TTL(key) if ttl <= 0 { t.Errorf("键 %s 没有设置过期时间", key) } } // 验证是否添加了有序集合 if !mr.Exists("tasks:channel") { t.Errorf("ZAdd未创建有序集合 tasks:channel") } } }) } } func Test_deleteCache(t *testing.T) { mr := testutil.SetupRedisTest(t) type args struct { ctx context.Context channels []*models.Channel } // 准备测试数据 testChannels := []*models.Channel{ {ID: 1}, {ID: 2}, {ID: 3}, } ctx := context.Background() tests := []struct { name string args args wantErr bool }{ { name: "正常删除多个通道缓存", args: args{ ctx: ctx, channels: testChannels, }, wantErr: false, }, { name: "空通道列表", args: args{ ctx: ctx, channels: []*models.Channel{}, }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mr.FlushAll() // 清空 Redis 数据 // 预先设置缓存数据 for _, channel := range testChannels { key := fmt.Sprintf("channel:%d", channel.ID) data, _ := json.Marshal(channel) mr.Set(key, string(data)) mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间 } if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr) return } // 验证删除结果 for _, channel := range tt.args.channels { key := fmt.Sprintf("channel:%d", channel.ID) if mr.Exists(key) { t.Errorf("通道键 %s 未被删除", key) } } }) } } func Test_channelService_CreateChannel(t *testing.T) { mr := testutil.SetupRedisTest(t) db := testutil.SetupDBTest(t) mc := testutil.SetupCloudClientMock(t) mg := testutil.SetupGatewayClientMock(t) type args struct { ctx context.Context auth *auth.Context resourceId int32 protocol ChannelProtocol authType ChannelAuthType count int nodeFilter []NodeFilterConfig } // 准备测试数据 ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} mc.AutoQueryMock = func() (g.CloudConnectResp, error) { return g.CloudConnectResp{ "test-proxy": []g.AutoConfig{ {Province: "河南省", Count: 10}, }, }, nil } var user *models.User var whitelists []*models.Whitelist var proxy *models.Proxy var resource *models.Resource var resourcePss *models.ResourcePss var resetDb = func() { user = &models.User{ ID: 101, Phone: "12312341234", } db.Exec("delete from user where true") db.Create(user) whitelists = []*models.Whitelist{ {ID: 1, UserID: 101, Host: "123.123.123.123"}, {ID: 2, UserID: 101, Host: "456.456.456.456"}, {ID: 3, UserID: 101, Host: "789.789.789.789"}, } db.Exec("delete from whitelist where true") db.Create(whitelists) proxy = &models.Proxy{ ID: 1, Version: 1, Name: "test-proxy", Host: "111.111.111.111", Type: 1, Secret: "test:secret", } db.Exec("delete from proxy where true") db.Create(proxy) resource = &models.Resource{ ID: 1, UserID: 101, Active: true, } db.Exec("delete from resource where true") db.Create(resource) resourcePss = &models.ResourcePss{ ID: 1, ResourceID: 1, Type: 1, Live: 180, Expire: core.LocalDateTime(time.Now().AddDate(1, 0, 0)), DailyLimit: 10000, } db.Exec("delete from resource_pss where true") db.Create(resourcePss) db.Exec("delete from channel where true") } tests := []struct { name string args args setup func() wantErr bool wantErrContains string want func(t *testing.T, got []*PortInfo) error }{ { name: "用户创建HTTP密码通道", args: args{ ctx: ctx, auth: userAuth, resourceId: 1, protocol: ProtocolHTTP, authType: ChannelAuthTypePass, count: 3, nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, }, setup: func() { mr.FlushAll() resetDb() mc.ConnectMock = func(param g.CloudConnectReq) error { if param.Uuid != proxy.Name { return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } if len(param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) } if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ {Province: "河南省", Count: 10}, {Province: "北京市", Count: 6}, }) { return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) } return nil } mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error { if c.Host != proxy.Host { return fmt.Errorf("代理主机不符合预期: %s", c.Host) } if len(params) != 3 { return fmt.Errorf("端口数量不符合预期: %d", len(params)) } for _, param := range params { if param.Status != true { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.AutoEdgeConfig == nil { return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) } if param.Userpass == nil || *param.Userpass == "" { return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass) } if param.Whitelist == nil || len(*param.Whitelist) != 0 { return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) } config := param.AutoEdgeConfig if config.Province != "北京市" { return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) } if *config.Count != 1 { return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) } if config.PacketLoss != 30 { return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) } } return nil } }, want: func(t *testing.T, got []*PortInfo) error { // 验证返回结果 if len(got) != 3 { return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) } // 验证结果 var gotMap = make(map[int]PortInfo) for _, port := range got { if port.Proto != 1 { return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) } if port.Host != proxy.Host { return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) } gotMap[port.Port] = *port } // 验证数据库字段 var channels []*models.Channel db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) for _, ch := range channels { if ch.Protocol != 1 { return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) } if ch.UserID != userAuth.Payload.Id { return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) } // todo 多代理分配策略,验证 proxy_host if ch.ProxyID != proxy.ID { return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) } var info, ok = gotMap[int(ch.ProxyPort)] if !ok { return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) } if ch.AuthPass != true && ch.AuthIP != false { return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) } if ch.Protocol != int32(info.Proto) { return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) } if ch.Username != *info.Username { return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username) } if ch.Password != *info.Password { return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) } if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } // 检查Redis缓存中的字段 key := fmt.Sprintf("channel:%d", ch.ID) if !mr.Exists(key) { return fmt.Errorf("redis缓存中应有键 %s", key) } data, _ := mr.Get(key) var cache models.Channel err := json.Unmarshal([]byte(data), &cache) if err != nil { return fmt.Errorf("无法解析缓存数据: %v", err) } if reflect.DeepEqual(cache, *ch) { return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) } } // 检查跨天用量更新 var pss models.ResourcePss db.Where("resource_id = ?", 1).First(&pss) if pss.DailyUsed != 3 { return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) } if time.Time(pss.DailyLast).IsZero() { return fmt.Errorf("套餐每日最后更新时间不应为空") } if pss.Used != 3 { return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) } return nil }, }, { name: "用户创建HTTP白名单通道", args: args{ ctx: ctx, auth: userAuth, resourceId: 1, protocol: ProtocolHTTP, authType: ChannelAuthTypeIp, count: 3, nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, }, setup: func() { mr.FlushAll() resetDb() mc.ConnectMock = func(param g.CloudConnectReq) error { if param.Uuid != proxy.Name { return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } if len(param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) } if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ {Province: "河南省", Count: 10}, {Province: "北京市", Count: 6}, }) { return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) } return nil } mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error { if c.Host != proxy.Host { return fmt.Errorf("代理主机不符合预期: %s", c.Host) } if len(params) != 3 { return fmt.Errorf("端口数量不符合预期: %d", len(params)) } for _, param := range params { if param.Status != true { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.AutoEdgeConfig == nil { return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) } if param.Userpass == nil || *param.Userpass != "" { return fmt.Errorf("用户名密码不符合预期: %v", *param.Userpass) } if param.Whitelist == nil || len(*param.Whitelist) == 0 { return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) } config := param.AutoEdgeConfig if config.Province != "北京市" { return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) } if *config.Count != 1 { return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) } if config.PacketLoss != 30 { return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) } } return nil } }, want: func(t *testing.T, got []*PortInfo) error { // 验证返回结果 if len(got) != 3 { return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) } // 验证结果 var gotMap = make(map[int]PortInfo) for _, port := range got { if port.Proto != 1 { return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) } if port.Host != proxy.Host { return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) } gotMap[port.Port] = *port } // 验证数据库字段 var channels []*models.Channel db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) for _, ch := range channels { if ch.Protocol != 1 { return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) } if ch.UserID != userAuth.Payload.Id { return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) } // todo 多代理分配策略,验证 proxy_host if ch.ProxyID != proxy.ID { return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) } var info, ok = gotMap[int(ch.ProxyPort)] if !ok { return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) } if ch.AuthPass != false && ch.AuthIP != true { return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) } if ch.Protocol != int32(info.Proto) { return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) } if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } // 检查Redis缓存中的字段 key := fmt.Sprintf("channel:%d", ch.ID) if !mr.Exists(key) { return fmt.Errorf("redis缓存中应有键 %s", key) } data, _ := mr.Get(key) var cache models.Channel err := json.Unmarshal([]byte(data), &cache) if err != nil { return fmt.Errorf("无法解析缓存数据: %v", err) } if reflect.DeepEqual(cache, *ch) { return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) } } // 检查跨天用量更新 var pss models.ResourcePss db.Where("resource_id = ?", 1).First(&pss) if pss.DailyUsed != 3 { return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) } if time.Time(pss.DailyLast).IsZero() { return fmt.Errorf("套餐每日最后更新时间不应为空") } if pss.Used != 3 { return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) } return nil }, }, { name: "管理员替用户创建HTTP密码通道", args: args{ ctx: ctx, auth: adminAuth, resourceId: 1, protocol: ProtocolSocks5, authType: ChannelAuthTypePass, count: 3, nodeFilter: []NodeFilterConfig{{Prov: "北京市"}}, }, setup: func() { mr.FlushAll() resetDb() mc.ConnectMock = func(param g.CloudConnectReq) error { if param.Uuid != proxy.Name { return fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } if len(param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) } if !reflect.DeepEqual(param.AutoConfig, []g.AutoConfig{ {Province: "河南省", Count: 10}, {Province: "北京市", Count: 6}, }) { return fmt.Errorf("自动配置不符合预期: %v", param.AutoConfig) } return nil } mg.PortConfigsMock = func(c *testutil.MockGatewayClient, params []g.PortConfigsReq) error { if c.Host != proxy.Host { return fmt.Errorf("代理主机不符合预期: %s", c.Host) } if len(params) != 3 { return fmt.Errorf("端口数量不符合预期: %d", len(params)) } for _, param := range params { if param.Status != true { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.AutoEdgeConfig == nil { return fmt.Errorf("自动边缘节点配置不符合预期: %v", param.AutoEdgeConfig) } if param.Userpass == nil || *param.Userpass == "" { return fmt.Errorf("用户名密码不符合预期: %v", param.Userpass) } if param.Whitelist == nil || len(*param.Whitelist) != 0 { return fmt.Errorf("白名单不符合预期: %v", param.Whitelist) } config := param.AutoEdgeConfig if config.Province != "北京市" { return fmt.Errorf("自动边缘节点省份不符合预期: %s", param.AutoEdgeConfig.Province) } if *config.Count != 1 { return fmt.Errorf("自动边缘节点数量不符合预期: %d", param.AutoEdgeConfig.Count) } if config.PacketLoss != 30 { return fmt.Errorf("自动边缘节点丢包率不符合预期: %d", param.AutoEdgeConfig.PacketLoss) } } return nil } }, want: func(t *testing.T, got []*PortInfo) error { // 验证返回结果 if len(got) != 3 { return fmt.Errorf("返回的 PortInfo 数量不正确,期望 3,得到 %d", len(got)) } // 验证结果 var gotMap = make(map[int]PortInfo) for _, port := range got { if port.Proto != 3 { return fmt.Errorf("期望协议为 1(http),得到 %d", port.Proto) } if port.Host != proxy.Host { return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) } gotMap[port.Port] = *port } // 验证数据库字段 var channels []*models.Channel db.Where("user_id = ? and deleted_at is null", userAuth.Payload.Id).Find(&channels) for _, ch := range channels { if ch.Protocol != 1 { return fmt.Errorf("通道协议不正确,期望 1(http),得到 %d", ch.Protocol) } if ch.UserID != userAuth.Payload.Id { return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) } // todo 多代理分配策略,验证 proxy_host if ch.ProxyID != proxy.ID { return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) } var info, ok = gotMap[int(ch.ProxyPort)] if !ok { return fmt.Errorf("通道端口 %d 不在返回结果中", ch.ProxyPort) } if ch.AuthPass != true && ch.AuthIP != false { return fmt.Errorf("通道认证类型不正确,期望 Pass,得到 %v", ch.AuthPass) } if ch.Protocol != int32(info.Proto) { return fmt.Errorf("通道协议不正确,期望 %d,得到 %d", info.Proto, ch.Protocol) } if ch.Username != *info.Username { return fmt.Errorf("通道用户名不正确,期望 %s,得到 %s", *info.Username, ch.Username) } if ch.Password != *info.Password { return fmt.Errorf("通道密码不正确,期望 %s,得到 %s", *info.Password, ch.Password) } if time.Time(ch.Expiration).IsZero() { return fmt.Errorf("通道过期时间不应为空") } // 检查Redis缓存中的字段 key := fmt.Sprintf("channel:%d", ch.ID) if !mr.Exists(key) { return fmt.Errorf("redis缓存中应有键 %s", key) } data, _ := mr.Get(key) var cache models.Channel err := json.Unmarshal([]byte(data), &cache) if err != nil { return fmt.Errorf("无法解析缓存数据: %v", err) } if reflect.DeepEqual(cache, *ch) { return fmt.Errorf("缓存数据与数据库不匹配: %v", cache) } } // 检查跨天用量更新 var pss models.ResourcePss db.Where("resource_id = ?", 1).First(&pss) if pss.DailyUsed != 3 { return fmt.Errorf("套餐每日用量不正确,期望 3,得到 %d", pss.DailyUsed) } if time.Time(pss.DailyLast).IsZero() { return fmt.Errorf("套餐每日最后更新时间不应为空") } if pss.Used != 3 { return fmt.Errorf("套餐总用量不正确,期望 3,得到 %d", pss.Used) } return nil }, }, { name: "套餐不存在", args: args{ ctx: ctx, auth: userAuth, resourceId: 999, protocol: ProtocolHTTP, authType: ChannelAuthTypeIp, count: 1, }, setup: func() { mr.FlushAll() resetDb() }, wantErr: true, wantErrContains: "无权限访问", }, { name: "套餐没有权限", args: args{ ctx: ctx, auth: userAuth, resourceId: 2, protocol: ProtocolHTTP, authType: ChannelAuthTypeIp, count: 1, }, setup: func() { mr.FlushAll() resetDb() resource2 := &models.Resource{ ID: 2, UserID: 102, Active: true, } db.Create(resource2) var resourcePss2 = &models.ResourcePss{ ID: 2, ResourceID: 2, Type: 1, Live: 180, Expire: core.LocalDateTime(time.Now().AddDate(1, 0, 0)), DailyLimit: 10000, } db.Create(resourcePss2) }, wantErr: true, wantErrContains: "无权限访问", }, { name: "套餐配额不足", args: args{ ctx: ctx, auth: userAuth, resourceId: 2, protocol: ProtocolHTTP, authType: ChannelAuthTypeIp, count: 10, }, setup: func() { mr.FlushAll() resetDb() // 创建一个配额几乎用完的资源包 resource2 := models.Resource{ ID: 2, UserID: 101, Active: true, } resourcePss2 := models.ResourcePss{ ID: 2, ResourceID: 2, Type: 2, Quota: 100, Used: 91, Live: 180, DailyLimit: 10000, } db.Create(&resource2).Create(&resourcePss2) }, wantErr: true, wantErrContains: "套餐配额不足", }, { name: "端口数量达到上限", args: args{ ctx: ctx, auth: userAuth, resourceId: 1, protocol: ProtocolHTTP, authType: ChannelAuthTypeIp, count: 1, }, setup: func() { mr.FlushAll() resetDb() mc.AutoQueryMock = func() (g.CloudConnectResp, error) { return g.CloudConnectResp{ "test-proxy": []g.AutoConfig{ {Count: 20000}, }, }, nil } // 创建大量占用端口的通道 var channels = make([]models.Channel, 10000) var expr = time.Now().Add(time.Hour) for i := range channels { channels[i] = models.Channel{ ProxyID: 1, ProxyPort: int32(i + 10000), UserID: 101, Expiration: core.LocalDateTime(expr), } } db.CreateInBatches(channels, 1000) }, wantErr: true, wantErrContains: "端口数量到达上限", }, // todo 多地区混杂条件提取 } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.setup != nil { tt.setup() } s := &channelService{} got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...) // 检查错误或结果 if tt.wantErr { if err == nil { t.Errorf("CreateChannel() 应当返回错误") return } if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains) } return } if err != nil { t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr) return } // 使用检查函数验证结果 if err := tt.want(t, got); err != nil { t.Errorf("结果验证失败: %v", err) } }) } } func Test_channelService_RemoveChannels(t *testing.T) { mr := testutil.SetupRedisTest(t) md := testutil.SetupDBTest(t) mg := testutil.SetupGatewayClientMock(t) mc := testutil.SetupCloudClientMock(t) type args struct { ctx context.Context auth *auth.Context id []int32 } // 准备测试数据 ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") // 创建用户 var user = &models.User{ ID: 101, Phone: "12312341234", } md.Create(user) // 创建管理员 var adminUser = &models.User{ ID: 100, Phone: "99999999999", } md.Create(adminUser) // 认证上下文 var adminAuth = &auth.Context{Payload: auth.Payload{Id: 100, Type: auth.PayloadAdmin}} var userAuth = &auth.Context{Payload: auth.Payload{Id: 101, Type: auth.PayloadUser}} // 创建代理 var proxy = &models.Proxy{ ID: 1, Version: 1, Name: "test-proxy", Host: "111.111.111.111", Type: 1, Secret: "test:secret", } md.Create(proxy) var proxy2 = &models.Proxy{ ID: 2, Version: 1, Name: "test-proxy-2", Host: "222.222.222.222", Type: 1, Secret: "test:secret2", } md.Create(proxy2) // 清空数据库函数 var clearDb = func() { md.Exec("delete from channel where true") mr.FlushAll() } tests := []struct { name string args args setup func() wantErr bool wantErrContains string want func(t *testing.T) error }{ { name: "管理员删除多个通道", args: args{ ctx: ctx, auth: adminAuth, id: []int32{1, 2, 3}, }, setup: func() { mr.FlushAll() clearDb() // 创建通道 channels := []models.Channel{ {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 md.Create(channels) for _, channel := range channels { key := fmt.Sprintf("channel:%d", channel.ID) data, _ := json.Marshal(channel) _ = mr.Set(key, string(data)) } // 模拟网关客户端的响应 mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) { switch { case m.Host == proxy.Host: return map[string]g.PortData{ "10001": {Edge: []string{"edge1", "edge4"}}, "10002": {Edge: []string{"edge2"}}, }, nil case m.Host == proxy2.Host: return map[string]g.PortData{ "10001": {Edge: []string{"edge3"}}, }, nil } return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) } mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []g.PortConfigsReq) error { switch { case m.Host == proxy.Host: for _, param := range params { if param.Port != 10001 && param.Port != 10002 { return fmt.Errorf("端口配置不符合预期: %d", param.Port) } if param.Status != false { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.Edge == nil || len(*param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期1: %v", param.Edge) } } return nil case m.Host == proxy2.Host: for _, param := range params { if param.Port != 10001 { return fmt.Errorf("端口配置不符合预期: %d", param.Port) } if param.Status != false { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.Edge == nil || len(*param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期2: %v", param.Edge) } } return nil } return fmt.Errorf("代理主机不符合预期: %s", m.Host) } mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) { switch { case param.Uuid == proxy.Name: var edges = []string{"edge1", "edge2", "edge4"} if !testutil.SliceEqual(edges, param.Edge) { return 0, fmt.Errorf("边缘节点不符合预期3: %v", param.Edge) } if len(param.Config) != 0 { return 0, fmt.Errorf("配置不符合预期: %v", param.Config) } return len(param.Edge), nil case param.Uuid == proxy2.Name: var edges = []string{"edge3"} if !testutil.SliceEqual(edges, param.Edge) { return 0, fmt.Errorf("边缘节点不符合预期4: %v", param.Edge) } if len(param.Config) != 0 { return 0, fmt.Errorf("配置不符合预期: %v", param.Config) } return len(param.Edge), nil } return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } }, want: func(t *testing.T) error { // 检查通道是否被软删除 var count int64 md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) if count > 0 { return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) } // 检查Redis缓存是否被删除 for _, id := range []int32{1, 2, 3} { key := fmt.Sprintf("channel:%d", id) if mr.Exists(key) { return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) } } return nil }, }, { name: "用户删除自己的通道", args: args{ ctx: ctx, auth: userAuth, id: []int32{1, 2, 3}, }, setup: func() { mr.FlushAll() clearDb() // 创建通道 channels := []models.Channel{ {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 md.Create(channels) for _, channel := range channels { key := fmt.Sprintf("channel:%d", channel.ID) data, _ := json.Marshal(channel) _ = mr.Set(key, string(data)) } // 模拟网关客户端的响应 mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...g.PortActiveReq) (map[string]g.PortData, error) { switch { case m.Host == proxy.Host: return map[string]g.PortData{ "10001": {Edge: []string{"edge1", "edge4"}}, "10002": {Edge: []string{"edge2"}}, }, nil case m.Host == proxy2.Host: return map[string]g.PortData{ "10001": {Edge: []string{"edge3"}}, }, nil } return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) } mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []g.PortConfigsReq) error { switch { case m.Host == proxy.Host: for _, param := range params { if param.Port != 10001 && param.Port != 10002 { return fmt.Errorf("端口配置不符合预期: %d", param.Port) } if param.Status != false { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.Edge == nil || len(*param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期5: %v", param.Edge) } } return nil case m.Host == proxy2.Host: for _, param := range params { if param.Port != 10001 { return fmt.Errorf("端口配置不符合预期: %d", param.Port) } if param.Status != false { return fmt.Errorf("端口状态不符合预期: %v", param.Status) } if param.Edge == nil || len(*param.Edge) != 0 { return fmt.Errorf("边缘节点不符合预期6: %v", param.Edge) } } return nil } return fmt.Errorf("代理主机不符合预期: %s", m.Host) } mc.DisconnectMock = func(param g.CloudDisconnectReq) (int, error) { switch { case param.Uuid == proxy.Name: var edges = []string{"edge1", "edge2", "edge4"} if !testutil.SliceEqual(edges, param.Edge) { return 0, fmt.Errorf("边缘节点不符合预期7: %v", param.Edge) } if len(param.Config) != 0 { return 0, fmt.Errorf("配置不符合预期: %v", param.Config) } return len(param.Edge), nil case param.Uuid == proxy2.Name: var edges = []string{"edge3"} if !testutil.SliceEqual(edges, param.Edge) { return 0, fmt.Errorf("边缘节点不符合预期8: %v", param.Edge) } if len(param.Config) != 0 { return 0, fmt.Errorf("配置不符合预期: %v", param.Config) } return len(param.Edge), nil } return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } }, want: func(t *testing.T) error { // 检查通道是否被软删除 var count int64 md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) if count > 0 { return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) } // 检查Redis缓存是否被删除 for _, id := range []int32{1, 2, 3} { key := fmt.Sprintf("channel:%d", id) if mr.Exists(key) { return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) } } return nil }, }, { name: "用户删除不属于自己的通道", args: args{ ctx: ctx, auth: userAuth, id: []int32{1, 2, 3}, }, setup: func() { mr.FlushAll() clearDb() // 创建通道 channels := []models.Channel{ {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: 1, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, {ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: 3, Expiration: core.LocalDateTime(time.Now().Add(24 * time.Hour))}, } // 保存预设数据 md.Create(channels) for _, channel := range channels { key := fmt.Sprintf("channel:%d", channel.ID) data, _ := json.Marshal(channel) _ = mr.Set(key, string(data)) } }, wantErr: true, wantErrContains: "无权限访问", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.setup != nil { tt.setup() } s := &channelService{} err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...) // 检查错误 if tt.wantErr { if err == nil { t.Errorf("RemoveChannels() 应当返回错误") return } if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains) } return } if err != nil { t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr) return } // 检查数据库和缓存是否正确设置 if err := tt.want(t); err != nil { t.Errorf("RemoveChannels() 结果验证失败: %v", err) } }) } }