完善 channel remove 测试用例
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
m "platform/web/models"
|
||||
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gen"
|
||||
"gorm.io/gen/field"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
@@ -26,5 +29,19 @@ func main() {
|
||||
|
||||
models := g.GenerateAllTable()
|
||||
g.ApplyBasic(models...)
|
||||
|
||||
modelChannel := g.GenerateModel("channel",
|
||||
gen.FieldRelateModel(field.BelongsTo, "Node", &m.Node{}, &field.RelateConfig{
|
||||
RelatePointer: true,
|
||||
}),
|
||||
gen.FieldRelateModel(field.BelongsTo, "User", &m.User{}, &field.RelateConfig{
|
||||
RelatePointer: true,
|
||||
}),
|
||||
gen.FieldRelateModel(field.BelongsTo, "Proxy", &m.Proxy{}, &field.RelateConfig{
|
||||
RelatePointer: true,
|
||||
}),
|
||||
)
|
||||
g.ApplyBasic(modelChannel)
|
||||
|
||||
g.Execute()
|
||||
}
|
||||
|
||||
@@ -1,78 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"platform/pkg/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"time"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
func main() {
|
||||
open, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
for i := range 3 {
|
||||
println(i)
|
||||
}
|
||||
|
||||
err = open.AutoMigrate(&m.Resource{}, &m.ResourcePss{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
q.SetDefault(open)
|
||||
|
||||
var r = &m.Resource{
|
||||
ID: 1,
|
||||
UserID: 101,
|
||||
Active: true,
|
||||
}
|
||||
open.Create(r)
|
||||
var resourcePss = &m.ResourcePss{
|
||||
ID: 1,
|
||||
ResourceID: 1,
|
||||
Type: 1,
|
||||
Live: 180,
|
||||
Expire: time.Now().AddDate(1, 0, 0),
|
||||
DailyLimit: 10000,
|
||||
}
|
||||
open.Create(resourcePss)
|
||||
|
||||
var resource = new(ResourceInfo)
|
||||
data := q.Resource.As("data")
|
||||
pss := q.ResourcePss.As("pss")
|
||||
err = data.Scopes(orm.Alias(data)).
|
||||
Select(
|
||||
data.ID, data.UserID, data.Active,
|
||||
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
|
||||
).
|
||||
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
|
||||
Where(data.ID.Eq(1)).
|
||||
Scan(&resource)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
bytes, err := json.MarshalIndent(resource, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
println(string(bytes))
|
||||
}
|
||||
|
||||
type ResourceInfo struct {
|
||||
Id int32
|
||||
UserId int32
|
||||
Active bool
|
||||
Type int32
|
||||
Live int32
|
||||
DailyLimit int32
|
||||
DailyUsed int32
|
||||
DailyLast time.Time
|
||||
Quota int32
|
||||
Used int32
|
||||
Expire time.Time
|
||||
}
|
||||
|
||||
1
go.mod
1
go.mod
@@ -40,6 +40,7 @@ require (
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/stripe/pg-schema-diff v0.9.0 // indirect
|
||||
github.com/valyala/bytebufferpool v1.0.0 // indirect
|
||||
github.com/valyala/fasthttp v1.59.0 // indirect
|
||||
github.com/yuin/gopher-lua v1.1.1 // indirect
|
||||
|
||||
3
go.sum
3
go.sum
@@ -85,6 +85,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
|
||||
github.com/stripe/pg-schema-diff v0.9.0 h1:qzm2VUdbZ2kYwqxoQqtEP3uLQI0B+ymS947zqFTZGBk=
|
||||
github.com/stripe/pg-schema-diff v0.9.0/go.mod h1:cl2VC6te/cCTOewTRvv4pYsgQqAOhvRQmatCHfYwy8c=
|
||||
github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw=
|
||||
github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc=
|
||||
github.com/valyala/fasthttp v1.59.0 h1:Qu0qYHfXvPk1mSLNqcFtEk6DpxgA26hy6bmydotDpRI=
|
||||
|
||||
@@ -67,11 +67,47 @@ func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{}, nil
|
||||
}
|
||||
|
||||
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||
mock := &MockCloudClient{}
|
||||
remote.Cloud = mock
|
||||
|
||||
return mock
|
||||
}
|
||||
|
||||
// MockGatewayClient 是GatewayClient接口的测试实现
|
||||
type MockGatewayClient struct {
|
||||
Host string
|
||||
}
|
||||
|
||||
// 确保MockGatewayClient实现了GatewayClient接口
|
||||
var _ remote.GatewayClient = (*MockGatewayClient)(nil)
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortConfigsCalls = append(testGatewayBase.PortConfigsCalls, params)
|
||||
if testGatewayBase.PortConfigsMock != nil {
|
||||
return testGatewayBase.PortConfigsMock(m, params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
testGatewayBase.mu.Lock()
|
||||
defer testGatewayBase.mu.Unlock()
|
||||
testGatewayBase.PortActiveCalls = append(testGatewayBase.PortActiveCalls, param)
|
||||
if testGatewayBase.PortActiveMock != nil {
|
||||
return testGatewayBase.PortActiveMock(m, param...)
|
||||
}
|
||||
return map[string]remote.PortData{}, nil
|
||||
}
|
||||
|
||||
type GatewayClientIns struct {
|
||||
|
||||
// 存储预期结果的字段
|
||||
PortConfigsMock func(params []remote.PortConfigsReq) error
|
||||
PortActiveMock func(param ...remote.PortActiveReq) (map[string]remote.PortData, error)
|
||||
PortConfigsMock func(c *MockGatewayClient, params []remote.PortConfigsReq) error
|
||||
PortActiveMock func(c *MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error)
|
||||
|
||||
// 记录调用历史
|
||||
PortConfigsCalls [][]remote.PortConfigsReq
|
||||
@@ -81,48 +117,14 @@ type MockGatewayClient struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// 确保MockGatewayClient实现了GatewayClient接口
|
||||
var _ remote.GatewayClient = (*MockGatewayClient)(nil)
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortConfigs(params []remote.PortConfigsReq) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PortConfigsCalls = append(m.PortConfigsCalls, params)
|
||||
if m.PortConfigsMock != nil {
|
||||
return m.PortConfigsMock(params)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MockGatewayClient) GatewayPortActive(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.PortActiveCalls = append(m.PortActiveCalls, param)
|
||||
if m.PortActiveMock != nil {
|
||||
return m.PortActiveMock(param...)
|
||||
}
|
||||
return map[string]remote.PortData{}, nil
|
||||
}
|
||||
|
||||
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||
mock := &MockCloudClient{}
|
||||
remote.Cloud = mock
|
||||
|
||||
return mock
|
||||
}
|
||||
var testGatewayBase = &GatewayClientIns{}
|
||||
|
||||
// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数
|
||||
func SetupGatewayClientMock(t *testing.T) *MockGatewayClient {
|
||||
mock := &MockGatewayClient{}
|
||||
func SetupGatewayClientMock(t *testing.T) *GatewayClientIns {
|
||||
remote.GatewayInitializer = func(url, username, password string) remote.GatewayClient {
|
||||
return mock
|
||||
return &MockGatewayClient{
|
||||
Host: url,
|
||||
}
|
||||
}
|
||||
return mock
|
||||
}
|
||||
|
||||
// NewMockGatewayClient 创建一个新的MockGatewayClient
|
||||
// 保留此函数以保持向后兼容性
|
||||
func NewMockGatewayClient() *MockGatewayClient {
|
||||
return &MockGatewayClient{}
|
||||
return testGatewayBase
|
||||
}
|
||||
|
||||
0
scripts/dev/speed.sh
Normal file
0
scripts/dev/speed.sh
Normal file
@@ -30,6 +30,9 @@ type Channel struct {
|
||||
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
|
||||
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
|
||||
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间
|
||||
Node *Node `json:"node"`
|
||||
User *User `json:"user"`
|
||||
Proxy *Proxy `json:"proxy"`
|
||||
}
|
||||
|
||||
// TableName Channel's table name
|
||||
|
||||
@@ -481,7 +481,7 @@ func assignEdge(q *q.Query, count int, filter NodeFilterConfig) (*AssignEdgeResu
|
||||
Province: filter.Prov,
|
||||
City: filter.City,
|
||||
Isp: filter.Isp,
|
||||
Count: int(math.Ceil(float64(info.used) * 11 / 10)),
|
||||
Count: int(math.Ceil(float64(info.used) * 2)),
|
||||
}
|
||||
var newConfigs []remote.AutoConfig
|
||||
var update = false
|
||||
@@ -596,10 +596,11 @@ func assignPort(
|
||||
Edge: nil,
|
||||
Status: true,
|
||||
AutoEdgeConfig: &remote.AutoEdgeConfig{
|
||||
Province: filter.Prov,
|
||||
City: filter.City,
|
||||
Isp: filter.Isp,
|
||||
Count: v.P(1),
|
||||
Province: filter.Prov,
|
||||
City: filter.City,
|
||||
Isp: filter.Isp,
|
||||
Count: v.P(1),
|
||||
PacketLoss: 30,
|
||||
},
|
||||
})
|
||||
|
||||
@@ -704,7 +705,6 @@ type PortInfo struct {
|
||||
// endregion
|
||||
|
||||
func genPassPair() (string, string) {
|
||||
var letters = []rune("abcdefghjkmnpqrstuvwxyz23456789")
|
||||
var alphabet = []rune("abcdefghjkmnpqrstuvwxyz")
|
||||
var numbers = []rune("23456789")
|
||||
|
||||
@@ -716,7 +716,7 @@ func genPassPair() (string, string) {
|
||||
} else {
|
||||
username[i] = numbers[rand.N(len(numbers))]
|
||||
}
|
||||
password[i] = letters[rand.N(len(letters))]
|
||||
password[i] = numbers[rand.N(len(numbers))]
|
||||
}
|
||||
|
||||
return string(username), string(password)
|
||||
|
||||
@@ -4,11 +4,10 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/remote"
|
||||
"platform/pkg/testutil"
|
||||
"platform/web/models"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -272,7 +271,6 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
db := testutil.SetupDBTest(t)
|
||||
mc := testutil.SetupCloudClientMock(t)
|
||||
env.DebugExternalChange = false
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
@@ -338,10 +336,9 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
name string
|
||||
args args
|
||||
setup func()
|
||||
want []*PortInfo
|
||||
wantErr bool
|
||||
wantErrContains string
|
||||
checkCache func(channels []models.Channel) error
|
||||
want func(t *testing.T, got []*PortInfo) error
|
||||
}{
|
||||
{
|
||||
name: "用户创建HTTP密码通道",
|
||||
@@ -354,12 +351,58 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
count: 3,
|
||||
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
|
||||
},
|
||||
want: []*PortInfo{
|
||||
{
|
||||
Proto: "http",
|
||||
Host: proxy.Host,
|
||||
Port: 10000,
|
||||
},
|
||||
|
||||
want: func(t *testing.T, got []*PortInfo) error {
|
||||
// 验证返回结果
|
||||
if len(got) == 0 {
|
||||
return fmt.Errorf("返回的 PortInfo 不应为空")
|
||||
}
|
||||
|
||||
// 验证协议正确
|
||||
for _, port := range got {
|
||||
if port.Proto != "http" {
|
||||
return fmt.Errorf("期望协议为 http,得到 %s", port.Proto)
|
||||
}
|
||||
if port.Host != proxy.Host {
|
||||
return fmt.Errorf("期望主机为 %s,得到 %s", proxy.Host, port.Host)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证数据库字段
|
||||
var channels []*models.Channel
|
||||
db.Where("user_id = ? AND proxy_id = ?", userAuth.Payload.Id, proxy.ID).Find(&channels)
|
||||
for _, ch := range channels {
|
||||
if ch.Protocol != "http" {
|
||||
return fmt.Errorf("通道协议不正确,期望 http,得到 %s", ch.Protocol)
|
||||
}
|
||||
if ch.UserID != userAuth.Payload.Id {
|
||||
return fmt.Errorf("通道用户ID不正确,期望 %d,得到 %d", userAuth.Payload.Id, ch.UserID)
|
||||
}
|
||||
if ch.ProxyID != proxy.ID {
|
||||
return fmt.Errorf("通道代理ID不正确,期望 %d,得到 %d", proxy.ID, ch.ProxyID)
|
||||
}
|
||||
|
||||
// 检查Redis缓存中的字段
|
||||
key := fmt.Sprintf("channel:%d", ch.ID)
|
||||
if !mr.Exists(key) {
|
||||
return fmt.Errorf("Redis缓存中应有键 %s", key)
|
||||
}
|
||||
|
||||
data, _ := mr.Get(key)
|
||||
var cachedChannel models.Channel
|
||||
err := json.Unmarshal([]byte(data), &cachedChannel)
|
||||
if err != nil {
|
||||
return fmt.Errorf("无法解析缓存数据: %v", err)
|
||||
}
|
||||
|
||||
if cachedChannel.ID != ch.ID {
|
||||
return fmt.Errorf("缓存ID不正确,期望 %d,得到 %d", ch.ID, cachedChannel.ID)
|
||||
}
|
||||
if cachedChannel.Protocol != ch.Protocol {
|
||||
return fmt.Errorf("缓存协议不正确,期望 %s,得到 %s", ch.Protocol, cachedChannel.Protocol)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -372,7 +415,9 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
authType: ChannelAuthTypeIp,
|
||||
count: 2,
|
||||
},
|
||||
want: []*PortInfo{},
|
||||
want: func(t *testing.T, got []*PortInfo) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "管理员创建SOCKS5密码通道",
|
||||
@@ -384,7 +429,9 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
authType: ChannelAuthTypePass,
|
||||
count: 2,
|
||||
},
|
||||
want: []*PortInfo{},
|
||||
want: func(t *testing.T, got []*PortInfo) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "套餐不存在",
|
||||
@@ -504,34 +551,10 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查返回值
|
||||
if reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("CreateChannel() 返回值 = %v, 期望 %v", got, tt.want)
|
||||
}
|
||||
|
||||
// 查询创建的通道
|
||||
var channels []models.Channel
|
||||
db.Where(
|
||||
"user_id = ? and proxy_id = ?",
|
||||
userAuth.Payload.Id, proxy.ID,
|
||||
).Find(&channels)
|
||||
|
||||
if len(channels) != 2 {
|
||||
t.Errorf("期望创建2个通道,但是创建了%d个", len(channels))
|
||||
}
|
||||
|
||||
// 检查Redis缓存
|
||||
for _, ch := range channels {
|
||||
key := fmt.Sprintf("channel:%d", ch.ID)
|
||||
if !mr.Exists(key) {
|
||||
t.Errorf("Redis缓存中应有键 %s", key)
|
||||
}
|
||||
}
|
||||
|
||||
if tt.checkCache != nil {
|
||||
var err = tt.checkCache(channels)
|
||||
if err != nil {
|
||||
t.Errorf("检查缓存失败: %v", err)
|
||||
// 使用检查函数验证结果
|
||||
if tt.want != nil {
|
||||
if err := tt.want(t, got); err != nil {
|
||||
t.Errorf("结果验证失败: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -540,9 +563,9 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
|
||||
func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
mr := testutil.SetupRedisTest(t)
|
||||
db := testutil.SetupDBTest(t)
|
||||
md := testutil.SetupDBTest(t)
|
||||
mg := testutil.SetupGatewayClientMock(t)
|
||||
env.DebugExternalChange = false
|
||||
mc := testutil.SetupCloudClientMock(t)
|
||||
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
@@ -552,178 +575,309 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
|
||||
// 准备测试数据
|
||||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||
|
||||
// 创建用户
|
||||
var user = &models.User{
|
||||
ID: 101,
|
||||
Phone: "12312341234",
|
||||
}
|
||||
md.Create(user)
|
||||
|
||||
// 创建管理员
|
||||
var adminUser = &models.User{
|
||||
ID: 100,
|
||||
Phone: "99999999999",
|
||||
}
|
||||
md.Create(adminUser)
|
||||
|
||||
// 认证上下文
|
||||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||||
|
||||
// 创建代理
|
||||
var proxy = &models.Proxy{
|
||||
ID: 1,
|
||||
Version: 1,
|
||||
Name: "test-proxy",
|
||||
Host: "111.111.111.111",
|
||||
Type: 1,
|
||||
Secret: "test:secret",
|
||||
}
|
||||
md.Create(proxy)
|
||||
|
||||
var proxy2 = &models.Proxy{
|
||||
ID: 2,
|
||||
Version: 1,
|
||||
Name: "test-proxy-2",
|
||||
Host: "222.222.222.222",
|
||||
Type: 1,
|
||||
Secret: "test:secret2",
|
||||
}
|
||||
md.Create(proxy2)
|
||||
|
||||
// 清空数据库函数
|
||||
var clearDb = func() {
|
||||
md.Exec("delete from channel where true")
|
||||
mr.FlushAll()
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
setup func()
|
||||
wantErr bool
|
||||
wantErrContains string
|
||||
checkCache func(t *testing.T)
|
||||
want func(t *testing.T) error
|
||||
}{
|
||||
{
|
||||
name: "管理员删除多个通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadAdmin,
|
||||
Id: 1,
|
||||
},
|
||||
},
|
||||
id: []int32{1, 2, 3},
|
||||
ctx: ctx,
|
||||
auth: adminAuth,
|
||||
id: []int32{1, 2, 3},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
for _, id := range []int32{1, 2, 3} {
|
||||
key := fmt.Sprintf("channel:%d", id)
|
||||
channel := models.Channel{ID: id, UserID: 100}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
}
|
||||
|
||||
// 清空数据库表
|
||||
db.Exec("delete from channel")
|
||||
db.Exec("delete from proxy")
|
||||
|
||||
// 创建代理
|
||||
proxies := []models.Proxy{
|
||||
{ID: 1, Name: "proxy1", Host: "proxy1.example.com", Secret: "key:secret", Type: 1},
|
||||
{ID: 2, Name: "proxy2", Host: "proxy2.example.com", Secret: "key:secret", Type: 1},
|
||||
}
|
||||
for _, p := range proxies {
|
||||
db.Create(&p)
|
||||
}
|
||||
clearDb()
|
||||
|
||||
// 创建通道
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 100, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 100, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
}
|
||||
for _, c := range channels {
|
||||
db.Create(&c)
|
||||
|
||||
// 保存预设数据
|
||||
md.Create(channels)
|
||||
for _, channel := range channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
data, _ := json.Marshal(channel)
|
||||
_ = mr.Set(key, string(data))
|
||||
}
|
||||
|
||||
// 模拟网关客户端的响应
|
||||
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
switch {
|
||||
case m.Host == proxy.Host:
|
||||
return map[string]remote.PortData{
|
||||
"10001": {Edge: []string{"edge1", "edge4"}},
|
||||
"10002": {Edge: []string{"edge2"}},
|
||||
}, nil
|
||||
case m.Host == proxy2.Host:
|
||||
return map[string]remote.PortData{
|
||||
"10001": {Edge: []string{"edge3"}},
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||||
}
|
||||
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error {
|
||||
switch {
|
||||
case m.Host == proxy.Host:
|
||||
for _, param := range params {
|
||||
if param.Port != 10001 && param.Port != 10002 {
|
||||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||||
}
|
||||
if param.Status != false {
|
||||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||||
}
|
||||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
}
|
||||
case m.Host == proxy2.Host:
|
||||
for _, param := range params {
|
||||
if param.Port != 10001 {
|
||||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||||
}
|
||||
if param.Status != false {
|
||||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||||
}
|
||||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||||
}
|
||||
mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) {
|
||||
switch {
|
||||
case param.Uuid == proxy.Name:
|
||||
var edges = []string{"edge1", "edge2", "edge4"}
|
||||
if !slices.Equal(edges, param.Edge) {
|
||||
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
if len(param.Config) != 0 {
|
||||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||||
}
|
||||
return len(param.Edge), nil
|
||||
case param.Uuid == proxy2.Name:
|
||||
var edges = []string{"edge3"}
|
||||
if !slices.Equal(edges, param.Edge) {
|
||||
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
if len(param.Config) != 0 {
|
||||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||||
}
|
||||
return len(param.Edge), nil
|
||||
}
|
||||
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||||
}
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
want: func(t *testing.T) error {
|
||||
// 检查通道是否被软删除
|
||||
var count int64
|
||||
db.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
|
||||
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
|
||||
if count > 0 {
|
||||
t.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
|
||||
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
|
||||
}
|
||||
|
||||
// 检查Redis缓存是否被删除
|
||||
for _, id := range []int32{1, 2, 3} {
|
||||
key := fmt.Sprintf("channel:%d", id)
|
||||
if mr.Exists(key) {
|
||||
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "用户删除自己的通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
id: []int32{1},
|
||||
ctx: ctx,
|
||||
auth: userAuth,
|
||||
id: []int32{1, 2, 3},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
key := "channel:1"
|
||||
channel := models.Channel{ID: 1, UserID: 100}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
|
||||
// 清空数据库表
|
||||
db.Exec("delete from channel")
|
||||
db.Exec("delete from proxy")
|
||||
|
||||
// 创建代理
|
||||
proxy := models.Proxy{
|
||||
ID: 1,
|
||||
Name: "proxy1",
|
||||
Host: "proxy1.example.com",
|
||||
Secret: "key:secret",
|
||||
Type: 1,
|
||||
}
|
||||
db.Create(&proxy)
|
||||
clearDb()
|
||||
|
||||
// 创建通道
|
||||
ch := models.Channel{
|
||||
ID: 1,
|
||||
UserID: 100,
|
||||
ProxyID: 1,
|
||||
ProxyPort: 10001,
|
||||
Protocol: "http",
|
||||
Expiration: time.Now().Add(24 * time.Hour),
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
}
|
||||
db.Create(&ch)
|
||||
|
||||
// 模拟查询已激活的端口
|
||||
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
return map[string]remote.PortData{
|
||||
"10001": {
|
||||
Edge: []string{"edge1", "edge2"},
|
||||
},
|
||||
}, nil
|
||||
// 保存预设数据
|
||||
md.Create(channels)
|
||||
for _, channel := range channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
data, _ := json.Marshal(channel)
|
||||
_ = mr.Set(key, string(data))
|
||||
}
|
||||
|
||||
// 模拟网关客户端的响应
|
||||
mg.PortActiveMock = func(m *testutil.MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
switch {
|
||||
case m.Host == proxy.Host:
|
||||
return map[string]remote.PortData{
|
||||
"10001": {Edge: []string{"edge1", "edge4"}},
|
||||
"10002": {Edge: []string{"edge2"}},
|
||||
}, nil
|
||||
case m.Host == proxy2.Host:
|
||||
return map[string]remote.PortData{
|
||||
"10001": {Edge: []string{"edge3"}},
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||||
}
|
||||
mg.PortConfigsMock = func(m *testutil.MockGatewayClient, params []remote.PortConfigsReq) error {
|
||||
switch {
|
||||
case m.Host == proxy.Host:
|
||||
for _, param := range params {
|
||||
if param.Port != 10001 && param.Port != 10002 {
|
||||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||||
}
|
||||
if param.Status != false {
|
||||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||||
}
|
||||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
}
|
||||
case m.Host == proxy2.Host:
|
||||
for _, param := range params {
|
||||
if param.Port != 10001 {
|
||||
return fmt.Errorf("端口配置不符合预期: %d", param.Port)
|
||||
}
|
||||
if param.Status != false {
|
||||
return fmt.Errorf("端口状态不符合预期: %v", param.Status)
|
||||
}
|
||||
if param.Edge == nil || len(*param.Edge) != 0 {
|
||||
return fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("代理主机不符合预期: %s", m.Host)
|
||||
}
|
||||
mc.DisconnectMock = func(param remote.CloudDisconnectReq) (int, error) {
|
||||
switch {
|
||||
case param.Uuid == proxy.Name:
|
||||
var edges = []string{"edge1", "edge2", "edge4"}
|
||||
if !slices.Equal(edges, param.Edge) {
|
||||
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
if len(param.Config) != 0 {
|
||||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||||
}
|
||||
return len(param.Edge), nil
|
||||
case param.Uuid == proxy2.Name:
|
||||
var edges = []string{"edge3"}
|
||||
if !slices.Equal(edges, param.Edge) {
|
||||
return 0, fmt.Errorf("边缘节点不符合预期: %v", param.Edge)
|
||||
}
|
||||
if len(param.Config) != 0 {
|
||||
return 0, fmt.Errorf("配置不符合预期: %v", param.Config)
|
||||
}
|
||||
return len(param.Edge), nil
|
||||
}
|
||||
return 0, fmt.Errorf("代理名称不符合预期: %s", param.Uuid)
|
||||
}
|
||||
},
|
||||
checkCache: func(t *testing.T) {
|
||||
want: func(t *testing.T) error {
|
||||
// 检查通道是否被软删除
|
||||
var count int64
|
||||
db.Model(&models.Channel{}).Where("id = ? AND deleted_at IS NULL", 1).Count(&count)
|
||||
md.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
|
||||
if count > 0 {
|
||||
t.Errorf("应该软删除了通道,但仍未删除")
|
||||
return fmt.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
|
||||
}
|
||||
|
||||
// 检查Redis缓存是否被删除
|
||||
key := "channel:1"
|
||||
if mr.Exists(key) {
|
||||
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
for _, id := range []int32{1, 2, 3} {
|
||||
key := fmt.Sprintf("channel:%d", id)
|
||||
if mr.Exists(key) {
|
||||
return fmt.Errorf("通道缓存 %s 应被删除但仍存在", key)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "用户删除不属于自己的通道",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
auth: &AuthContext{
|
||||
Payload: Payload{
|
||||
Type: PayloadUser,
|
||||
Id: 100,
|
||||
},
|
||||
},
|
||||
id: []int32{5},
|
||||
ctx: ctx,
|
||||
auth: userAuth,
|
||||
id: []int32{1, 2, 3},
|
||||
},
|
||||
setup: func() {
|
||||
// 预设 Redis 缓存
|
||||
mr.FlushAll()
|
||||
key := "channel:5"
|
||||
channel := models.Channel{ID: 5, UserID: 101}
|
||||
data, _ := json.Marshal(channel)
|
||||
mr.Set(key, string(data))
|
||||
clearDb()
|
||||
|
||||
// 清空数据库表
|
||||
db.Exec("delete from channel")
|
||||
|
||||
// 创建一个属于用户101的通道
|
||||
ch := models.Channel{
|
||||
ID: 5,
|
||||
UserID: 101,
|
||||
ProxyID: 1,
|
||||
ProxyPort: 10005,
|
||||
Protocol: "http",
|
||||
Expiration: time.Now().Add(24 * time.Hour),
|
||||
// 创建通道
|
||||
channels := []models.Channel{
|
||||
{ID: 1, UserID: 101, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 2, UserID: 101, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
{ID: 3, UserID: 102, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
|
||||
}
|
||||
|
||||
// 保存预设数据
|
||||
md.Create(channels)
|
||||
for _, channel := range channels {
|
||||
key := fmt.Sprintf("channel:%d", channel.ID)
|
||||
data, _ := json.Marshal(channel)
|
||||
_ = mr.Set(key, string(data))
|
||||
}
|
||||
db.Create(&ch)
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "无权限访问",
|
||||
@@ -756,9 +910,10 @@ func Test_channelService_RemoveChannels(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查 Redis 缓存是否正确设置
|
||||
if tt.checkCache != nil {
|
||||
tt.checkCache(t)
|
||||
// 检查数据库和缓存是否正确设置
|
||||
want := tt.want(t)
|
||||
if tt.want(t) != nil {
|
||||
t.Errorf("RemoveChannels() 结果验证失败: %v", want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user