diff --git a/go.mod b/go.mod index 4f18c78..e35b17d 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( ) require ( + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/andybalholm/brotli v1.1.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect diff --git a/go.sum b/go.sum index e192c4f..5c924e0 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= @@ -42,6 +44,7 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw= github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/lmittmann/tint v1.0.7 h1:D/0OqWZ0YOGZ6AyC+5Y2kD8PBEzBk6rFHVSfOqCkF9Y= diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 719ee42..2d78907 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -13,15 +13,29 @@ import ( "strings" ) -type client struct { +// CloudClient 定义云服务接口 +type CloudClient interface { + CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error) + CloudConnect(param CloudConnectReq) error + CloudDisconnect(param CloudDisconnectReq) (int, error) + CloudAutoQuery() (CloudConnectResp, error) +} + +// GatewayClient 定义网关接口 +type GatewayClient interface { + GatewayPortConfigs(params []PortConfigsReq) error + GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) +} + +type cloud struct { url string token string } -var Client client +var Cloud CloudClient func Init() { - Client = client{ + Cloud = &cloud{ url: env.RemoteAddr, token: env.RemoteToken, } @@ -61,7 +75,7 @@ type Edge struct { PacketLoss int `json:"packet_loss"` } -func (c *client) CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error) { +func (c *cloud) CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error) { data := strings.Builder{} data.WriteString("province=") data.WriteString(param.Province) @@ -110,7 +124,7 @@ type CloudConnectReq struct { AutoConfig []AutoConfig `json:"auto_config,omitempty"` } -func (c *client) CloudConnect(param CloudConnectReq) error { +func (c *cloud) CloudConnect(param CloudConnectReq) error { data, err := json.Marshal(param) if err != nil { return err @@ -165,7 +179,7 @@ type Config struct { Online bool `json:"online"` } -func (c *client) CloudDisconnect(param CloudDisconnectReq) (int, error) { +func (c *cloud) CloudDisconnect(param CloudDisconnectReq) (int, error) { data, err := json.Marshal(param) if err != nil { return 0, err @@ -208,7 +222,7 @@ func (c *client) CloudDisconnect(param CloudDisconnectReq) (int, error) { type CloudConnectResp map[string][]AutoConfig -func (c *client) CloudAutoQuery() (CloudConnectResp, error) { +func (c *cloud) CloudAutoQuery() (CloudConnectResp, error) { resp, err := c.requestCloud("GET", "/auto_query", "") if err != nil { return nil, err @@ -237,7 +251,7 @@ func (c *client) CloudAutoQuery() (CloudConnectResp, error) { // endregion -func (c *client) requestCloud(method string, url string, data string) (*http.Response, error) { +func (c *cloud) requestCloud(method string, url string, data string) (*http.Response, error) { url = fmt.Sprintf("%s/api%s", c.url, url) req, err := http.NewRequest(method, url, strings.NewReader(data)) @@ -274,14 +288,22 @@ func (c *client) requestCloud(method string, url string, data string) (*http.Res return resp, nil } -type Gateway struct { +type gateway struct { url string username string password string } -func InitGateway(url, username, password string) *Gateway { - return &Gateway{url, username, password} +var GatewayInitializer = func(url, username, password string) GatewayClient { + return &gateway{ + url: url, + username: username, + password: password, + } +} + +func NewGateway(url, username, password string) GatewayClient { + return GatewayInitializer(url, username, password) } // region gateway:/port/configs @@ -306,7 +328,7 @@ type AutoEdgeConfig struct { PacketLoss int `json:"packet_loss,omitempty"` } -func (c *Gateway) GatewayPortConfigs(params []PortConfigsReq) error { +func (c *gateway) GatewayPortConfigs(params []PortConfigsReq) error { if len(params) == 0 { return errors.New("params is empty") } @@ -372,7 +394,7 @@ type PortData struct { Userpass string `json:"userpass"` } -func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) { +func (c *gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) { _param := PortActiveReq{} if len(param) != 0 { _param = param[0] @@ -431,7 +453,7 @@ func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData // endregion -func (c *Gateway) requestGateway(method string, url string, data string) (*http.Response, error) { +func (c *gateway) requestGateway(method string, url string, data string) (*http.Response, error) { url = fmt.Sprintf("http://%s:%s@%s:9990%s", c.username, c.password, c.url, url) req, err := http.NewRequest(method, url, strings.NewReader(data)) if err != nil { diff --git a/pkg/testutil/db.go b/pkg/testutil/db.go new file mode 100644 index 0000000..513f1e7 --- /dev/null +++ b/pkg/testutil/db.go @@ -0,0 +1,40 @@ +package testutil + +import ( + "platform/pkg/orm" + q "platform/web/queries" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// SetupDBTest 创建一个带有 sqlmock 的 GORM 数据库连接 +func SetupDBTest(t *testing.T) sqlmock.Sqlmock { + + // 创建 sqlmock + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("创建sqlmock失败: %v", err) + } + + // 配置 gorm 连接 + gormDB, err := gorm.Open(postgres.New(postgres.Config{ + Conn: db, + PreferSimpleProtocol: true, // 禁用 prepared statement 缓存 + }), &gorm.Config{}) + if err != nil { + t.Fatalf("gorm 打开数据库连接失败: %v", err) + } + + q.SetDefault(gormDB) + orm.DB = gormDB + + // 使用 t.Cleanup 确保测试结束后关闭数据库连接 + t.Cleanup(func() { + db.Close() + }) + + return mock +} diff --git a/pkg/testutil/redis.go b/pkg/testutil/redis.go new file mode 100644 index 0000000..b7fccdc --- /dev/null +++ b/pkg/testutil/redis.go @@ -0,0 +1,30 @@ +package testutil + +import ( + "platform/pkg/rds" + "testing" + + "github.com/alicebob/miniredis/v2" + "github.com/redis/go-redis/v9" +) + +// SetupRedisTest 创建一个测试用的Redis实例 +// 返回miniredis实例,使用t.Cleanup自动清理资源 +func SetupRedisTest(t *testing.T) *miniredis.Miniredis { + mr, err := miniredis.Run() + if err != nil { + t.Fatalf("设置 miniredis 失败: %v", err) + } + + // 替换 Redis 客户端为测试客户端 + rds.Client = redis.NewClient(&redis.Options{ + Addr: mr.Addr(), + }) + + // 使用t.Cleanup确保测试结束后恢复原始客户端并关闭miniredis + t.Cleanup(func() { + mr.Close() + }) + + return mr +} diff --git a/pkg/testutil/remote.go b/pkg/testutil/remote.go new file mode 100644 index 0000000..4082efa --- /dev/null +++ b/pkg/testutil/remote.go @@ -0,0 +1,128 @@ +package testutil + +import ( + "platform/pkg/remote" + "sync" + "testing" +) + +// MockCloudClient 是CloudClient接口的测试实现 +type MockCloudClient struct { + // 存储预期结果的字段 + EdgesMock func(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error) + ConnectMock func(param remote.CloudConnectReq) error + DisconnectMock func(param remote.CloudDisconnectReq) (int, error) + AutoQueryMock func() (remote.CloudConnectResp, error) + + // 记录调用历史 + EdgesCalls []remote.CloudEdgesReq + ConnectCalls []remote.CloudConnectReq + DisconnectCalls []remote.CloudDisconnectReq + AutoQueryCalls int + + // 用于并发安全 + mu sync.Mutex +} + +// 确保MockCloudClient实现了CloudClient接口 +var _ remote.CloudClient = (*MockCloudClient)(nil) + +func (m *MockCloudClient) CloudEdges(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.EdgesCalls = append(m.EdgesCalls, param) + if m.EdgesMock != nil { + return m.EdgesMock(param) + } + return &remote.CloudEdgesResp{}, nil +} + +func (m *MockCloudClient) CloudConnect(param remote.CloudConnectReq) error { + m.mu.Lock() + defer m.mu.Unlock() + m.ConnectCalls = append(m.ConnectCalls, param) + if m.ConnectMock != nil { + return m.ConnectMock(param) + } + return nil +} + +func (m *MockCloudClient) CloudDisconnect(param remote.CloudDisconnectReq) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.DisconnectCalls = append(m.DisconnectCalls, param) + if m.DisconnectMock != nil { + return m.DisconnectMock(param) + } + return 0, nil +} + +func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.AutoQueryCalls++ + if m.AutoQueryMock != nil { + return m.AutoQueryMock() + } + return remote.CloudConnectResp{}, nil +} + +// MockGatewayClient 是GatewayClient接口的测试实现 +type MockGatewayClient struct { + // 存储预期结果的字段 + PortConfigsMock func(params []remote.PortConfigsReq) error + PortActiveMock func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) + + // 记录调用历史 + PortConfigsCalls [][]remote.PortConfigsReq + PortActiveCalls [][]remote.PortActiveReq + + // 用于并发安全 + mu sync.Mutex +} + +// 确保MockGatewayClient实现了GatewayClient接口 +var _ remote.GatewayClient = (*MockGatewayClient)(nil) + +func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error { + m.mu.Lock() + defer m.mu.Unlock() + m.PortConfigsCalls = append(m.PortConfigsCalls, params) + if m.PortConfigsMock != nil { + return m.PortConfigsMock(params) + } + return nil +} + +func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + m.mu.Lock() + defer m.mu.Unlock() + m.PortActiveCalls = append(m.PortActiveCalls, param) + if m.PortActiveMock != nil { + return m.PortActiveMock(param...) + } + return map[string]remote.PortData{}, nil +} + +// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复 +func SetupCloudClientMock(t *testing.T) *MockCloudClient { + mock := &MockCloudClient{} + remote.Cloud = mock + + return mock +} + +// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数 +func SetupGatewayClientMock(t *testing.T) *MockGatewayClient { + mock := &MockGatewayClient{} + remote.GatewayInitializer = func(url, username, password string) remote.GatewayClient { + return mock + } + return mock +} + +// NewMockGatewayClient 创建一个新的MockGatewayClient +// 保留此函数以保持向后兼容性 +func NewMockGatewayClient() *MockGatewayClient { + return &MockGatewayClient{} +} diff --git a/web/services/channel.go b/web/services/channel.go index 0d9bc4a..80388f7 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -99,6 +99,9 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, proxies, err := tx.Proxy.Where( q.Proxy.ID.In(proxyIds...), ).Find() + if err != nil { + return err + } slog.Debug("查找代理", "rid", rid, "step", time.Since(step)) @@ -163,7 +166,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, } var secret = strings.Split(proxy.Secret, ":") - gateway := remote.InitGateway( + gateway := remote.NewGateway( proxy.Host, secret[0], secret[1], @@ -204,7 +207,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, } } if len(edges) > 0 { - _, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{ + _, err := remote.Cloud.CloudDisconnect(remote.CloudDisconnectReq{ Uuid: proxy.Name, Edge: edges, }) @@ -395,7 +398,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) { // 查询已配置的节点 step = time.Now() - rProxyConfigs, err := remote.Client.CloudAutoQuery() + rProxyConfigs, err := remote.Cloud.CloudAutoQuery() if err != nil { return nil, err } @@ -466,7 +469,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) { step = time.Now() slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count) - err := remote.Client.CloudConnect(remote.CloudConnectReq{ + err := remote.Cloud.CloudConnect(remote.CloudConnectReq{ Uuid: info.proxy.Name, Edge: nil, AutoConfig: []remote.AutoConfig{{ @@ -520,7 +523,7 @@ func assignPort( expiration time.Time, filter NodeFilterConfig, ) ([]string, []*models.Channel, error) { - var step = time.Now() + var step time.Time var configs = proxies.configs var exists = proxies.channels @@ -639,7 +642,7 @@ func assignPort( step = time.Now() var secret = strings.Split(proxy.Secret, ":") - gateway := remote.InitGateway( + gateway := remote.NewGateway( proxy.Host, secret[0], secret[1], @@ -677,6 +680,10 @@ func chKey(channel *models.Channel) string { } func cache(ctx context.Context, channels []*models.Channel) error { + if len(channels) == 0 { + return nil + } + pipe := rds.Client.TxPipeline() zList := make([]redis.Z, 0, len(channels)) @@ -685,7 +692,7 @@ func cache(ctx context.Context, channels []*models.Channel) error { if err != nil { return err } - pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now())) + pipe.Set(ctx, chKey(channel), string(marshal), time.Until(channel.Expiration)) zList = append(zList, redis.Z{ Score: float64(channel.Expiration.Unix()), Member: channel.ID, @@ -702,6 +709,10 @@ func cache(ctx context.Context, channels []*models.Channel) error { } func deleteCache(ctx context.Context, channels []*models.Channel) error { + if len(channels) == 0 { + return nil + } + keys := make([]string, len(channels)) for i := range channels { keys[i] = chKey(channels[i]) diff --git a/web/services/channel_test.go b/web/services/channel_test.go new file mode 100644 index 0000000..505b43e --- /dev/null +++ b/web/services/channel_test.go @@ -0,0 +1,985 @@ +package services + +import ( + "context" + "encoding/json" + "fmt" + "platform/pkg/env" + "platform/pkg/remote" + "platform/pkg/testutil" + "platform/web/models" + "regexp" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gofiber/fiber/v2/middleware/requestid" + "gorm.io/gorm" +) + +func Test_genPassPair(t *testing.T) { + tests := []struct { + name string + }{ + { + name: "正常生成随机用户名和密码", + }, + { + name: "多次调用生成不同的值", + }, + } + + // 第一个测试:检查生成的用户名和密码是否有效 + t.Run(tests[0].name, func(t *testing.T) { + username, password := genPassPair() + if username == "" { + t.Errorf("genPassPair() username is empty") + } + if password == "" { + t.Errorf("genPassPair() password is empty") + } + }) + + // 第二个测试:确保多次调用生成不同的值 + t.Run(tests[1].name, func(t *testing.T) { + username1, password1 := genPassPair() + username2, password2 := genPassPair() + + if username1 == username2 { + t.Errorf("genPassPair() generated the same username twice: %v", username1) + } + if password1 == password2 { + t.Errorf("genPassPair() generated the same password twice: %v", password1) + } + }) +} + +func Test_chKey(t *testing.T) { + type args struct { + channel *models.Channel + } + tests := []struct { + name string + args args + want string + }{ + { + name: "ID为1的通道", + args: args{ + channel: &models.Channel{ + ID: 1, + }, + }, + want: "channel:1", + }, + { + name: "ID为100的通道", + args: args{ + channel: &models.Channel{ + ID: 100, + }, + }, + want: "channel:100", + }, + { + name: "ID为0的通道", + args: args{ + channel: &models.Channel{ + ID: 0, + }, + }, + want: "channel:0", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := chKey(tt.args.channel); got != tt.want { + t.Errorf("chKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_cache(t *testing.T) { + mr := testutil.SetupRedisTest(t) + + type args struct { + ctx context.Context + channels []*models.Channel + } + + // 准备测试数据 + now := time.Now() + expiration := now.Add(24 * time.Hour) + + testChannels := []*models.Channel{ + { + ID: 1, + UserID: 100, + ProxyID: 10, + ProxyPort: 8080, + Protocol: "http", + Expiration: expiration, + }, + { + ID: 2, + UserID: 101, + ProxyID: 11, + ProxyPort: 8081, + Protocol: "socks5", + Expiration: expiration, + }, + } + + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "正常缓存多个通道", + args: args{ + ctx: context.Background(), + channels: testChannels, + }, + wantErr: false, + }, + { + name: "空通道列表", + args: args{ + ctx: context.Background(), + channels: []*models.Channel{}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mr.FlushAll() // 清空 Redis 数据 + + if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { + t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 验证缓存结果 + if len(tt.args.channels) > 0 { + for _, channel := range tt.args.channels { + key := fmt.Sprintf("channel:%d", channel.ID) + if !mr.Exists(key) { + t.Errorf("缓存未包含通道键 %s", key) + } else { + // 验证缓存的数据是否正确 + data, _ := mr.Get(key) + var cachedChannel models.Channel + err := json.Unmarshal([]byte(data), &cachedChannel) + if err != nil { + t.Errorf("无法解析缓存数据: %v", err) + } + if cachedChannel.ID != channel.ID { + t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID) + } + } + } + + // 验证是否设置了过期时间 + for _, channel := range tt.args.channels { + key := fmt.Sprintf("channel:%d", channel.ID) + ttl := mr.TTL(key) + if ttl <= 0 { + t.Errorf("键 %s 没有设置过期时间", key) + } + } + + // 验证是否添加了有序集合 + if !mr.Exists("tasks:channel") { + t.Errorf("ZAdd未创建有序集合 tasks:channel") + } + } + }) + } +} + +func Test_deleteCache(t *testing.T) { + mr := testutil.SetupRedisTest(t) + + type args struct { + ctx context.Context + channels []*models.Channel + } + + // 准备测试数据 + testChannels := []*models.Channel{ + {ID: 1}, + {ID: 2}, + {ID: 3}, + } + + ctx := context.Background() + + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "正常删除多个通道缓存", + args: args{ + ctx: ctx, + channels: testChannels, + }, + wantErr: false, + }, + { + name: "空通道列表", + args: args{ + ctx: ctx, + channels: []*models.Channel{}, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mr.FlushAll() // 清空 Redis 数据 + + // 预先设置缓存数据 + for _, channel := range testChannels { + key := fmt.Sprintf("channel:%d", channel.ID) + data, _ := json.Marshal(channel) + mr.Set(key, string(data)) + mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间 + } + + if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr { + t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr) + return + } + + // 验证删除结果 + for _, channel := range tt.args.channels { + key := fmt.Sprintf("channel:%d", channel.ID) + if mr.Exists(key) { + t.Errorf("通道键 %s 未被删除", key) + } + } + }) + } +} + +func Test_channelService_CreateChannel(t *testing.T) { + mr := testutil.SetupRedisTest(t) + mdb := testutil.SetupDBTest(t) + mc := testutil.SetupCloudClientMock(t) + env.DebugExternalChange = false + + type args struct { + ctx context.Context + auth *AuthContext + resourceId int32 + protocol ChannelProtocol + authType ChannelAuthType + count int + nodeFilter []NodeFilterConfig + } + + // 准备测试数据 + ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") + tests := []struct { + name string + args args + setup func() + want []string + wantErr bool + wantErrContains string + checkCache func(t *testing.T) + }{ + { + name: "用户创建HTTP密码通道", + args: args{ + ctx: ctx, + auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}}, + resourceId: 4, + protocol: ProtocolHTTP, + authType: ChannelAuthTypePass, + count: 3, + nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}}, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 设置CloudAutoQuery的模拟返回 + mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { + return remote.CloudConnectResp{ + "proxy3": []remote.AutoConfig{ + {Province: "河南", City: "郑州", Isp: "电信", Count: 10}, + }, + }, nil + } + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 4, 100, true, + 0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows) + + // 模拟查询代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1) + mdb.ExpectQuery("SELECT"). + WithArgs(1). + WillReturnRows(proxyRows) + + // 模拟查询通道 + channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"}) + mdb.ExpectQuery("SELECT"). + WillReturnRows(channelRows) + + // 模拟保存通道 - PostgreSQL返回ID + mdb.ExpectQuery("INSERT INTO").WillReturnRows( + sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6), + ) + + // 模拟更新套餐使用记录 + mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1)) + + // 提交事务 + mdb.ExpectCommit() + }, + want: []string{ + "http://proxy3.example.com:10000", + "http://proxy3.example.com:10001", + "http://proxy3.example.com:10002", + }, + checkCache: func(t *testing.T) { + // 检查总共创建了3个通道 + for i := 4; i <= 6; i++ { + key := fmt.Sprintf("channel:%d", i) + if !mr.Exists(key) { + t.Errorf("Redis缓存中应有键 %s", key) + } + } + }, + }, + { + name: "用户创建HTTP白名单通道", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + resourceId: 5, + protocol: ProtocolHTTP, + authType: ChannelAuthTypeIp, + count: 2, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 设置CloudAutoQuery的模拟返回 + mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { + return remote.CloudConnectResp{ + "proxy3": []remote.AutoConfig{ + {Province: "河南", City: "郑州", Isp: "电信", Count: 10}, + }, + }, nil + } + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 5, 100, true, + 0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows) + + // 模拟查询代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1) + mdb.ExpectQuery("SELECT"). + WithArgs(1). + WillReturnRows(proxyRows) + + // 模拟查询通道 + channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"}) + mdb.ExpectQuery("SELECT"). + WillReturnRows(channelRows) + + // 模拟查询白名单 - 3个IP + whitelistRows := sqlmock.NewRows([]string{"host"}). + AddRow("192.168.1.1"). + AddRow("192.168.1.2"). + AddRow("192.168.1.3") + mdb.ExpectQuery("SELECT"). + WithArgs(int32(100)). + WillReturnRows(whitelistRows) + + // 模拟保存通道 - 2个通道 * 3个白名单 = 6个 + mdb.ExpectQuery("INSERT INTO").WillReturnRows( + sqlmock.NewRows([]string{"id"}). + AddRow(7).AddRow(8).AddRow(9). + AddRow(10).AddRow(11).AddRow(12), + ) + + // 模拟更新套餐使用记录 + mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1)) + + // 提交事务 + mdb.ExpectCommit() + }, + want: []string{ + "http://proxy3.example.com:10000", + "http://proxy3.example.com:10001", + }, + checkCache: func(t *testing.T) { + // 检查应该创建了6个通道(2个通道 * 3个白名单) + for i := 7; i <= 12; i++ { + key := fmt.Sprintf("channel:%d", i) + if !mr.Exists(key) { + t.Errorf("Redis缓存中应有键 %s", key) + } + } + }, + }, + { + name: "管理员创建SOCKS5密码通道", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadAdmin, + Id: 1, + }, + }, + resourceId: 6, + protocol: ProtocolSocks5, + authType: ChannelAuthTypePass, + count: 2, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 设置CloudAutoQuery的模拟返回 + mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { + return remote.CloudConnectResp{ + "proxy4": []remote.AutoConfig{ + {Province: "河南", City: "郑州", Isp: "电信", Count: 5}, + }, + }, nil + } + + // 设置CloudConnect的模拟逻辑 + mc.ConnectMock = func(param remote.CloudConnectReq) error { + return nil + } + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 6, 102, true, + 1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows) + + // 模拟查询代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1) + mdb.ExpectQuery("SELECT"). + WithArgs(1). + WillReturnRows(proxyRows) + + // 模拟查询通道 + channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"}) + mdb.ExpectQuery("SELECT"). + WillReturnRows(channelRows) + + // 模拟保存通道 + mdb.ExpectQuery("INSERT INTO").WillReturnRows( + sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14), + ) + + // 模拟更新套餐使用记录 + mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1)) + + // 提交事务 + mdb.ExpectCommit() + }, + want: []string{ + "socks5://proxy4.example.com:10000", + "socks5://proxy4.example.com:10001", + }, + checkCache: func(t *testing.T) { + for i := 13; i <= 14; i++ { + key := fmt.Sprintf("channel:%d", i) + if !mr.Exists(key) { + t.Errorf("Redis缓存中应有键 %s", key) + } + } + }, + }, + { + name: "套餐不存在", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + resourceId: 999, + protocol: ProtocolHTTP, + authType: ChannelAuthTypeIp, + count: 1, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐不存在 + mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound) + + // 回滚事务 + mdb.ExpectRollback() + }, + wantErr: true, + wantErrContains: "套餐不存在", + }, + { + name: "套餐没有权限", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 101, + }, + }, + resourceId: 7, + protocol: ProtocolHTTP, + authType: ChannelAuthTypeIp, + count: 1, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 7, 102, true, // 注意:user_id 与 auth.Id 不匹配 + 0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows) + + // 回滚事务 + mdb.ExpectRollback() + }, + wantErr: true, + wantErrContains: "无权限访问", + }, + { + name: "套餐配额不足", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + resourceId: 2, + protocol: ProtocolHTTP, + authType: ChannelAuthTypeIp, + count: 10, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 2, 100, true, + 0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows) + + // 回滚事务 + mdb.ExpectRollback() + }, + wantErr: true, + wantErrContains: "套餐配额不足", + }, + { + name: "端口数量达到上限", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + resourceId: 8, + protocol: ProtocolHTTP, + authType: ChannelAuthTypeIp, + count: 1, + }, + setup: func() { + // 清空Redis + mr.FlushAll() + + // 设置CloudAutoQuery的模拟返回 + mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { + return remote.CloudConnectResp{ + "proxy5": []remote.AutoConfig{ + {Province: "河南", City: "郑州", Isp: "电信", Count: 10}, + }, + }, nil + } + + // 开始事务 + mdb.ExpectBegin() + + // 模拟查询套餐 + resourceRows := sqlmock.NewRows([]string{ + "id", "user_id", "active", + "type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire", + }).AddRow( + 8, 100, true, + 0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour), + ) + mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows) + + // 模拟查询代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1) + mdb.ExpectQuery("SELECT"). + WithArgs(1). + WillReturnRows(proxyRows) + + // 模拟通道端口已用尽 + // 构建一个大量已使用端口的结果集 + channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"}) + for i := 10000; i < 65535; i++ { + channelRows.AddRow(5, i) + } + mdb.ExpectQuery("SELECT"). + WillReturnRows(channelRows) + + // 回滚事务 + mdb.ExpectRollback() + }, + wantErr: true, + wantErrContains: "端口数量不足", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + s := &channelService{} + got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...) + + // 检查错误或结果 + if tt.wantErr { + if err == nil { + t.Errorf("CreateChannel() 应当返回错误") + return + } + if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr) + return + } + + if len(got) != len(tt.want) { + t.Errorf("CreateChannel() 返回长度 = %v, want %v", len(got), len(tt.want)) + return + } + + // 检查返回地址格式 + for _, addr := range got { + protocol := string(tt.args.protocol) + if !strings.HasPrefix(addr, protocol+"://") { + t.Errorf("CreateChannel() 地址 %v 不是有效的 %s 地址", addr, protocol) + } + } + + // 验证所有期望的 SQL 已执行 + if err := mdb.ExpectationsWereMet(); err != nil { + t.Errorf("有未满足的SQL期望: %s", err) + } + + // 检查 Redis 缓存是否正确设置 + if tt.checkCache != nil { + tt.checkCache(t) + } + }) + } +} + +func Test_channelService_RemoveChannels(t *testing.T) { + mr := testutil.SetupRedisTest(t) + mdb := testutil.SetupDBTest(t) + mg := testutil.SetupGatewayClientMock(t) + env.DebugExternalChange = false + + type args struct { + ctx context.Context + auth *AuthContext + id []int32 + } + + // 准备测试数据 + ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") + tests := []struct { + name string + args args + setup func() + wantErr bool + wantErrContains string + checkCache func(t *testing.T) + }{ + { + name: "管理员删除多个通道", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadAdmin, + Id: 1, + }, + }, + id: []int32{1, 2, 3}, + }, + setup: func() { + // 预设 Redis 缓存 + mr.FlushAll() + for _, id := range []int32{1, 2, 3} { + key := fmt.Sprintf("channel:%d", id) + channel := models.Channel{ID: id, UserID: 100} + data, _ := json.Marshal(channel) + mr.Set(key, string(data)) + } + + // 开始事务 + mdb.ExpectBegin() + + // 查找通道 + channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}). + AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)). + AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)). + AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour)) + mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")). + WithArgs(int32(1), int32(2), int32(3)). + WillReturnRows(channelRows) + + // 查找代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1). + AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1) + mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")). + WithArgs(int32(1), int32(2)). + WillReturnRows(proxyRows) + + // 软删除通道 + mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")). + WillReturnResult(sqlmock.NewResult(0, 3)) + + // 提交事务 + mdb.ExpectCommit() + }, + checkCache: func(t *testing.T) { + for _, id := range []int32{1, 2, 3} { + key := fmt.Sprintf("channel:%d", id) + if mr.Exists(key) { + t.Errorf("通道缓存 %s 应被删除但仍存在", key) + } + } + }, + }, + { + name: "用户删除自己的通道", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + id: []int32{1}, + }, + setup: func() { + // 预设 Redis 缓存 + mr.FlushAll() + key := "channel:1" + channel := models.Channel{ID: 1, UserID: 100} + data, _ := json.Marshal(channel) + mr.Set(key, string(data)) + + // 模拟查询已激活的端口 + mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + return map[string]remote.PortData{ + "10001": { + Edge: []string{"edge1", "edge2"}, + }, + }, nil + } + + // 开始事务 + mdb.ExpectBegin() + + // 查找通道 + channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}). + AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)) + mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")). + WithArgs(int32(1)). + WillReturnRows(channelRows) + + // 查找代理 + proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}). + AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1) + mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")). + WithArgs(int32(1)). + WillReturnRows(proxyRows) + + // 软删除通道 + mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // 提交事务 + mdb.ExpectCommit() + }, + checkCache: func(t *testing.T) { + key := "channel:1" + if mr.Exists(key) { + t.Errorf("通道缓存 %s 应被删除但仍存在", key) + } + }, + }, + { + name: "用户删除不属于自己的通道", + args: args{ + ctx: ctx, + auth: &AuthContext{ + Payload: Payload{ + Type: PayloadUser, + Id: 100, + }, + }, + id: []int32{5}, + }, + setup: func() { + // 预设 Redis 缓存 + mr.FlushAll() + key := "channel:5" + channel := models.Channel{ID: 5, UserID: 101} + data, _ := json.Marshal(channel) + mr.Set(key, string(data)) + + // 开始事务 + mdb.ExpectBegin() + + // 查找通道 + channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}). + AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour)) + mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")). + WithArgs(int32(5)). + WillReturnRows(channelRows) + + // 回滚事务 + mdb.ExpectRollback() + }, + wantErr: true, + wantErrContains: "无权限访问", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.setup != nil { + tt.setup() + } + + s := &channelService{} + err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...) + + // 检查错误 + if tt.wantErr { + if err == nil { + t.Errorf("RemoveChannels() 应当返回错误") + return + } + if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) { + t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains) + } + return + } + + if err != nil { + t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr) + return + } + + // 验证所有期望的 SQL 已执行 + if err := mdb.ExpectationsWereMet(); err != nil { + t.Errorf("有未满足的SQL期望: %s", err) + } + + // 检查 Redis 缓存是否正确设置 + if tt.checkCache != nil { + tt.checkCache(t) + } + }) + } +} diff --git a/web/services/session_test.go b/web/services/session_test.go index d5f0cb3..ef0b178 100644 --- a/web/services/session_test.go +++ b/web/services/session_test.go @@ -3,36 +3,12 @@ package services import ( "context" "errors" - "platform/pkg/rds" + "platform/pkg/testutil" "reflect" "testing" "time" - - "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" ) -// 设置 Redis 模拟服务器 -func setupTestRedis(t *testing.T) *miniredis.Miniredis { - mr, err := miniredis.Run() - if err != nil { - t.Fatalf("无法启动 miniredis: %v", err) - } - - // 替换 Redis 客户端为测试客户端 - origClient := rds.Client - rds.Client = redis.NewClient(&redis.Options{ - Addr: mr.Addr(), - }) - - t.Cleanup(func() { - mr.Close() - rds.Client = origClient - }) - - return mr -} - // 创建测试用的认证上下文 func createTestAuthContext() AuthContext { return AuthContext{ @@ -52,7 +28,7 @@ func createTestAuthContext() AuthContext { } func Test_sessionService_Create(t *testing.T) { - mr := setupTestRedis(t) + mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() @@ -162,7 +138,7 @@ func Test_sessionService_Create(t *testing.T) { } func Test_sessionService_Find(t *testing.T) { - _ = setupTestRedis(t) + testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} @@ -221,7 +197,7 @@ func Test_sessionService_Find(t *testing.T) { } func Test_sessionService_Refresh(t *testing.T) { - mr := setupTestRedis(t) + mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} @@ -314,7 +290,7 @@ func Test_sessionService_Refresh(t *testing.T) { } func Test_sessionService_Remove(t *testing.T) { - mr := setupTestRedis(t) + mr := testutil.SetupRedisTest(t) ctx := context.Background() auth := createTestAuthContext() s := &sessionService{} diff --git a/web/services/verifier_test.go b/web/services/verifier_test.go index cf7c064..f3151a3 100644 --- a/web/services/verifier_test.go +++ b/web/services/verifier_test.go @@ -2,30 +2,14 @@ package services import ( "context" - "platform/pkg/rds" + "platform/pkg/testutil" "strconv" "testing" "time" "github.com/alicebob/miniredis/v2" - "github.com/redis/go-redis/v9" ) -// 设置测试的 Redis 环境 -func setupRedisTest(t *testing.T) *miniredis.Miniredis { - mr, err := miniredis.Run() - if err != nil { - t.Fatalf("设置 miniredis 失败: %v", err) - } - - // 替换 redis 客户端为测试客户端 - rds.Client = redis.NewClient(&redis.Options{ - Addr: mr.Addr(), - }) - - return mr -} - func Test_verifierService_SendSms(t *testing.T) { type args struct { ctx context.Context @@ -82,7 +66,7 @@ func Test_verifierService_SendSms(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 设置 Redis 测试环境 - mr := setupRedisTest(t) + mr := testutil.SetupRedisTest(t) defer mr.Close() // 执行测试前的设置 @@ -216,7 +200,7 @@ func Test_verifierService_VerifySms(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // 设置 Redis 测试环境 - mr := setupRedisTest(t) + mr := testutil.SetupRedisTest(t) defer mr.Close() // 执行测试前的设置