新增通道服务相关测试用例
This commit is contained in:
1
go.mod
1
go.mod
@@ -18,6 +18,7 @@ require (
|
|||||||
)
|
)
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect
|
||||||
github.com/andybalholm/brotli v1.1.0 // indirect
|
github.com/andybalholm/brotli v1.1.0 // indirect
|
||||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||||
|
|||||||
3
go.sum
3
go.sum
@@ -1,3 +1,5 @@
|
|||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
|
||||||
|
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE=
|
||||||
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc=
|
||||||
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0=
|
||||||
@@ -42,6 +44,7 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
|
|||||||
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
|
||||||
github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw=
|
github.com/jxskiss/base62 v1.1.0 h1:A5zbF8v8WXx2xixnAKD2w+abC+sIzYJX+nxmhA6HWFw=
|
||||||
github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc=
|
github.com/jxskiss/base62 v1.1.0/go.mod h1:HhWAlUXvxKThfOlZbcuFzsqwtF5TcqS9ru3y5GfjWAc=
|
||||||
|
github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE=
|
||||||
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA=
|
||||||
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
|
||||||
github.com/lmittmann/tint v1.0.7 h1:D/0OqWZ0YOGZ6AyC+5Y2kD8PBEzBk6rFHVSfOqCkF9Y=
|
github.com/lmittmann/tint v1.0.7 h1:D/0OqWZ0YOGZ6AyC+5Y2kD8PBEzBk6rFHVSfOqCkF9Y=
|
||||||
|
|||||||
@@ -13,15 +13,29 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
// CloudClient 定义云服务接口
|
||||||
|
type CloudClient interface {
|
||||||
|
CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error)
|
||||||
|
CloudConnect(param CloudConnectReq) error
|
||||||
|
CloudDisconnect(param CloudDisconnectReq) (int, error)
|
||||||
|
CloudAutoQuery() (CloudConnectResp, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatewayClient 定义网关接口
|
||||||
|
type GatewayClient interface {
|
||||||
|
GatewayPortConfigs(params []PortConfigsReq) error
|
||||||
|
GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type cloud struct {
|
||||||
url string
|
url string
|
||||||
token string
|
token string
|
||||||
}
|
}
|
||||||
|
|
||||||
var Client client
|
var Cloud CloudClient
|
||||||
|
|
||||||
func Init() {
|
func Init() {
|
||||||
Client = client{
|
Cloud = &cloud{
|
||||||
url: env.RemoteAddr,
|
url: env.RemoteAddr,
|
||||||
token: env.RemoteToken,
|
token: env.RemoteToken,
|
||||||
}
|
}
|
||||||
@@ -61,7 +75,7 @@ type Edge struct {
|
|||||||
PacketLoss int `json:"packet_loss"`
|
PacketLoss int `json:"packet_loss"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error) {
|
func (c *cloud) CloudEdges(param CloudEdgesReq) (*CloudEdgesResp, error) {
|
||||||
data := strings.Builder{}
|
data := strings.Builder{}
|
||||||
data.WriteString("province=")
|
data.WriteString("province=")
|
||||||
data.WriteString(param.Province)
|
data.WriteString(param.Province)
|
||||||
@@ -110,7 +124,7 @@ type CloudConnectReq struct {
|
|||||||
AutoConfig []AutoConfig `json:"auto_config,omitempty"`
|
AutoConfig []AutoConfig `json:"auto_config,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) CloudConnect(param CloudConnectReq) error {
|
func (c *cloud) CloudConnect(param CloudConnectReq) error {
|
||||||
data, err := json.Marshal(param)
|
data, err := json.Marshal(param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -165,7 +179,7 @@ type Config struct {
|
|||||||
Online bool `json:"online"`
|
Online bool `json:"online"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) CloudDisconnect(param CloudDisconnectReq) (int, error) {
|
func (c *cloud) CloudDisconnect(param CloudDisconnectReq) (int, error) {
|
||||||
data, err := json.Marshal(param)
|
data, err := json.Marshal(param)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
@@ -208,7 +222,7 @@ func (c *client) CloudDisconnect(param CloudDisconnectReq) (int, error) {
|
|||||||
|
|
||||||
type CloudConnectResp map[string][]AutoConfig
|
type CloudConnectResp map[string][]AutoConfig
|
||||||
|
|
||||||
func (c *client) CloudAutoQuery() (CloudConnectResp, error) {
|
func (c *cloud) CloudAutoQuery() (CloudConnectResp, error) {
|
||||||
resp, err := c.requestCloud("GET", "/auto_query", "")
|
resp, err := c.requestCloud("GET", "/auto_query", "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -237,7 +251,7 @@ func (c *client) CloudAutoQuery() (CloudConnectResp, error) {
|
|||||||
|
|
||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
func (c *client) requestCloud(method string, url string, data string) (*http.Response, error) {
|
func (c *cloud) requestCloud(method string, url string, data string) (*http.Response, error) {
|
||||||
|
|
||||||
url = fmt.Sprintf("%s/api%s", c.url, url)
|
url = fmt.Sprintf("%s/api%s", c.url, url)
|
||||||
req, err := http.NewRequest(method, url, strings.NewReader(data))
|
req, err := http.NewRequest(method, url, strings.NewReader(data))
|
||||||
@@ -274,14 +288,22 @@ func (c *client) requestCloud(method string, url string, data string) (*http.Res
|
|||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Gateway struct {
|
type gateway struct {
|
||||||
url string
|
url string
|
||||||
username string
|
username string
|
||||||
password string
|
password string
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitGateway(url, username, password string) *Gateway {
|
var GatewayInitializer = func(url, username, password string) GatewayClient {
|
||||||
return &Gateway{url, username, password}
|
return &gateway{
|
||||||
|
url: url,
|
||||||
|
username: username,
|
||||||
|
password: password,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGateway(url, username, password string) GatewayClient {
|
||||||
|
return GatewayInitializer(url, username, password)
|
||||||
}
|
}
|
||||||
|
|
||||||
// region gateway:/port/configs
|
// region gateway:/port/configs
|
||||||
@@ -306,7 +328,7 @@ type AutoEdgeConfig struct {
|
|||||||
PacketLoss int `json:"packet_loss,omitempty"`
|
PacketLoss int `json:"packet_loss,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Gateway) GatewayPortConfigs(params []PortConfigsReq) error {
|
func (c *gateway) GatewayPortConfigs(params []PortConfigsReq) error {
|
||||||
if len(params) == 0 {
|
if len(params) == 0 {
|
||||||
return errors.New("params is empty")
|
return errors.New("params is empty")
|
||||||
}
|
}
|
||||||
@@ -372,7 +394,7 @@ type PortData struct {
|
|||||||
Userpass string `json:"userpass"`
|
Userpass string `json:"userpass"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) {
|
func (c *gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) {
|
||||||
_param := PortActiveReq{}
|
_param := PortActiveReq{}
|
||||||
if len(param) != 0 {
|
if len(param) != 0 {
|
||||||
_param = param[0]
|
_param = param[0]
|
||||||
@@ -431,7 +453,7 @@ func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData
|
|||||||
|
|
||||||
// endregion
|
// endregion
|
||||||
|
|
||||||
func (c *Gateway) requestGateway(method string, url string, data string) (*http.Response, error) {
|
func (c *gateway) requestGateway(method string, url string, data string) (*http.Response, error) {
|
||||||
url = fmt.Sprintf("http://%s:%s@%s:9990%s", c.username, c.password, c.url, url)
|
url = fmt.Sprintf("http://%s:%s@%s:9990%s", c.username, c.password, c.url, url)
|
||||||
req, err := http.NewRequest(method, url, strings.NewReader(data))
|
req, err := http.NewRequest(method, url, strings.NewReader(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
40
pkg/testutil/db.go
Normal file
40
pkg/testutil/db.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"platform/pkg/orm"
|
||||||
|
q "platform/web/queries"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"gorm.io/driver/postgres"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupDBTest 创建一个带有 sqlmock 的 GORM 数据库连接
|
||||||
|
func SetupDBTest(t *testing.T) sqlmock.Sqlmock {
|
||||||
|
|
||||||
|
// 创建 sqlmock
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("创建sqlmock失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 配置 gorm 连接
|
||||||
|
gormDB, err := gorm.Open(postgres.New(postgres.Config{
|
||||||
|
Conn: db,
|
||||||
|
PreferSimpleProtocol: true, // 禁用 prepared statement 缓存
|
||||||
|
}), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("gorm 打开数据库连接失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
q.SetDefault(gormDB)
|
||||||
|
orm.DB = gormDB
|
||||||
|
|
||||||
|
// 使用 t.Cleanup 确保测试结束后关闭数据库连接
|
||||||
|
t.Cleanup(func() {
|
||||||
|
db.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
||||||
30
pkg/testutil/redis.go
Normal file
30
pkg/testutil/redis.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"platform/pkg/rds"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/alicebob/miniredis/v2"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupRedisTest 创建一个测试用的Redis实例
|
||||||
|
// 返回miniredis实例,使用t.Cleanup自动清理资源
|
||||||
|
func SetupRedisTest(t *testing.T) *miniredis.Miniredis {
|
||||||
|
mr, err := miniredis.Run()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("设置 miniredis 失败: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 替换 Redis 客户端为测试客户端
|
||||||
|
rds.Client = redis.NewClient(&redis.Options{
|
||||||
|
Addr: mr.Addr(),
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用t.Cleanup确保测试结束后恢复原始客户端并关闭miniredis
|
||||||
|
t.Cleanup(func() {
|
||||||
|
mr.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return mr
|
||||||
|
}
|
||||||
128
pkg/testutil/remote.go
Normal file
128
pkg/testutil/remote.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"platform/pkg/remote"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// MockCloudClient 是CloudClient接口的测试实现
|
||||||
|
type MockCloudClient struct {
|
||||||
|
// 存储预期结果的字段
|
||||||
|
EdgesMock func(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error)
|
||||||
|
ConnectMock func(param remote.CloudConnectReq) error
|
||||||
|
DisconnectMock func(param remote.CloudDisconnectReq) (int, error)
|
||||||
|
AutoQueryMock func() (remote.CloudConnectResp, error)
|
||||||
|
|
||||||
|
// 记录调用历史
|
||||||
|
EdgesCalls []remote.CloudEdgesReq
|
||||||
|
ConnectCalls []remote.CloudConnectReq
|
||||||
|
DisconnectCalls []remote.CloudDisconnectReq
|
||||||
|
AutoQueryCalls int
|
||||||
|
|
||||||
|
// 用于并发安全
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保MockCloudClient实现了CloudClient接口
|
||||||
|
var _ remote.CloudClient = (*MockCloudClient)(nil)
|
||||||
|
|
||||||
|
func (m *MockCloudClient) CloudEdges(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.EdgesCalls = append(m.EdgesCalls, param)
|
||||||
|
if m.EdgesMock != nil {
|
||||||
|
return m.EdgesMock(param)
|
||||||
|
}
|
||||||
|
return &remote.CloudEdgesResp{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCloudClient) CloudConnect(param remote.CloudConnectReq) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.ConnectCalls = append(m.ConnectCalls, param)
|
||||||
|
if m.ConnectMock != nil {
|
||||||
|
return m.ConnectMock(param)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCloudClient) CloudDisconnect(param remote.CloudDisconnectReq) (int, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.DisconnectCalls = append(m.DisconnectCalls, param)
|
||||||
|
if m.DisconnectMock != nil {
|
||||||
|
return m.DisconnectMock(param)
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.AutoQueryCalls++
|
||||||
|
if m.AutoQueryMock != nil {
|
||||||
|
return m.AutoQueryMock()
|
||||||
|
}
|
||||||
|
return remote.CloudConnectResp{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockGatewayClient 是GatewayClient接口的测试实现
|
||||||
|
type MockGatewayClient struct {
|
||||||
|
// 存储预期结果的字段
|
||||||
|
PortConfigsMock func(params []remote.PortConfigsReq) error
|
||||||
|
PortActiveMock func(param ...remote.PortActiveReq) (map[string]remote.PortData, error)
|
||||||
|
|
||||||
|
// 记录调用历史
|
||||||
|
PortConfigsCalls [][]remote.PortConfigsReq
|
||||||
|
PortActiveCalls [][]remote.PortActiveReq
|
||||||
|
|
||||||
|
// 用于并发安全
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保MockGatewayClient实现了GatewayClient接口
|
||||||
|
var _ remote.GatewayClient = (*MockGatewayClient)(nil)
|
||||||
|
|
||||||
|
func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.PortConfigsCalls = append(m.PortConfigsCalls, params)
|
||||||
|
if m.PortConfigsMock != nil {
|
||||||
|
return m.PortConfigsMock(params)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.PortActiveCalls = append(m.PortActiveCalls, param)
|
||||||
|
if m.PortActiveMock != nil {
|
||||||
|
return m.PortActiveMock(param...)
|
||||||
|
}
|
||||||
|
return map[string]remote.PortData{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||||
|
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||||
|
mock := &MockCloudClient{}
|
||||||
|
remote.Cloud = mock
|
||||||
|
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数
|
||||||
|
func SetupGatewayClientMock(t *testing.T) *MockGatewayClient {
|
||||||
|
mock := &MockGatewayClient{}
|
||||||
|
remote.GatewayInitializer = func(url, username, password string) remote.GatewayClient {
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
return mock
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMockGatewayClient 创建一个新的MockGatewayClient
|
||||||
|
// 保留此函数以保持向后兼容性
|
||||||
|
func NewMockGatewayClient() *MockGatewayClient {
|
||||||
|
return &MockGatewayClient{}
|
||||||
|
}
|
||||||
@@ -99,6 +99,9 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
|
|||||||
proxies, err := tx.Proxy.Where(
|
proxies, err := tx.Proxy.Where(
|
||||||
q.Proxy.ID.In(proxyIds...),
|
q.Proxy.ID.In(proxyIds...),
|
||||||
).Find()
|
).Find()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
slog.Debug("查找代理", "rid", rid, "step", time.Since(step))
|
slog.Debug("查找代理", "rid", rid, "step", time.Since(step))
|
||||||
|
|
||||||
@@ -163,7 +166,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
|
|||||||
}
|
}
|
||||||
|
|
||||||
var secret = strings.Split(proxy.Secret, ":")
|
var secret = strings.Split(proxy.Secret, ":")
|
||||||
gateway := remote.InitGateway(
|
gateway := remote.NewGateway(
|
||||||
proxy.Host,
|
proxy.Host,
|
||||||
secret[0],
|
secret[0],
|
||||||
secret[1],
|
secret[1],
|
||||||
@@ -204,7 +207,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(edges) > 0 {
|
if len(edges) > 0 {
|
||||||
_, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{
|
_, err := remote.Cloud.CloudDisconnect(remote.CloudDisconnectReq{
|
||||||
Uuid: proxy.Name,
|
Uuid: proxy.Name,
|
||||||
Edge: edges,
|
Edge: edges,
|
||||||
})
|
})
|
||||||
@@ -395,7 +398,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
|
|||||||
// 查询已配置的节点
|
// 查询已配置的节点
|
||||||
step = time.Now()
|
step = time.Now()
|
||||||
|
|
||||||
rProxyConfigs, err := remote.Client.CloudAutoQuery()
|
rProxyConfigs, err := remote.Cloud.CloudAutoQuery()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -466,7 +469,7 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
|
|||||||
step = time.Now()
|
step = time.Now()
|
||||||
|
|
||||||
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
|
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
|
||||||
err := remote.Client.CloudConnect(remote.CloudConnectReq{
|
err := remote.Cloud.CloudConnect(remote.CloudConnectReq{
|
||||||
Uuid: info.proxy.Name,
|
Uuid: info.proxy.Name,
|
||||||
Edge: nil,
|
Edge: nil,
|
||||||
AutoConfig: []remote.AutoConfig{{
|
AutoConfig: []remote.AutoConfig{{
|
||||||
@@ -520,7 +523,7 @@ func assignPort(
|
|||||||
expiration time.Time,
|
expiration time.Time,
|
||||||
filter NodeFilterConfig,
|
filter NodeFilterConfig,
|
||||||
) ([]string, []*models.Channel, error) {
|
) ([]string, []*models.Channel, error) {
|
||||||
var step = time.Now()
|
var step time.Time
|
||||||
|
|
||||||
var configs = proxies.configs
|
var configs = proxies.configs
|
||||||
var exists = proxies.channels
|
var exists = proxies.channels
|
||||||
@@ -639,7 +642,7 @@ func assignPort(
|
|||||||
step = time.Now()
|
step = time.Now()
|
||||||
|
|
||||||
var secret = strings.Split(proxy.Secret, ":")
|
var secret = strings.Split(proxy.Secret, ":")
|
||||||
gateway := remote.InitGateway(
|
gateway := remote.NewGateway(
|
||||||
proxy.Host,
|
proxy.Host,
|
||||||
secret[0],
|
secret[0],
|
||||||
secret[1],
|
secret[1],
|
||||||
@@ -677,6 +680,10 @@ func chKey(channel *models.Channel) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cache(ctx context.Context, channels []*models.Channel) error {
|
func cache(ctx context.Context, channels []*models.Channel) error {
|
||||||
|
if len(channels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
pipe := rds.Client.TxPipeline()
|
pipe := rds.Client.TxPipeline()
|
||||||
|
|
||||||
zList := make([]redis.Z, 0, len(channels))
|
zList := make([]redis.Z, 0, len(channels))
|
||||||
@@ -685,7 +692,7 @@ func cache(ctx context.Context, channels []*models.Channel) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
|
pipe.Set(ctx, chKey(channel), string(marshal), time.Until(channel.Expiration))
|
||||||
zList = append(zList, redis.Z{
|
zList = append(zList, redis.Z{
|
||||||
Score: float64(channel.Expiration.Unix()),
|
Score: float64(channel.Expiration.Unix()),
|
||||||
Member: channel.ID,
|
Member: channel.ID,
|
||||||
@@ -702,6 +709,10 @@ func cache(ctx context.Context, channels []*models.Channel) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func deleteCache(ctx context.Context, channels []*models.Channel) error {
|
func deleteCache(ctx context.Context, channels []*models.Channel) error {
|
||||||
|
if len(channels) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
keys := make([]string, len(channels))
|
keys := make([]string, len(channels))
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
keys[i] = chKey(channels[i])
|
keys[i] = chKey(channels[i])
|
||||||
|
|||||||
985
web/services/channel_test.go
Normal file
985
web/services/channel_test.go
Normal file
@@ -0,0 +1,985 @@
|
|||||||
|
package services
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"platform/pkg/env"
|
||||||
|
"platform/pkg/remote"
|
||||||
|
"platform/pkg/testutil"
|
||||||
|
"platform/web/models"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_genPassPair(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常生成随机用户名和密码",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多次调用生成不同的值",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一个测试:检查生成的用户名和密码是否有效
|
||||||
|
t.Run(tests[0].name, func(t *testing.T) {
|
||||||
|
username, password := genPassPair()
|
||||||
|
if username == "" {
|
||||||
|
t.Errorf("genPassPair() username is empty")
|
||||||
|
}
|
||||||
|
if password == "" {
|
||||||
|
t.Errorf("genPassPair() password is empty")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 第二个测试:确保多次调用生成不同的值
|
||||||
|
t.Run(tests[1].name, func(t *testing.T) {
|
||||||
|
username1, password1 := genPassPair()
|
||||||
|
username2, password2 := genPassPair()
|
||||||
|
|
||||||
|
if username1 == username2 {
|
||||||
|
t.Errorf("genPassPair() generated the same username twice: %v", username1)
|
||||||
|
}
|
||||||
|
if password1 == password2 {
|
||||||
|
t.Errorf("genPassPair() generated the same password twice: %v", password1)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_chKey(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
channel *models.Channel
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ID为1的通道",
|
||||||
|
args: args{
|
||||||
|
channel: &models.Channel{
|
||||||
|
ID: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "channel:1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ID为100的通道",
|
||||||
|
args: args{
|
||||||
|
channel: &models.Channel{
|
||||||
|
ID: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "channel:100",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ID为0的通道",
|
||||||
|
args: args{
|
||||||
|
channel: &models.Channel{
|
||||||
|
ID: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: "channel:0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if got := chKey(tt.args.channel); got != tt.want {
|
||||||
|
t.Errorf("chKey() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_cache(t *testing.T) {
|
||||||
|
mr := testutil.SetupRedisTest(t)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
channels []*models.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备测试数据
|
||||||
|
now := time.Now()
|
||||||
|
expiration := now.Add(24 * time.Hour)
|
||||||
|
|
||||||
|
testChannels := []*models.Channel{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
UserID: 100,
|
||||||
|
ProxyID: 10,
|
||||||
|
ProxyPort: 8080,
|
||||||
|
Protocol: "http",
|
||||||
|
Expiration: expiration,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: 2,
|
||||||
|
UserID: 101,
|
||||||
|
ProxyID: 11,
|
||||||
|
ProxyPort: 8081,
|
||||||
|
Protocol: "socks5",
|
||||||
|
Expiration: expiration,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常缓存多个通道",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
channels: testChannels,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空通道列表",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
channels: []*models.Channel{},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mr.FlushAll() // 清空 Redis 数据
|
||||||
|
|
||||||
|
if err := cache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("cache() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证缓存结果
|
||||||
|
if len(tt.args.channels) > 0 {
|
||||||
|
for _, channel := range tt.args.channels {
|
||||||
|
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||||
|
if !mr.Exists(key) {
|
||||||
|
t.Errorf("缓存未包含通道键 %s", key)
|
||||||
|
} else {
|
||||||
|
// 验证缓存的数据是否正确
|
||||||
|
data, _ := mr.Get(key)
|
||||||
|
var cachedChannel models.Channel
|
||||||
|
err := json.Unmarshal([]byte(data), &cachedChannel)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("无法解析缓存数据: %v", err)
|
||||||
|
}
|
||||||
|
if cachedChannel.ID != channel.ID {
|
||||||
|
t.Errorf("缓存数据不匹配: 期望 ID %d, 得到 %d", channel.ID, cachedChannel.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否设置了过期时间
|
||||||
|
for _, channel := range tt.args.channels {
|
||||||
|
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||||
|
ttl := mr.TTL(key)
|
||||||
|
if ttl <= 0 {
|
||||||
|
t.Errorf("键 %s 没有设置过期时间", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否添加了有序集合
|
||||||
|
if !mr.Exists("tasks:channel") {
|
||||||
|
t.Errorf("ZAdd未创建有序集合 tasks:channel")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_deleteCache(t *testing.T) {
|
||||||
|
mr := testutil.SetupRedisTest(t)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
channels []*models.Channel
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备测试数据
|
||||||
|
testChannels := []*models.Channel{
|
||||||
|
{ID: 1},
|
||||||
|
{ID: 2},
|
||||||
|
{ID: 3},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常删除多个通道缓存",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
channels: testChannels,
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空通道列表",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
channels: []*models.Channel{},
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
mr.FlushAll() // 清空 Redis 数据
|
||||||
|
|
||||||
|
// 预先设置缓存数据
|
||||||
|
for _, channel := range testChannels {
|
||||||
|
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||||
|
data, _ := json.Marshal(channel)
|
||||||
|
mr.Set(key, string(data))
|
||||||
|
mr.SetTTL(key, 1*time.Hour) // 设置1小时的过期时间
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := deleteCache(tt.args.ctx, tt.args.channels); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("deleteCache() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证删除结果
|
||||||
|
for _, channel := range tt.args.channels {
|
||||||
|
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||||
|
if mr.Exists(key) {
|
||||||
|
t.Errorf("通道键 %s 未被删除", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_channelService_CreateChannel(t *testing.T) {
|
||||||
|
mr := testutil.SetupRedisTest(t)
|
||||||
|
mdb := testutil.SetupDBTest(t)
|
||||||
|
mc := testutil.SetupCloudClientMock(t)
|
||||||
|
env.DebugExternalChange = false
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
auth *AuthContext
|
||||||
|
resourceId int32
|
||||||
|
protocol ChannelProtocol
|
||||||
|
authType ChannelAuthType
|
||||||
|
count int
|
||||||
|
nodeFilter []NodeFilterConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备测试数据
|
||||||
|
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
setup func()
|
||||||
|
want []string
|
||||||
|
wantErr bool
|
||||||
|
wantErrContains string
|
||||||
|
checkCache func(t *testing.T)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "用户创建HTTP密码通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}},
|
||||||
|
resourceId: 4,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypePass,
|
||||||
|
count: 3,
|
||||||
|
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 设置CloudAutoQuery的模拟返回
|
||||||
|
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||||
|
return remote.CloudConnectResp{
|
||||||
|
"proxy3": []remote.AutoConfig{
|
||||||
|
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
4, 100, true,
|
||||||
|
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 模拟查询代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 模拟查询通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 模拟保存通道 - PostgreSQL返回ID
|
||||||
|
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||||
|
sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 模拟更新套餐使用记录
|
||||||
|
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
mdb.ExpectCommit()
|
||||||
|
},
|
||||||
|
want: []string{
|
||||||
|
"http://proxy3.example.com:10000",
|
||||||
|
"http://proxy3.example.com:10001",
|
||||||
|
"http://proxy3.example.com:10002",
|
||||||
|
},
|
||||||
|
checkCache: func(t *testing.T) {
|
||||||
|
// 检查总共创建了3个通道
|
||||||
|
for i := 4; i <= 6; i++ {
|
||||||
|
key := fmt.Sprintf("channel:%d", i)
|
||||||
|
if !mr.Exists(key) {
|
||||||
|
t.Errorf("Redis缓存中应有键 %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "用户创建HTTP白名单通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 5,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypeIp,
|
||||||
|
count: 2,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 设置CloudAutoQuery的模拟返回
|
||||||
|
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||||
|
return remote.CloudConnectResp{
|
||||||
|
"proxy3": []remote.AutoConfig{
|
||||||
|
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
5, 100, true,
|
||||||
|
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 模拟查询代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 模拟查询通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 模拟查询白名单 - 3个IP
|
||||||
|
whitelistRows := sqlmock.NewRows([]string{"host"}).
|
||||||
|
AddRow("192.168.1.1").
|
||||||
|
AddRow("192.168.1.2").
|
||||||
|
AddRow("192.168.1.3")
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WithArgs(int32(100)).
|
||||||
|
WillReturnRows(whitelistRows)
|
||||||
|
|
||||||
|
// 模拟保存通道 - 2个通道 * 3个白名单 = 6个
|
||||||
|
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||||
|
sqlmock.NewRows([]string{"id"}).
|
||||||
|
AddRow(7).AddRow(8).AddRow(9).
|
||||||
|
AddRow(10).AddRow(11).AddRow(12),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 模拟更新套餐使用记录
|
||||||
|
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
mdb.ExpectCommit()
|
||||||
|
},
|
||||||
|
want: []string{
|
||||||
|
"http://proxy3.example.com:10000",
|
||||||
|
"http://proxy3.example.com:10001",
|
||||||
|
},
|
||||||
|
checkCache: func(t *testing.T) {
|
||||||
|
// 检查应该创建了6个通道(2个通道 * 3个白名单)
|
||||||
|
for i := 7; i <= 12; i++ {
|
||||||
|
key := fmt.Sprintf("channel:%d", i)
|
||||||
|
if !mr.Exists(key) {
|
||||||
|
t.Errorf("Redis缓存中应有键 %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "管理员创建SOCKS5密码通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadAdmin,
|
||||||
|
Id: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 6,
|
||||||
|
protocol: ProtocolSocks5,
|
||||||
|
authType: ChannelAuthTypePass,
|
||||||
|
count: 2,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 设置CloudAutoQuery的模拟返回
|
||||||
|
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||||
|
return remote.CloudConnectResp{
|
||||||
|
"proxy4": []remote.AutoConfig{
|
||||||
|
{Province: "河南", City: "郑州", Isp: "电信", Count: 5},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置CloudConnect的模拟逻辑
|
||||||
|
mc.ConnectMock = func(param remote.CloudConnectReq) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
6, 102, true,
|
||||||
|
1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 模拟查询代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 模拟查询通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 模拟保存通道
|
||||||
|
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
|
||||||
|
sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14),
|
||||||
|
)
|
||||||
|
|
||||||
|
// 模拟更新套餐使用记录
|
||||||
|
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
mdb.ExpectCommit()
|
||||||
|
},
|
||||||
|
want: []string{
|
||||||
|
"socks5://proxy4.example.com:10000",
|
||||||
|
"socks5://proxy4.example.com:10001",
|
||||||
|
},
|
||||||
|
checkCache: func(t *testing.T) {
|
||||||
|
for i := 13; i <= 14; i++ {
|
||||||
|
key := fmt.Sprintf("channel:%d", i)
|
||||||
|
if !mr.Exists(key) {
|
||||||
|
t.Errorf("Redis缓存中应有键 %s", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "套餐不存在",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 999,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypeIp,
|
||||||
|
count: 1,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐不存在
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound)
|
||||||
|
|
||||||
|
// 回滚事务
|
||||||
|
mdb.ExpectRollback()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrContains: "套餐不存在",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "套餐没有权限",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 101,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 7,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypeIp,
|
||||||
|
count: 1,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
7, 102, true, // 注意:user_id 与 auth.Id 不匹配
|
||||||
|
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 回滚事务
|
||||||
|
mdb.ExpectRollback()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrContains: "无权限访问",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "套餐配额不足",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 2,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypeIp,
|
||||||
|
count: 10,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
2, 100, true,
|
||||||
|
0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 回滚事务
|
||||||
|
mdb.ExpectRollback()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrContains: "套餐配额不足",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "端口数量达到上限",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
resourceId: 8,
|
||||||
|
protocol: ProtocolHTTP,
|
||||||
|
authType: ChannelAuthTypeIp,
|
||||||
|
count: 1,
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 清空Redis
|
||||||
|
mr.FlushAll()
|
||||||
|
|
||||||
|
// 设置CloudAutoQuery的模拟返回
|
||||||
|
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||||
|
return remote.CloudConnectResp{
|
||||||
|
"proxy5": []remote.AutoConfig{
|
||||||
|
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 模拟查询套餐
|
||||||
|
resourceRows := sqlmock.NewRows([]string{
|
||||||
|
"id", "user_id", "active",
|
||||||
|
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
|
||||||
|
}).AddRow(
|
||||||
|
8, 100, true,
|
||||||
|
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
|
||||||
|
)
|
||||||
|
mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows)
|
||||||
|
|
||||||
|
// 模拟查询代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WithArgs(1).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 模拟通道端口已用尽
|
||||||
|
// 构建一个大量已使用端口的结果集
|
||||||
|
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
|
||||||
|
for i := 10000; i < 65535; i++ {
|
||||||
|
channelRows.AddRow(5, i)
|
||||||
|
}
|
||||||
|
mdb.ExpectQuery("SELECT").
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 回滚事务
|
||||||
|
mdb.ExpectRollback()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrContains: "端口数量不足",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.setup != nil {
|
||||||
|
tt.setup()
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &channelService{}
|
||||||
|
got, err := s.CreateChannel(tt.args.ctx, tt.args.auth, tt.args.resourceId, tt.args.protocol, tt.args.authType, tt.args.count, tt.args.nodeFilter...)
|
||||||
|
|
||||||
|
// 检查错误或结果
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("CreateChannel() 应当返回错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||||
|
t.Errorf("CreateChannel() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("CreateChannel() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != len(tt.want) {
|
||||||
|
t.Errorf("CreateChannel() 返回长度 = %v, want %v", len(got), len(tt.want))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查返回地址格式
|
||||||
|
for _, addr := range got {
|
||||||
|
protocol := string(tt.args.protocol)
|
||||||
|
if !strings.HasPrefix(addr, protocol+"://") {
|
||||||
|
t.Errorf("CreateChannel() 地址 %v 不是有效的 %s 地址", addr, protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有期望的 SQL 已执行
|
||||||
|
if err := mdb.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("有未满足的SQL期望: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 Redis 缓存是否正确设置
|
||||||
|
if tt.checkCache != nil {
|
||||||
|
tt.checkCache(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_channelService_RemoveChannels(t *testing.T) {
|
||||||
|
mr := testutil.SetupRedisTest(t)
|
||||||
|
mdb := testutil.SetupDBTest(t)
|
||||||
|
mg := testutil.SetupGatewayClientMock(t)
|
||||||
|
env.DebugExternalChange = false
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
auth *AuthContext
|
||||||
|
id []int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备测试数据
|
||||||
|
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
setup func()
|
||||||
|
wantErr bool
|
||||||
|
wantErrContains string
|
||||||
|
checkCache func(t *testing.T)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "管理员删除多个通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadAdmin,
|
||||||
|
Id: 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: []int32{1, 2, 3},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 预设 Redis 缓存
|
||||||
|
mr.FlushAll()
|
||||||
|
for _, id := range []int32{1, 2, 3} {
|
||||||
|
key := fmt.Sprintf("channel:%d", id)
|
||||||
|
channel := models.Channel{ID: id, UserID: 100}
|
||||||
|
data, _ := json.Marshal(channel)
|
||||||
|
mr.Set(key, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 查找通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||||
|
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)).
|
||||||
|
AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)).
|
||||||
|
AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour))
|
||||||
|
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||||
|
WithArgs(int32(1), int32(2), int32(3)).
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 查找代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1).
|
||||||
|
AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
|
||||||
|
WithArgs(int32(1), int32(2)).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 软删除通道
|
||||||
|
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 3))
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
mdb.ExpectCommit()
|
||||||
|
},
|
||||||
|
checkCache: func(t *testing.T) {
|
||||||
|
for _, id := range []int32{1, 2, 3} {
|
||||||
|
key := fmt.Sprintf("channel:%d", id)
|
||||||
|
if mr.Exists(key) {
|
||||||
|
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "用户删除自己的通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: []int32{1},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 预设 Redis 缓存
|
||||||
|
mr.FlushAll()
|
||||||
|
key := "channel:1"
|
||||||
|
channel := models.Channel{ID: 1, UserID: 100}
|
||||||
|
data, _ := json.Marshal(channel)
|
||||||
|
mr.Set(key, string(data))
|
||||||
|
|
||||||
|
// 模拟查询已激活的端口
|
||||||
|
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||||
|
return map[string]remote.PortData{
|
||||||
|
"10001": {
|
||||||
|
Edge: []string{"edge1", "edge2"},
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 查找通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||||
|
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour))
|
||||||
|
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||||
|
WithArgs(int32(1)).
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 查找代理
|
||||||
|
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
|
||||||
|
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1)
|
||||||
|
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
|
||||||
|
WithArgs(int32(1)).
|
||||||
|
WillReturnRows(proxyRows)
|
||||||
|
|
||||||
|
// 软删除通道
|
||||||
|
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
|
||||||
|
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||||
|
|
||||||
|
// 提交事务
|
||||||
|
mdb.ExpectCommit()
|
||||||
|
},
|
||||||
|
checkCache: func(t *testing.T) {
|
||||||
|
key := "channel:1"
|
||||||
|
if mr.Exists(key) {
|
||||||
|
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "用户删除不属于自己的通道",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
auth: &AuthContext{
|
||||||
|
Payload: Payload{
|
||||||
|
Type: PayloadUser,
|
||||||
|
Id: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
id: []int32{5},
|
||||||
|
},
|
||||||
|
setup: func() {
|
||||||
|
// 预设 Redis 缓存
|
||||||
|
mr.FlushAll()
|
||||||
|
key := "channel:5"
|
||||||
|
channel := models.Channel{ID: 5, UserID: 101}
|
||||||
|
data, _ := json.Marshal(channel)
|
||||||
|
mr.Set(key, string(data))
|
||||||
|
|
||||||
|
// 开始事务
|
||||||
|
mdb.ExpectBegin()
|
||||||
|
|
||||||
|
// 查找通道
|
||||||
|
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
|
||||||
|
AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour))
|
||||||
|
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
|
||||||
|
WithArgs(int32(5)).
|
||||||
|
WillReturnRows(channelRows)
|
||||||
|
|
||||||
|
// 回滚事务
|
||||||
|
mdb.ExpectRollback()
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
wantErrContains: "无权限访问",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.setup != nil {
|
||||||
|
tt.setup()
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &channelService{}
|
||||||
|
err := s.RemoveChannels(tt.args.ctx, tt.args.auth, tt.args.id...)
|
||||||
|
|
||||||
|
// 检查错误
|
||||||
|
if tt.wantErr {
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("RemoveChannels() 应当返回错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantErrContains != "" && !strings.Contains(err.Error(), tt.wantErrContains) {
|
||||||
|
t.Errorf("RemoveChannels() 错误 = %v, 应包含 %v", err, tt.wantErrContains)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("RemoveChannels() 错误 = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证所有期望的 SQL 已执行
|
||||||
|
if err := mdb.ExpectationsWereMet(); err != nil {
|
||||||
|
t.Errorf("有未满足的SQL期望: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 Redis 缓存是否正确设置
|
||||||
|
if tt.checkCache != nil {
|
||||||
|
tt.checkCache(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -3,36 +3,12 @@ package services
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"platform/pkg/rds"
|
"platform/pkg/testutil"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alicebob/miniredis/v2"
|
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 设置 Redis 模拟服务器
|
|
||||||
func setupTestRedis(t *testing.T) *miniredis.Miniredis {
|
|
||||||
mr, err := miniredis.Run()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("无法启动 miniredis: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 替换 Redis 客户端为测试客户端
|
|
||||||
origClient := rds.Client
|
|
||||||
rds.Client = redis.NewClient(&redis.Options{
|
|
||||||
Addr: mr.Addr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
mr.Close()
|
|
||||||
rds.Client = origClient
|
|
||||||
})
|
|
||||||
|
|
||||||
return mr
|
|
||||||
}
|
|
||||||
|
|
||||||
// 创建测试用的认证上下文
|
// 创建测试用的认证上下文
|
||||||
func createTestAuthContext() AuthContext {
|
func createTestAuthContext() AuthContext {
|
||||||
return AuthContext{
|
return AuthContext{
|
||||||
@@ -52,7 +28,7 @@ func createTestAuthContext() AuthContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_sessionService_Create(t *testing.T) {
|
func Test_sessionService_Create(t *testing.T) {
|
||||||
mr := setupTestRedis(t)
|
mr := testutil.SetupRedisTest(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
auth := createTestAuthContext()
|
auth := createTestAuthContext()
|
||||||
|
|
||||||
@@ -162,7 +138,7 @@ func Test_sessionService_Create(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_sessionService_Find(t *testing.T) {
|
func Test_sessionService_Find(t *testing.T) {
|
||||||
_ = setupTestRedis(t)
|
testutil.SetupRedisTest(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
auth := createTestAuthContext()
|
auth := createTestAuthContext()
|
||||||
s := &sessionService{}
|
s := &sessionService{}
|
||||||
@@ -221,7 +197,7 @@ func Test_sessionService_Find(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_sessionService_Refresh(t *testing.T) {
|
func Test_sessionService_Refresh(t *testing.T) {
|
||||||
mr := setupTestRedis(t)
|
mr := testutil.SetupRedisTest(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
auth := createTestAuthContext()
|
auth := createTestAuthContext()
|
||||||
s := &sessionService{}
|
s := &sessionService{}
|
||||||
@@ -314,7 +290,7 @@ func Test_sessionService_Refresh(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func Test_sessionService_Remove(t *testing.T) {
|
func Test_sessionService_Remove(t *testing.T) {
|
||||||
mr := setupTestRedis(t)
|
mr := testutil.SetupRedisTest(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
auth := createTestAuthContext()
|
auth := createTestAuthContext()
|
||||||
s := &sessionService{}
|
s := &sessionService{}
|
||||||
|
|||||||
@@ -2,30 +2,14 @@ package services
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"platform/pkg/rds"
|
"platform/pkg/testutil"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/alicebob/miniredis/v2"
|
"github.com/alicebob/miniredis/v2"
|
||||||
"github.com/redis/go-redis/v9"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// 设置测试的 Redis 环境
|
|
||||||
func setupRedisTest(t *testing.T) *miniredis.Miniredis {
|
|
||||||
mr, err := miniredis.Run()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("设置 miniredis 失败: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 替换 redis 客户端为测试客户端
|
|
||||||
rds.Client = redis.NewClient(&redis.Options{
|
|
||||||
Addr: mr.Addr(),
|
|
||||||
})
|
|
||||||
|
|
||||||
return mr
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_verifierService_SendSms(t *testing.T) {
|
func Test_verifierService_SendSms(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -82,7 +66,7 @@ func Test_verifierService_SendSms(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// 设置 Redis 测试环境
|
// 设置 Redis 测试环境
|
||||||
mr := setupRedisTest(t)
|
mr := testutil.SetupRedisTest(t)
|
||||||
defer mr.Close()
|
defer mr.Close()
|
||||||
|
|
||||||
// 执行测试前的设置
|
// 执行测试前的设置
|
||||||
@@ -216,7 +200,7 @@ func Test_verifierService_VerifySms(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// 设置 Redis 测试环境
|
// 设置 Redis 测试环境
|
||||||
mr := setupRedisTest(t)
|
mr := testutil.SetupRedisTest(t)
|
||||||
defer mr.Close()
|
defer mr.Close()
|
||||||
|
|
||||||
// 执行测试前的设置
|
// 执行测试前的设置
|
||||||
|
|||||||
Reference in New Issue
Block a user