GET类型通道创建端点;修改完善返回格式处理逻辑;动态刷新remote令牌

This commit is contained in:
2025-04-02 16:08:55 +08:00
parent 1b8e118fae
commit 13794c2d27
12 changed files with 639 additions and 673 deletions

View File

@@ -2,7 +2,11 @@ package handlers
import (
"errors"
"fmt"
"log/slog"
q "platform/web/queries"
"platform/web/services"
"strconv"
"strings"
"github.com/gofiber/fiber/v2"
@@ -11,15 +15,16 @@ import (
// region CreateChannel
type CreateChannelReq struct {
ResourceId int32 `json:"resource_id" validate:"required"`
Protocol services.ChannelProtocol `json:"protocol" validate:"required,oneof=socks5 http https"`
AuthType services.ChannelAuthType `json:"auth_type" validate:"required,oneof=0 1"`
Count int `json:"count" validate:"required"`
Prov string `json:"prov" validate:"required"`
City string `json:"city" validate:"required"`
Isp string `json:"isp" validate:"required"`
ResultType CreateChannelResultType `json:"result_type" validate:"required,oneof=json text"`
ResultSeparator CreateChannelResultSeparator `json:"result_separator" validate:"required,oneof=enter line both tab"`
ResourceId int32 `json:"resource_id" validate:"required"`
Protocol services.ChannelProtocol `json:"protocol" validate:"required,oneof=socks5 http https"`
AuthType services.ChannelAuthType `json:"auth_type" validate:"required,oneof=0 1"`
Count int `json:"count" validate:"required"`
Prov string `json:"prov" validate:"required"`
City string `json:"city" validate:"required"`
Isp string `json:"isp" validate:"required"`
ResultType CreateChannelResultType `json:"result_type" validate:"required,oneof=json text"`
ResultBreaker []rune `json:"result_breaker" validate:""`
ResultSeparator []rune `json:"result_separator" validate:""`
}
func CreateChannel(c *fiber.Ctx) error {
@@ -28,6 +33,16 @@ func CreateChannel(c *fiber.Ctx) error {
return err
}
if req.ResultType == "" {
req.ResultType = CreateChannelResultTypeText
}
if req.ResultBreaker == nil {
req.ResultBreaker = []rune("\r\n")
}
if req.ResultSeparator == nil {
req.ResultSeparator = []rune("|")
}
// 建立连接通道
auth, ok := c.Locals("auth").(*services.AuthContext)
if !ok {
@@ -51,25 +66,35 @@ func CreateChannel(c *fiber.Ctx) error {
return err
}
var separator = string(req.ResultSeparator)
switch req.ResultType {
case CreateChannelResultTypeJson:
return c.JSON(fiber.Map{
"result": result,
"code": 1,
"data": result,
})
case CreateChannelResultTypeText:
switch req.ResultSeparator {
case CreateChannelResultSeparatorEnter:
return c.SendString(strings.Join(result, "\r"))
case CreateChannelResultSeparatorLine:
return c.SendString(strings.Join(result, "\n"))
case CreateChannelResultSeparatorBoth:
return c.SendString(strings.Join(result, "\r\n"))
case CreateChannelResultSeparatorTab:
return c.SendString(strings.Join(result, "\t"))
}
}
default:
var breaker = string(req.ResultBreaker)
var str = strings.Builder{}
for _, info := range result {
return errors.New("无效的返回类型")
str.WriteString(info.Host)
str.WriteString(separator)
str.WriteString(strconv.Itoa(info.Port))
if info.Username != nil {
str.WriteString(separator)
str.WriteString(*info.Username)
}
if info.Password != nil {
str.WriteString(separator)
str.WriteString(*info.Password)
}
str.WriteString(breaker)
}
return c.SendString(str.String())
}
}
type CreateChannelResultType string
@@ -79,15 +104,6 @@ const (
CreateChannelResultTypeText CreateChannelResultType = "text"
)
type CreateChannelResultSeparator string
const (
CreateChannelResultSeparatorEnter CreateChannelResultSeparator = "enter"
CreateChannelResultSeparatorLine CreateChannelResultSeparator = "line"
CreateChannelResultSeparatorBoth CreateChannelResultSeparator = "both"
CreateChannelResultSeparatorTab CreateChannelResultSeparator = "tab"
)
// endregion
// region RemoveChannels
@@ -118,3 +134,124 @@ func RemoveChannels(c *fiber.Ctx) error {
}
// endregion
// region CreateChannel(GET)
type CreateChannelGetReq struct {
ResourceId int32 `query:"i" validate:"required"`
Protocol services.ChannelProtocol `query:"x" validate:"required,oneof=socks5 http https"`
AuthType services.ChannelAuthType `query:"t" validate:"required,oneof=0 1"`
Count int `query:"n" validate:"required"`
Prov string `query:"a" validate:"required"`
City string `query:"b" validate:"required"`
Isp string `query:"s" validate:"required"`
ResultType CreateChannelResultType `query:"rt" validate:"required,oneof=json text"`
ResultBreaker []rune `query:"rb"`
ResultSeparator []rune `query:"rs"`
}
func CreateChannelGet(c *fiber.Ctx) error {
req := new(CreateChannelGetReq)
if err := c.QueryParser(req); err != nil {
return err
}
slog.Info("CreateChannelGet", "req", *req)
// 验证用户身份
resource, err := q.Resource.Debug().Where(q.Resource.ID.Eq(req.ResourceId)).Take()
if err != nil {
return err
}
whitelists, err := q.Whitelist.Debug().Where(q.Whitelist.UserID.Eq(resource.UserID)).Find()
if err != nil {
return err
}
if len(whitelists) == 0 {
return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("forbidden %s", c.IP()))
}
var invalid bool
for _, whitelist := range whitelists {
invalid = whitelist.Host == c.IP()
if invalid {
break
}
}
if !invalid {
return fiber.NewError(fiber.StatusForbidden, fmt.Sprintf("forbidden %s", c.IP()))
}
user, err := q.User.Debug().Where(q.User.ID.Eq(resource.UserID)).Take()
if err != nil {
return err
}
auth := &services.AuthContext{
Payload: services.Payload{
Id: user.ID,
Type: services.PayloadUser,
Name: user.Name,
Avatar: user.Avatar,
},
}
if req.ResultType == "" {
req.ResultType = CreateChannelResultTypeText
}
if req.ResultBreaker == nil {
req.ResultBreaker = []rune("\r\n")
}
if req.ResultSeparator == nil {
req.ResultSeparator = []rune("|")
}
// 建立连接通道
result, err := services.Channel.CreateChannel(
c.Context(),
auth,
req.ResourceId,
req.Protocol,
req.AuthType,
req.Count,
services.NodeFilterConfig{
Isp: req.Isp,
Prov: req.Prov,
City: req.City,
},
)
if err != nil {
return err
}
var separator = string(req.ResultSeparator)
switch req.ResultType {
case CreateChannelResultTypeJson:
return c.JSON(fiber.Map{
"code": 1,
"data": result,
})
default:
var breaker = string(req.ResultBreaker)
var str = strings.Builder{}
for _, info := range result {
str.WriteString(info.Host)
str.WriteString(separator)
str.WriteString(strconv.Itoa(info.Port))
if info.Username != nil {
str.WriteString(separator)
str.WriteString(*info.Username)
}
if info.Password != nil {
str.WriteString(separator)
str.WriteString(*info.Password)
}
str.WriteString(breaker)
}
return c.SendString(str.String())
}
}
// endregion

View File

@@ -20,5 +20,7 @@ func ApplyRouters(app *fiber.App) {
channel.Post("/create", PermitAll(), handlers.CreateChannel)
channel.Post("/remove", PermitAll(), handlers.RemoveChannels)
// 临时
app.Get("/collect", handlers.CreateChannelGet)
app.Get("/temp", handlers.Temp)
}

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"log/slog"
"math"
"math/rand/v2"
"platform/pkg/env"
"platform/pkg/orm"
"platform/pkg/rds"
@@ -20,8 +21,6 @@ import (
"time"
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
@@ -242,7 +241,7 @@ func (s *channelService) CreateChannel(
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]string, error) {
) ([]*PortInfo, error) {
var step = time.Now()
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
@@ -251,7 +250,7 @@ func (s *channelService) CreateChannel(
filter = nodeFilter[0]
}
var addr []string
var addr []*PortInfo
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
@@ -522,7 +521,7 @@ func assignPort(
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]string, []*models.Channel, error) {
) ([]*PortInfo, []*models.Channel, error) {
var step time.Time
var configs = proxies.configs
@@ -548,7 +547,7 @@ func assignPort(
}
// 配置启用代理
var result []string
var result []*PortInfo
var channels []*models.Channel
for _, config := range configs {
var err error
@@ -595,9 +594,14 @@ func assignPort(
Expiration: expiration,
})
}
result = append(result, &PortInfo{
Proto: string(protocol),
Host: proxy.Host,
Port: port,
})
case ChannelAuthTypePass:
username, password := genPassPair()
configs[i].Whitelist = new([]string)
configs[i].Whitelist = &[]string{}
configs[i].Userpass = v.P(fmt.Sprintf("%s:%s", username, password))
channels = append(channels, &models.Channel{
UserID: userId,
@@ -610,9 +614,14 @@ func assignPort(
Protocol: string(protocol),
Expiration: expiration,
})
result = append(result, &PortInfo{
Proto: string(protocol),
Host: proxy.Host,
Port: port,
Username: &username,
Password: &password,
})
}
result = append(result, fmt.Sprintf("%s://%s:%d", protocol, proxy.Host, port))
}
if len(configs) < count {
@@ -659,20 +668,34 @@ func assignPort(
return result, channels, nil
}
type PortInfo struct {
Proto string `json:"-"`
Host string `json:"host"`
Port int `json:"port"`
Username *string `json:"username,omitempty"`
Password *string `json:"password,omitempty"`
}
// endregion
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
var letters = []rune("abcdefghjkmnpqrstuvwxyz23456789")
var alphabet = []rune("abcdefghjkmnpqrstuvwxyz")
var numbers = []rune("23456789")
var username = make([]rune, 6)
var password = make([]rune, 6)
for i := range 6 {
if i < 2 {
username[i] = alphabet[rand.N(len(alphabet))]
} else {
username[i] = numbers[rand.N(len(numbers))]
}
password[i] = letters[rand.N(len(letters))]
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
// return string(username), string(password)
return "123123", "123123"
}
func chKey(channel *models.Channel) string {

View File

@@ -8,14 +8,11 @@ import (
"platform/pkg/remote"
"platform/pkg/testutil"
"platform/web/models"
"regexp"
"strings"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/gofiber/fiber/v2/middleware/requestid"
"gorm.io/gorm"
)
func Test_genPassPair(t *testing.T) {
@@ -272,7 +269,7 @@ func Test_deleteCache(t *testing.T) {
func Test_channelService_CreateChannel(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
db := testutil.SetupDBTest(t)
mc := testutil.SetupCloudClientMock(t)
env.DebugExternalChange = false
@@ -288,6 +285,54 @@ func Test_channelService_CreateChannel(t *testing.T) {
// 准备测试数据
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
var user = &models.User{
ID: 101,
Phone: "12312341234",
}
db.Create(user)
var whitelists = []*models.Whitelist{
{ID: 1, UserID: 101, Host: "123.123.123.123"},
{ID: 2, UserID: 101, Host: "456.456.456.456"},
{ID: 3, UserID: 101, Host: "789.789.789.789"},
}
db.Create(whitelists)
var resource = &models.Resource{
ID: 1,
UserID: 101,
Active: true,
}
db.Create(resource)
var resourcePss = &models.ResourcePss{
ID: 1,
ResourceID: 1,
Type: 1,
Live: 180,
Expire: time.Now().AddDate(1, 0, 0),
DailyLimit: 10000,
}
db.Create(resourcePss)
var proxy = &models.Proxy{
ID: 1,
Version: 1,
Name: "test-proxy",
Host: "111.111.111.111",
Type: 1,
Secret: "test:secret",
}
db.Create(proxy)
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"test-proxy": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
var clearDb = func() {
db.Exec("delete from channel where true")
db.Exec("update resource_pss set daily_used = 0, daily_last = null, used = 0 where true")
}
tests := []struct {
name string
args args
@@ -295,320 +340,78 @@ func Test_channelService_CreateChannel(t *testing.T) {
want []string
wantErr bool
wantErrContains string
checkCache func(t *testing.T)
checkCache func(channels []models.Channel) error
}{
{
name: "用户创建HTTP密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{Payload: Payload{Type: PayloadUser, Id: 100}},
resourceId: 4,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypePass,
count: 3,
nodeFilter: []NodeFilterConfig{{Prov: "河南", City: "郑州", Isp: "电信"}},
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
4, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(4)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道 - PostgreSQL返回ID
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(4).AddRow(5).AddRow(6),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
"http://proxy3.example.com:10002",
},
checkCache: func(t *testing.T) {
// 检查总共创建了3个通道
for i := 4; i <= 6; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
"http://111.111.111.111:10000",
"http://111.111.111.111:10001",
"http://111.111.111.111:10002",
},
},
{
name: "用户创建HTTP白名单通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 5,
ctx: ctx,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy3": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
5, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(5)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(3, "proxy3", "proxy3.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟查询白名单 - 3个IP
whitelistRows := sqlmock.NewRows([]string{"host"}).
AddRow("192.168.1.1").
AddRow("192.168.1.2").
AddRow("192.168.1.3")
mdb.ExpectQuery("SELECT").
WithArgs(int32(100)).
WillReturnRows(whitelistRows)
// 模拟保存通道 - 2个通道 * 3个白名单 = 6个
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).
AddRow(7).AddRow(8).AddRow(9).
AddRow(10).AddRow(11).AddRow(12),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"http://proxy3.example.com:10000",
"http://proxy3.example.com:10001",
},
checkCache: func(t *testing.T) {
// 检查应该创建了6个通道2个通道 * 3个白名单
for i := 7; i <= 12; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
"http://111.111.111.111:10000",
"http://111.111.111.111:10001",
},
},
{
name: "管理员创建SOCKS5密码通道",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadAdmin,
Id: 1,
},
},
resourceId: 6,
ctx: ctx,
auth: adminAuth,
resourceId: 1,
protocol: ProtocolSocks5,
authType: ChannelAuthTypePass,
count: 2,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy4": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 5},
},
}, nil
}
// 设置CloudConnect的模拟逻辑
mc.ConnectMock = func(param remote.CloudConnectReq) error {
return nil
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
6, 102, true,
1, 86400, 0, 100, time.Now(), 0, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(6)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(4, "proxy4", "proxy4.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟查询通道
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 模拟保存通道
mdb.ExpectQuery("INSERT INTO").WillReturnRows(
sqlmock.NewRows([]string{"id"}).AddRow(13).AddRow(14),
)
// 模拟更新套餐使用记录
mdb.ExpectExec("UPDATE").WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
want: []string{
"socks5://proxy4.example.com:10000",
"socks5://proxy4.example.com:10001",
},
checkCache: func(t *testing.T) {
for i := 13; i <= 14; i++ {
key := fmt.Sprintf("channel:%d", i)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
"socks5://111.111.111.111:10000",
"socks5://111.111.111.111:10001",
},
},
{
name: "套餐不存在",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
ctx: ctx,
auth: userAuth,
resourceId: 999,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐不存在
mdb.ExpectQuery("SELECT").WithArgs(int32(999)).WillReturnError(gorm.ErrRecordNotFound)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "套餐不存在",
},
{
name: "套餐没有权限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 101,
},
},
resourceId: 7,
ctx: ctx,
auth: userAuth,
resourceId: 2,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
7, 102, true, // 注意user_id 与 auth.Id 不匹配
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(7)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "无权限访问",
},
@@ -628,24 +431,22 @@ func Test_channelService_CreateChannel(t *testing.T) {
count: 10,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
2, 100, true,
0, 86400, 95, 100, time.Now(), 100, 95, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(2)).WillReturnRows(resourceRows)
// 回滚事务
mdb.ExpectRollback()
// 创建一个配额几乎用完的资源包
resource2 := models.Resource{
ID: 2,
UserID: 101,
Active: true,
}
resourcePss2 := models.ResourcePss{
ID: 1,
ResourceID: 1,
Type: 2,
Quota: 100,
Used: 91,
Live: 180,
DailyLimit: 10000,
}
db.Create(&resource2).Create(&resourcePss2)
},
wantErr: true,
wantErrContains: "套餐配额不足",
@@ -653,62 +454,23 @@ func Test_channelService_CreateChannel(t *testing.T) {
{
name: "端口数量达到上限",
args: args{
ctx: ctx,
auth: &AuthContext{
Payload: Payload{
Type: PayloadUser,
Id: 100,
},
},
resourceId: 8,
ctx: ctx,
auth: userAuth,
resourceId: 1,
protocol: ProtocolHTTP,
authType: ChannelAuthTypeIp,
count: 1,
},
setup: func() {
// 清空Redis
mr.FlushAll()
// 设置CloudAutoQuery的模拟返回
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
return remote.CloudConnectResp{
"proxy5": []remote.AutoConfig{
{Province: "河南", City: "郑州", Isp: "电信", Count: 10},
},
}, nil
// 创建大量占用端口的通道
for i := 10000; i < 20000; i++ {
channel := models.Channel{
ProxyID: 1,
ProxyPort: int32(i),
UserID: 101,
}
db.Create(&channel)
}
// 开始事务
mdb.ExpectBegin()
// 模拟查询套餐
resourceRows := sqlmock.NewRows([]string{
"id", "user_id", "active",
"type", "live", "daily_used", "daily_limit", "daily_last", "quota", "used", "expire",
}).AddRow(
8, 100, true,
0, 86400, 0, 100, time.Now(), 1000, 0, time.Now().Add(24*time.Hour),
)
mdb.ExpectQuery("SELECT").WithArgs(int32(8)).WillReturnRows(resourceRows)
// 模拟查询代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(5, "proxy5", "proxy5.example.com", "key:secret", 1)
mdb.ExpectQuery("SELECT").
WithArgs(1).
WillReturnRows(proxyRows)
// 模拟通道端口已用尽
// 构建一个大量已使用端口的结果集
channelRows := sqlmock.NewRows([]string{"proxy_id", "proxy_port"})
for i := 10000; i < 65535; i++ {
channelRows.AddRow(5, i)
}
mdb.ExpectQuery("SELECT").
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
},
wantErr: true,
wantErrContains: "端口数量不足",
@@ -717,6 +479,8 @@ func Test_channelService_CreateChannel(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mr.FlushAll()
clearDb()
if tt.setup != nil {
tt.setup()
}
@@ -754,14 +518,30 @@ func Test_channelService_CreateChannel(t *testing.T) {
}
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
// 查询创建的通道
var channels []models.Channel
db.Where(
"user_id = ? and proxy_id = ?",
userAuth.Payload.Id, proxy.ID,
).Find(&channels)
if len(channels) != 2 {
t.Errorf("期望创建2个通道但是创建了%d个", len(channels))
}
// 检查Redis缓存
for _, ch := range channels {
key := fmt.Sprintf("channel:%d", ch.ID)
if !mr.Exists(key) {
t.Errorf("Redis缓存中应有键 %s", key)
}
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)
var err = tt.checkCache(channels)
if err != nil {
t.Errorf("检查缓存失败: %v", err)
}
}
})
}
@@ -769,7 +549,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
func Test_channelService_RemoveChannels(t *testing.T) {
mr := testutil.SetupRedisTest(t)
mdb := testutil.SetupDBTest(t)
db := testutil.SetupDBTest(t)
mg := testutil.SetupGatewayClientMock(t)
env.DebugExternalChange = false
@@ -811,34 +591,38 @@ func Test_channelService_RemoveChannels(t *testing.T) {
mr.Set(key, string(data))
}
// 开始事务
mdb.ExpectBegin()
// 清空数据库表
db.Exec("delete from channel")
db.Exec("delete from proxy")
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour)).
AddRow(2, 100, 1, 10002, "http", time.Now().Add(24*time.Hour)).
AddRow(3, 101, 2, 10001, "socks5", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1), int32(2), int32(3)).
WillReturnRows(channelRows)
// 创建代理
proxies := []models.Proxy{
{ID: 1, Name: "proxy1", Host: "proxy1.example.com", Secret: "key:secret", Type: 1},
{ID: 2, Name: "proxy2", Host: "proxy2.example.com", Secret: "key:secret", Type: 1},
}
for _, p := range proxies {
db.Create(&p)
}
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1).
AddRow(2, "proxy2", "proxy2.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1), int32(2)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 3))
// 提交事务
mdb.ExpectCommit()
// 创建通道
channels := []models.Channel{
{ID: 1, UserID: 100, ProxyID: 1, ProxyPort: 10001, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 2, UserID: 100, ProxyID: 1, ProxyPort: 10002, Protocol: "http", Expiration: time.Now().Add(24 * time.Hour)},
{ID: 3, UserID: 101, ProxyID: 2, ProxyPort: 10001, Protocol: "socks5", Expiration: time.Now().Add(24 * time.Hour)},
}
for _, c := range channels {
db.Create(&c)
}
},
checkCache: func(t *testing.T) {
// 检查通道是否被软删除
var count int64
db.Model(&models.Channel{}).Where("id IN ? AND deleted_at IS NULL", []int32{1, 2, 3}).Count(&count)
if count > 0 {
t.Errorf("应该软删除了所有通道,但仍有 %d 个未删除", count)
}
// 检查Redis缓存是否被删除
for _, id := range []int32{1, 2, 3} {
key := fmt.Sprintf("channel:%d", id)
if mr.Exists(key) {
@@ -867,6 +651,31 @@ func Test_channelService_RemoveChannels(t *testing.T) {
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 清空数据库表
db.Exec("delete from channel")
db.Exec("delete from proxy")
// 创建代理
proxy := models.Proxy{
ID: 1,
Name: "proxy1",
Host: "proxy1.example.com",
Secret: "key:secret",
Type: 1,
}
db.Create(&proxy)
// 创建通道
ch := models.Channel{
ID: 1,
UserID: 100,
ProxyID: 1,
ProxyPort: 10001,
Protocol: "http",
Expiration: time.Now().Add(24 * time.Hour),
}
db.Create(&ch)
// 模拟查询已激活的端口
mg.PortActiveMock = func(param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
return map[string]remote.PortData{
@@ -875,32 +684,16 @@ func Test_channelService_RemoveChannels(t *testing.T) {
},
}, nil
}
// 开始事务
mdb.ExpectBegin()
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(1, 100, 1, 10001, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(channelRows)
// 查找代理
proxyRows := sqlmock.NewRows([]string{"id", "name", "host", "secret", "type"}).
AddRow(1, "proxy1", "proxy1.example.com", "key:secret", 1)
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `proxy` WHERE `proxy`.`id` IN")).
WithArgs(int32(1)).
WillReturnRows(proxyRows)
// 软删除通道
mdb.ExpectExec(regexp.QuoteMeta("UPDATE `channel` SET")).
WillReturnResult(sqlmock.NewResult(0, 1))
// 提交事务
mdb.ExpectCommit()
},
checkCache: func(t *testing.T) {
// 检查通道是否被软删除
var count int64
db.Model(&models.Channel{}).Where("id = ? AND deleted_at IS NULL", 1).Count(&count)
if count > 0 {
t.Errorf("应该软删除了通道,但仍未删除")
}
// 检查Redis缓存是否被删除
key := "channel:1"
if mr.Exists(key) {
t.Errorf("通道缓存 %s 应被删除但仍存在", key)
@@ -927,18 +720,19 @@ func Test_channelService_RemoveChannels(t *testing.T) {
data, _ := json.Marshal(channel)
mr.Set(key, string(data))
// 开始事务
mdb.ExpectBegin()
// 清空数据库表
db.Exec("delete from channel")
// 查找通道
channelRows := sqlmock.NewRows([]string{"id", "user_id", "proxy_id", "proxy_port", "protocol", "expiration"}).
AddRow(5, 101, 1, 10005, "http", time.Now().Add(24*time.Hour))
mdb.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `channel` WHERE `channel`.`id` IN")).
WithArgs(int32(5)).
WillReturnRows(channelRows)
// 回滚事务
mdb.ExpectRollback()
// 创建一个属于用户101的通道
ch := models.Channel{
ID: 5,
UserID: 101,
ProxyID: 1,
ProxyPort: 10005,
Protocol: "http",
Expiration: time.Now().Add(24 * time.Hour),
}
db.Create(&ch)
},
wantErr: true,
wantErrContains: "无权限访问",
@@ -971,11 +765,6 @@ func Test_channelService_RemoveChannels(t *testing.T) {
return
}
// 验证所有期望的 SQL 已执行
if err := mdb.ExpectationsWereMet(); err != nil {
t.Errorf("有未满足的SQL期望: %s", err)
}
// 检查 Redis 缓存是否正确设置
if tt.checkCache != nil {
tt.checkCache(t)