From 309aa6d0e281cf1e31240335bc5b49cf6961757c Mon Sep 17 00:00:00 2001 From: luorijun Date: Thu, 3 Apr 2025 13:30:57 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=20channel=20remove=20?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmd/gen/main.go | 17 ++ cmd/playground/main.go | 75 +----- go.mod | 1 + go.sum | 3 + pkg/testutil/remote.go | 86 +++---- scripts/dev/speed.sh | 0 web/models/channel.gen.go | 3 + web/services/channel.go | 14 +- web/services/channel_test.go | 467 +++++++++++++++++++++++------------ 9 files changed, 388 insertions(+), 278 deletions(-) create mode 100644 scripts/dev/speed.sh diff --git a/cmd/gen/main.go b/cmd/gen/main.go index 2adc6e6..049d95e 100644 --- a/cmd/gen/main.go +++ b/cmd/gen/main.go @@ -1,8 +1,11 @@ package main import ( + m "platform/web/models" + "gorm.io/driver/postgres" "gorm.io/gen" + "gorm.io/gen/field" "gorm.io/gorm" "gorm.io/gorm/schema" ) @@ -26,5 +29,19 @@ func main() { models := g.GenerateAllTable() g.ApplyBasic(models...) + + modelChannel := g.GenerateModel("channel", + gen.FieldRelateModel(field.BelongsTo, "Node", &m.Node{}, &field.RelateConfig{ + RelatePointer: true, + }), + gen.FieldRelateModel(field.BelongsTo, "User", &m.User{}, &field.RelateConfig{ + RelatePointer: true, + }), + gen.FieldRelateModel(field.BelongsTo, "Proxy", &m.Proxy{}, &field.RelateConfig{ + RelatePointer: true, + }), + ) + g.ApplyBasic(modelChannel) + g.Execute() } diff --git a/cmd/playground/main.go b/cmd/playground/main.go index 6b8edf0..0766a8c 100644 --- a/cmd/playground/main.go +++ b/cmd/playground/main.go @@ -1,78 +1,7 @@ package main -import ( - "encoding/json" - "platform/pkg/orm" - m "platform/web/models" - q "platform/web/queries" - "time" - - "github.com/glebarez/sqlite" - "gorm.io/gorm" -) - func main() { - open, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) - if err != nil { - panic(err) + for i := range 3 { + println(i) } - - err = open.AutoMigrate(&m.Resource{}, &m.ResourcePss{}) - if err != nil { - panic(err) - } - - q.SetDefault(open) - - var r = &m.Resource{ - ID: 1, - UserID: 101, - Active: true, - } - open.Create(r) - var resourcePss = &m.ResourcePss{ - ID: 1, - ResourceID: 1, - Type: 1, - Live: 180, - Expire: time.Now().AddDate(1, 0, 0), - DailyLimit: 10000, - } - open.Create(resourcePss) - - var resource = new(ResourceInfo) - data := q.Resource.As("data") - pss := q.ResourcePss.As("pss") - err = data.Scopes(orm.Alias(data)). - Select( - data.ID, data.UserID, data.Active, - pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire, - ). - LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)). - Where(data.ID.Eq(1)). - Scan(&resource) - if err != nil { - panic(err) - } - - bytes, err := json.MarshalIndent(resource, "", " ") - if err != nil { - panic(err) - } - - println(string(bytes)) -} - -type ResourceInfo struct { - Id int32 - UserId int32 - Active bool - Type int32 - Live int32 - DailyLimit int32 - DailyUsed int32 - DailyLast time.Time - Quota int32 - Used int32 - Expire time.Time } diff --git a/go.mod b/go.mod index 0382fe3..7e02d7c 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect + github.com/stripe/pg-schema-diff v0.9.0 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.59.0 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect diff --git a/go.sum b/go.sum index 0a81637..9abc75c 100644 --- a/go.sum +++ b/go.sum @@ -85,6 +85,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stripe/pg-schema-diff v0.9.0 h1:qzm2VUdbZ2kYwqxoQqtEP3uLQI0B+ymS947zqFTZGBk= +github.com/stripe/pg-schema-diff v0.9.0/go.mod h1:cl2VC6te/cCTOewTRvv4pYsgQqAOhvRQmatCHfYwy8c= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI= diff --git a/pkg/testutil/remote.go b/pkg/testutil/remote.go index 4082efa..e5f5983 100644 --- a/pkg/testutil/remote.go +++ b/pkg/testutil/remote.go @@ -67,11 +67,47 @@ func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) { return remote.CloudConnectResp{}, nil } +// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复 +func SetupCloudClientMock(t *testing.T) *MockCloudClient { + mock := &MockCloudClient{} + remote.Cloud = mock + + return mock +} + // MockGatewayClient 是GatewayClient接口的测试实现 type MockGatewayClient struct { + Host string +} + +// 确保MockGatewayClient实现了GatewayClient接口 +var _ remote.GatewayClient = (*MockGatewayClient)(nil) + +func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error { + testGatewayBase.mu.Lock() + defer testGatewayBase.mu.Unlock() + testGatewayBase.PortConfigsCalls = append(testGatewayBase.PortConfigsCalls, params) + if testGatewayBase.PortConfigsMock != nil { + return testGatewayBase.PortConfigsMock(m, params) + } + return nil +} + +func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + testGatewayBase.mu.Lock() + defer testGatewayBase.mu.Unlock() + testGatewayBase.PortActiveCalls = append(testGatewayBase.PortActiveCalls, param) + if testGatewayBase.PortActiveMock != nil { + return testGatewayBase.PortActiveMock(m, param...) + } + return map[string]remote.PortData{}, nil +} + +type GatewayClientIns struct { + // 存储预期结果的字段 - PortConfigsMock func(params []remote.PortConfigsReq) error - PortActiveMock func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) + PortConfigsMock func(c *MockGatewayClient, params []remote.PortConfigsReq) error + PortActiveMock func(c *MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) // 记录调用历史 PortConfigsCalls [][]remote.PortConfigsReq @@ -81,48 +117,14 @@ type MockGatewayClient struct { mu sync.Mutex } -// 确保MockGatewayClient实现了GatewayClient接口 -var _ remote.GatewayClient = (*MockGatewayClient)(nil) - -func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error { - m.mu.Lock() - defer m.mu.Unlock() - m.PortConfigsCalls = append(m.PortConfigsCalls, params) - if m.PortConfigsMock != nil { - return m.PortConfigsMock(params) - } - return nil -} - -func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) { - m.mu.Lock() - defer m.mu.Unlock() - m.PortActiveCalls = append(m.PortActiveCalls, param) - if m.PortActiveMock != nil { - return m.PortActiveMock(param...) - } - return map[string]remote.PortData{}, nil -} - -// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复 -func SetupCloudClientMock(t *testing.T) *MockCloudClient { - mock := &MockCloudClient{} - remote.Cloud = mock - - return mock -} +var testGatewayBase = &GatewayClientIns{} // SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数 -func SetupGatewayClientMock(t *testing.T) *MockGatewayClient { - mock := &MockGatewayClient{} +func SetupGatewayClientMock(t *testing.T) *GatewayClientIns { remote.GatewayInitializer = func(url, username, password string) remote.GatewayClient { - return mock + return &MockGatewayClient{ + Host: url, + } } - return mock -} - -// NewMockGatewayClient 创建一个新的MockGatewayClient -// 保留此函数以保持向后兼容性 -func NewMockGatewayClient() *MockGatewayClient { - return &MockGatewayClient{} + return testGatewayBase } diff --git a/scripts/dev/speed.sh b/scripts/dev/speed.sh new file mode 100644 index 0000000..e69de29 diff --git a/web/models/channel.gen.go b/web/models/channel.gen.go index 829fe93..7257b57 100644 --- a/web/models/channel.gen.go +++ b/web/models/channel.gen.go @@ -30,6 +30,9 @@ type Channel struct { CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 + Node *Node `json:"node"` + User *User `json:"user"` + Proxy *Proxy `json:"proxy"` } // TableName Channel's table name diff --git a/web/services/channel.go b/web/services/channel.go index 129172b..64af8c6 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -481,7 +481,7 @@ func assignEdge(q *q.Query, count int, filter NodeFilterConfig) (*AssignEdgeResu Province: filter.Prov, City: filter.City, Isp: filter.Isp, - Count: int(math.Ceil(float64(info.used) * 11 / 10)), + Count: int(math.Ceil(float64(info.used) * 2)), } var newConfigs []remote.AutoConfig var update = false @@ -596,10 +596,11 @@ func assignPort( Edge: nil, Status: true, AutoEdgeConfig: &remote.AutoEdgeConfig{ - Province: filter.Prov, - City: filter.City, - Isp: filter.Isp, - Count: v.P(1), + Province: filter.Prov, + City: filter.City, + Isp: filter.Isp, + Count: v.P(1), + PacketLoss: 30, }, }) @@ -704,7 +705,6 @@ type PortInfo struct { // endregion func genPassPair() (string, string) { - var letters = []rune("abcdefghjkmnpqrstuvwxyz23456789") var alphabet = []rune("abcdefghjkmnpqrstuvwxyz") var numbers = []rune("23456789") @@ -716,7 +716,7 @@ func genPassPair() (string, string) { } else { username[i] = numbers[rand.N(len(numbers))] } - password[i] = letters[rand.N(len(letters))] + password[i] = numbers[rand.N(len(numbers))] } return string(username), string(password) diff --git a/web/services/channel_test.go b/web/services/channel_test.go index 0dd086c..c29ed63 100644 --- a/web/services/channel_test.go +++ b/web/services/channel_test.go @@ -4,11 +4,10 @@ import ( "context" "encoding/json" "fmt" - "platform/pkg/env" "platform/pkg/remote" "platform/pkg/testutil" "platform/web/models" - "reflect" + "slices" "strings" "testing" "time" @@ -272,7 +271,6 @@ func Test_channelService_CreateChannel(t *testing.T) { mr := testutil.SetupRedisTest(t) db := testutil.SetupDBTest(t) mc := testutil.SetupCloudClientMock(t) - env.DebugExternalChange = false type args struct { ctx context.Context @@ -338,10 +336,9 @@ func Test_channelService_CreateChannel(t *testing.T) { name string args args setup func() - want []*PortInfo wantErr bool wantErrContains string - checkCache func(channels []models.Channel) error + want func(t *testing.T, got []*PortInfo) error }{ { name: "用户创建HTTP密码通道", @@ -354,12 +351,58 @@ func Test_channelService_CreateChannel(t *testing.T) { count: 3, nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}}, }, - want: []*PortInfo{ - { - Proto: "http", - Host: proxy.Host, - Port: 10000, - }, + + want: func(t *testing.T, got []*PortInfo) error { + // 验证返回结果 + if len(got) == 0 { + return fmt.Errorf("返回的 PortInfo 不应为空") + } + + // 验证协议正确 + for _, port := range got { + if port.Proto != "http" { + return fmt.Errorf("期望协议为 http,得到 %s", port.Proto) + } + if port.Host != proxy.Host { + return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host) + } + } + + // 验证数据库字段 + var channels []*models.Channel + db.Where("user_id = ? AND proxy_id = ?", userAuth.Payload.Id, proxy.ID).Find(&channels) + for _, ch := range channels { + if ch.Protocol != "http" { + return fmt.Errorf("通道协议不正确,期望 http,得到 %s", ch.Protocol) + } + if ch.UserID != userAuth.Payload.Id { + return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID) + } + if ch.ProxyID != proxy.ID { + return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID) + } + + // 检查Redis缓存中的字段 + key := fmt.Sprintf("channel:%d", ch.ID) + if !mr.Exists(key) { + return fmt.Errorf("Redis缓存中应有键 %s", key) + } + + data, _ := mr.Get(key) + var cachedChannel models.Channel + err := json.Unmarshal([]byte(data), &cachedChannel) + if err != nil { + return fmt.Errorf("无法解析缓存数据: %v", err) + } + + if cachedChannel.ID != ch.ID { + return fmt.Errorf("缓存ID不正确,期望 %d,得到 %d", ch.ID, cachedChannel.ID) + } + if cachedChannel.Protocol != ch.Protocol { + return fmt.Errorf("缓存协议不正确,期望 %s,得到 %s", ch.Protocol, cachedChannel.Protocol) + } + } + return nil }, }, { @@ -372,7 +415,9 @@ func Test_channelService_CreateChannel(t *testing.T) { authType: ChannelAuthTypeIp, count: 2, }, - want: []*PortInfo{}, + want: func(t *testing.T, got []*PortInfo) error { + return nil + }, }, { name: "管理员创建SOCKS5密码通道", @@ -384,7 +429,9 @@ func Test_channelService_CreateChannel(t *testing.T) { authType: ChannelAuthTypePass, count: 2, }, - want: []*PortInfo{}, + want: func(t *testing.T, got []*PortInfo) error { + return nil + }, }, { name: "套餐不存在", @@ -504,34 +551,10 @@ func Test_channelService_CreateChannel(t *testing.T) { return } - // 检查返回值 - if reflect.DeepEqual(got, tt.want) { - t.Errorf("CreateChannel() 返回值 = %v, 期望 %v", got, tt.want) - } - - // 查询创建的通道 - var channels []models.Channel - db.Where( - "user_id = ? and proxy_id = ?", - userAuth.Payload.Id, proxy.ID, - ).Find(&channels) - - if len(channels) != 2 { - t.Errorf("期望创建2个通道,但是创建了%d个", len(channels)) - } - - // 检查Redis缓存 - for _, ch := range channels { - key := fmt.Sprintf("channel:%d", ch.ID) - if !mr.Exists(key) { - t.Errorf("Redis缓存中应有键 %s", key) - } - } - - if tt.checkCache != nil { - var err = tt.checkCache(channels) - if err != nil { - t.Errorf("检查缓存失败: %v", err) + // 使用检查函数验证结果 + if tt.want != nil { + if err := tt.want(t, got); err != nil { + t.Errorf("结果验证失败: %v", err) } } }) @@ -540,9 +563,9 @@ func Test_channelService_CreateChannel(t *testing.T) { func Test_channelService_RemoveChannels(t *testing.T) { mr := testutil.SetupRedisTest(t) - db := testutil.SetupDBTest(t) + md := testutil.SetupDBTest(t) mg := testutil.SetupGatewayClientMock(t) - env.DebugExternalChange = false + mc := testutil.SetupCloudClientMock(t) type args struct { ctx context.Context @@ -552,178 +575,309 @@ func Test_channelService_RemoveChannels(t *testing.T) { // 准备测试数据 ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") + + // 创建用户 + var user = &models.User{ + ID: 101, + Phone: "12312341234", + } + md.Create(user) + + // 创建管理员 + var adminUser = &models.User{ + ID: 100, + Phone: "99999999999", + } + md.Create(adminUser) + + // 认证上下文 + var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} + var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} + + // 创建代理 + var proxy = &models.Proxy{ + ID: 1, + Version: 1, + Name: "test-proxy", + Host: "111.111.111.111", + Type: 1, + Secret: "test:secret", + } + md.Create(proxy) + + var proxy2 = &models.Proxy{ + ID: 2, + Version: 1, + Name: "test-proxy-2", + Host: "222.222.222.222", + Type: 1, + Secret: "test:secret2", + } + md.Create(proxy2) + + // 清空数据库函数 + var clearDb = func() { + md.Exec("delete from channel where true") + mr.FlushAll() + } + tests := []struct { name string args args setup func() wantErr bool wantErrContains string - checkCache func(t *testing.T) + want func(t *testing.T) error }{ { name: "管理员删除多个通道", args: args{ - ctx: ctx, - auth: &AuthContext{ - Payload: Payload{ - Type: PayloadAdmin, - Id: 1, - }, - }, - id: []int32{1, 2, 3}, + ctx: ctx, + auth: adminAuth, + id: []int32{1, 2, 3}, }, setup: func() { - // 预设 Redis 缓存 mr.FlushAll() - for _, id := range []int32{1, 2, 3} { - key := fmt.Sprintf("channel:%d", id) - channel := models.Channel{ID: id, UserID: 100} - data, _ := json.Marshal(channel) - mr.Set(key, string(data)) - } - - // 清空数据库表 - db.Exec("delete from channel") - db.Exec("delete from proxy") - - // 创建代理 - proxies := []models.Proxy{ - {ID: 1, Name: "proxy1", Host: "proxy1.example.com", Secret: "key:secret", Type: 1}, - {ID: 2, Name: "proxy2", Host: "proxy2.example.com", Secret: "key:secret", Type: 1}, - } - for _, p := range proxies { - db.Create(&p) - } + clearDb() // 创建通道 channels := []models.Channel{ - {ID: 1, UserID: 100, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, - {ID: 2, UserID: 100, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)}, } - for _, c := range channels { - db.Create(&c) + + // 保存预设数据 + md.Create(channels) + for _, channel := range channels { + key := fmt.Sprintf("channel:%d", channel.ID) + data, _ := json.Marshal(channel) + _ = mr.Set(key, string(data)) + } + + // 模拟网关客户端的响应 + mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + switch { + case m.Host == proxy.Host: + return map[string]remote.PortData{ + "10001": {Edge: []string{"edge1", "edge4"}}, + "10002": {Edge: []string{"edge2"}}, + }, nil + case m.Host == proxy2.Host: + return map[string]remote.PortData{ + "10001": {Edge: []string{"edge3"}}, + }, nil + } + return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) + } + mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error { + switch { + case m.Host == proxy.Host: + for _, param := range params { + if param.Port != 10001 && param.Port != 10002 { + return fmt.Errorf("端口配置不符合预期: %d", param.Port) + } + if param.Status != false { + return fmt.Errorf("端口状态不符合预期: %v", param.Status) + } + if param.Edge == nil || len(*param.Edge) != 0 { + return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + } + case m.Host == proxy2.Host: + for _, param := range params { + if param.Port != 10001 { + return fmt.Errorf("端口配置不符合预期: %d", param.Port) + } + if param.Status != false { + return fmt.Errorf("端口状态不符合预期: %v", param.Status) + } + if param.Edge == nil || len(*param.Edge) != 0 { + return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + } + } + return fmt.Errorf("代理主机不符合预期: %s", m.Host) + } + mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) { + switch { + case param.Uuid == proxy.Name: + var edges = []string{"edge1", "edge2", "edge4"} + if !slices.Equal(edges, param.Edge) { + return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + if len(param.Config) != 0 { + return 0, fmt.Errorf("配置不符合预期: %v", param.Config) + } + return len(param.Edge), nil + case param.Uuid == proxy2.Name: + var edges = []string{"edge3"} + if !slices.Equal(edges, param.Edge) { + return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + if len(param.Config) != 0 { + return 0, fmt.Errorf("配置不符合预期: %v", param.Config) + } + return len(param.Edge), nil + } + return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } }, - checkCache: func(t *testing.T) { + want: func(t *testing.T) error { // 检查通道是否被软删除 var count int64 - db.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) + md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) if count > 0 { - t.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) + return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) } // 检查Redis缓存是否被删除 for _, id := range []int32{1, 2, 3} { key := fmt.Sprintf("channel:%d", id) if mr.Exists(key) { - t.Errorf("通道缓存 %s 应被删除但仍存在", key) + return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) } } + return nil }, }, { name: "用户删除自己的通道", args: args{ - ctx: ctx, - auth: &AuthContext{ - Payload: Payload{ - Type: PayloadUser, - Id: 100, - }, - }, - id: []int32{1}, + ctx: ctx, + auth: userAuth, + id: []int32{1, 2, 3}, }, setup: func() { - // 预设 Redis 缓存 mr.FlushAll() - key := "channel:1" - channel := models.Channel{ID: 1, UserID: 100} - data, _ := json.Marshal(channel) - mr.Set(key, string(data)) - - // 清空数据库表 - db.Exec("delete from channel") - db.Exec("delete from proxy") - - // 创建代理 - proxy := models.Proxy{ - ID: 1, - Name: "proxy1", - Host: "proxy1.example.com", - Secret: "key:secret", - Type: 1, - } - db.Create(&proxy) + clearDb() // 创建通道 - ch := models.Channel{ - ID: 1, - UserID: 100, - ProxyID: 1, - ProxyPort: 10001, - Protocol: "http", - Expiration: time.Now().Add(24 * time.Hour), + channels := []models.Channel{ + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)}, } - db.Create(&ch) - // 模拟查询已激活的端口 - mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) { - return map[string]remote.PortData{ - "10001": { - Edge: []string{"edge1", "edge2"}, - }, - }, nil + // 保存预设数据 + md.Create(channels) + for _, channel := range channels { + key := fmt.Sprintf("channel:%d", channel.ID) + data, _ := json.Marshal(channel) + _ = mr.Set(key, string(data)) + } + + // 模拟网关客户端的响应 + mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + switch { + case m.Host == proxy.Host: + return map[string]remote.PortData{ + "10001": {Edge: []string{"edge1", "edge4"}}, + "10002": {Edge: []string{"edge2"}}, + }, nil + case m.Host == proxy2.Host: + return map[string]remote.PortData{ + "10001": {Edge: []string{"edge3"}}, + }, nil + } + return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host) + } + mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error { + switch { + case m.Host == proxy.Host: + for _, param := range params { + if param.Port != 10001 && param.Port != 10002 { + return fmt.Errorf("端口配置不符合预期: %d", param.Port) + } + if param.Status != false { + return fmt.Errorf("端口状态不符合预期: %v", param.Status) + } + if param.Edge == nil || len(*param.Edge) != 0 { + return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + } + case m.Host == proxy2.Host: + for _, param := range params { + if param.Port != 10001 { + return fmt.Errorf("端口配置不符合预期: %d", param.Port) + } + if param.Status != false { + return fmt.Errorf("端口状态不符合预期: %v", param.Status) + } + if param.Edge == nil || len(*param.Edge) != 0 { + return fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + } + } + return fmt.Errorf("代理主机不符合预期: %s", m.Host) + } + mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) { + switch { + case param.Uuid == proxy.Name: + var edges = []string{"edge1", "edge2", "edge4"} + if !slices.Equal(edges, param.Edge) { + return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + if len(param.Config) != 0 { + return 0, fmt.Errorf("配置不符合预期: %v", param.Config) + } + return len(param.Edge), nil + case param.Uuid == proxy2.Name: + var edges = []string{"edge3"} + if !slices.Equal(edges, param.Edge) { + return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge) + } + if len(param.Config) != 0 { + return 0, fmt.Errorf("配置不符合预期: %v", param.Config) + } + return len(param.Edge), nil + } + return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid) } }, - checkCache: func(t *testing.T) { + want: func(t *testing.T) error { // 检查通道是否被软删除 var count int64 - db.Model(&models.Channel{}).Where("id = ? AND deleted_at IS NULL", 1).Count(&count) + md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count) if count > 0 { - t.Errorf("应该软删除了通道,但仍未删除") + return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count) } // 检查Redis缓存是否被删除 - key := "channel:1" - if mr.Exists(key) { - t.Errorf("通道缓存 %s 应被删除但仍存在", key) + for _, id := range []int32{1, 2, 3} { + key := fmt.Sprintf("channel:%d", id) + if mr.Exists(key) { + return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key) + } } + return nil }, }, { name: "用户删除不属于自己的通道", args: args{ - ctx: ctx, - auth: &AuthContext{ - Payload: Payload{ - Type: PayloadUser, - Id: 100, - }, - }, - id: []int32{5}, + ctx: ctx, + auth: userAuth, + id: []int32{1, 2, 3}, }, setup: func() { - // 预设 Redis 缓存 mr.FlushAll() - key := "channel:5" - channel := models.Channel{ID: 5, UserID: 101} - data, _ := json.Marshal(channel) - mr.Set(key, string(data)) + clearDb() - // 清空数据库表 - db.Exec("delete from channel") - - // 创建一个属于用户101的通道 - ch := models.Channel{ - ID: 5, - UserID: 101, - ProxyID: 1, - ProxyPort: 10005, - Protocol: "http", - Expiration: time.Now().Add(24 * time.Hour), + // 创建通道 + channels := []models.Channel{ + {ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)}, + {ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)}, + } + + // 保存预设数据 + md.Create(channels) + for _, channel := range channels { + key := fmt.Sprintf("channel:%d", channel.ID) + data, _ := json.Marshal(channel) + _ = mr.Set(key, string(data)) } - db.Create(&ch) }, wantErr: true, wantErrContains: "无权限访问", @@ -756,9 +910,10 @@ func Test_channelService_RemoveChannels(t *testing.T) { return } - // 检查 Redis 缓存是否正确设置 - if tt.checkCache != nil { - tt.checkCache(t) + // 检查数据库和缓存是否正确设置 + want := tt.want(t) + if tt.want(t) != nil { + t.Errorf("RemoveChannels() 结果验证失败: %v", want) } }) }