更新提取接口参数,完善流程与测试代码逻辑
This commit is contained in:
@@ -62,7 +62,7 @@ oauth token 验证授权范围
|
||||
开发环境数据库迁移:
|
||||
|
||||
```powershell
|
||||
pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai" --allow-hazards INDEX_BUILD,INDEX_DROPPE
|
||||
pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai"
|
||||
```
|
||||
|
||||
## 枚举字典
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func main() {
|
||||
fmt.Printf("%v\n", time.Now())
|
||||
println('|')
|
||||
println(':')
|
||||
println('\t')
|
||||
println('\r')
|
||||
println('\n')
|
||||
}
|
||||
|
||||
2
go.mod
2
go.mod
@@ -10,6 +10,7 @@ require (
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/jxskiss/base62 v1.1.0
|
||||
github.com/lmittmann/tint v1.0.7
|
||||
github.com/mattn/go-sqlite3 v1.14.24
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
golang.org/x/crypto v0.36.0
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
@@ -37,7 +38,6 @@ require (
|
||||
github.com/mattn/go-colorable v0.1.14 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.16 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
||||
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
|
||||
github.com/rivo/uniseg v0.4.7 // indirect
|
||||
github.com/stretchr/testify v1.8.2 // indirect
|
||||
|
||||
@@ -69,7 +69,20 @@ func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) {
|
||||
|
||||
// SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复
|
||||
func SetupCloudClientMock(t *testing.T) *MockCloudClient {
|
||||
mock := &MockCloudClient{}
|
||||
mock := &MockCloudClient{
|
||||
EdgesMock: func(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
ConnectMock: func(param remote.CloudConnectReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
DisconnectMock: func(param remote.CloudDisconnectReq) (int, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
AutoQueryMock: func() (remote.CloudConnectResp, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
remote.Cloud = mock
|
||||
|
||||
return mock
|
||||
@@ -117,7 +130,14 @@ type GatewayClientIns struct {
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
var testGatewayBase = &GatewayClientIns{}
|
||||
var testGatewayBase = &GatewayClientIns{
|
||||
PortConfigsMock: func(c *MockGatewayClient, params []remote.PortConfigsReq) error {
|
||||
panic("not implemented")
|
||||
},
|
||||
PortActiveMock: func(c *MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) {
|
||||
panic("not implemented")
|
||||
},
|
||||
}
|
||||
|
||||
// SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数
|
||||
func SetupGatewayClientMock(t *testing.T) *GatewayClientIns {
|
||||
|
||||
@@ -512,7 +512,7 @@ create table channel (
|
||||
auth_ip bool not null default false,
|
||||
user_host varchar(255),
|
||||
auth_pass bool not null default false,
|
||||
username varchar(255) unique,
|
||||
username varchar(255),
|
||||
password varchar(255),
|
||||
expiration timestamp not null,
|
||||
created_at timestamp default current_timestamp,
|
||||
|
||||
@@ -67,19 +67,43 @@ type PageResp struct {
|
||||
|
||||
type LocalDateTime time.Time
|
||||
|
||||
var formats = []string{
|
||||
"2006-01-02 15:04:05.999999999-07:00",
|
||||
"2006-01-02T15:04:05.999999999-07:00",
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02",
|
||||
}
|
||||
|
||||
//goland:noinspection GoMixedReceiverTypes
|
||||
func (ldt *LocalDateTime) Scan(value interface{}) (err error) {
|
||||
var t time.Time
|
||||
|
||||
nullTime := &sql.NullTime{}
|
||||
err = nullTime.Scan(value)
|
||||
if err != nil {
|
||||
return err
|
||||
if strValue, ok := value.(string); ok {
|
||||
var timeValue time.Time
|
||||
for _, format := range formats {
|
||||
timeValue, err = time.Parse(format, strValue)
|
||||
if err == nil {
|
||||
t = timeValue
|
||||
break
|
||||
}
|
||||
}
|
||||
t = timeValue
|
||||
} else {
|
||||
nullTime := &sql.NullTime{}
|
||||
err = nullTime.Scan(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nullTime == nil {
|
||||
return nil
|
||||
}
|
||||
t = nullTime.Time
|
||||
}
|
||||
if nullTime == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := nullTime.Time
|
||||
*ldt = LocalDateTime(time.Date(
|
||||
t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local,
|
||||
))
|
||||
|
||||
@@ -15,34 +15,22 @@ import (
|
||||
// region CreateChannel
|
||||
|
||||
type CreateChannelReq struct {
|
||||
ResourceId int32 `json:"resource_id" validate:"required"`
|
||||
Protocol services.ChannelProtocol `json:"protocol" validate:"required,oneof=socks5 http https"`
|
||||
AuthType services.ChannelAuthType `json:"auth_type" validate:"required,oneof=0 1"`
|
||||
Count int `json:"count" validate:"required"`
|
||||
Prov string `json:"prov" validate:"required"`
|
||||
City string `json:"city" validate:"required"`
|
||||
Isp string `json:"isp" validate:"required"`
|
||||
ResultType CreateChannelResultType `json:"result_type" validate:"required,oneof=json text"`
|
||||
ResultBreaker []rune `json:"result_breaker" validate:""`
|
||||
ResultSeparator []rune `json:"result_separator" validate:""`
|
||||
ResourceId int32 `json:"resource_id" validate:"required"`
|
||||
AuthType services.ChannelAuthType `json:"auth_type" validate:"required"`
|
||||
Protocol services.ChannelProtocol `json:"protocol" validate:"required"`
|
||||
Count int `json:"count" validate:"required"`
|
||||
Prov string `json:"prov"`
|
||||
City string `json:"city"`
|
||||
Isp string `json:"isp"`
|
||||
}
|
||||
|
||||
func CreateChannel(c *fiber.Ctx) error {
|
||||
|
||||
req := new(CreateChannelReq)
|
||||
if err := c.BodyParser(req); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if req.ResultType == "" {
|
||||
req.ResultType = CreateChannelResultTypeText
|
||||
}
|
||||
if req.ResultBreaker == nil {
|
||||
req.ResultBreaker = []rune("\r\n")
|
||||
}
|
||||
if req.ResultSeparator == nil {
|
||||
req.ResultSeparator = []rune("|")
|
||||
}
|
||||
|
||||
// 建立连接通道
|
||||
auth, ok := c.Locals("auth").(*services.AuthContext)
|
||||
if !ok {
|
||||
@@ -66,35 +54,7 @@ func CreateChannel(c *fiber.Ctx) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var separator = string(req.ResultSeparator)
|
||||
switch req.ResultType {
|
||||
case CreateChannelResultTypeJson:
|
||||
return c.JSON(fiber.Map{
|
||||
"code": 1,
|
||||
"data": result,
|
||||
})
|
||||
default:
|
||||
var breaker = string(req.ResultBreaker)
|
||||
var str = strings.Builder{}
|
||||
for _, info := range result {
|
||||
|
||||
str.WriteString(info.Host)
|
||||
|
||||
str.WriteString(separator)
|
||||
str.WriteString(strconv.Itoa(info.Port))
|
||||
|
||||
if info.Username != nil {
|
||||
str.WriteString(separator)
|
||||
str.WriteString(*info.Username)
|
||||
}
|
||||
if info.Password != nil {
|
||||
str.WriteString(separator)
|
||||
str.WriteString(*info.Password)
|
||||
}
|
||||
str.WriteString(breaker)
|
||||
}
|
||||
return c.SendString(str.String())
|
||||
}
|
||||
return c.JSON(result)
|
||||
}
|
||||
|
||||
type CreateChannelResultType string
|
||||
|
||||
@@ -66,12 +66,11 @@ func ListResourcePss(c *fiber.Ctx) error {
|
||||
do = do.Where(q.ResourcePss.As(q.Resource.Pss.Name()).Expire.Lte(common.LocalDateTime(*req.ExpireBefore)))
|
||||
}
|
||||
|
||||
var resource []*m.Resource
|
||||
err = do.Debug().
|
||||
resource, err := do.Debug().
|
||||
Order(q.Resource.CreatedAt.Desc()).
|
||||
Offset(req.GetOffset()).
|
||||
Limit(req.GetLimit()).
|
||||
Scan(&resource)
|
||||
Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -33,16 +33,18 @@ type channelService struct {
|
||||
type ChannelAuthType int
|
||||
|
||||
const (
|
||||
ChannelAuthTypeIp = iota
|
||||
ChannelAuthTypeAll ChannelAuthType = iota
|
||||
ChannelAuthTypeIp
|
||||
ChannelAuthTypePass
|
||||
)
|
||||
|
||||
type ChannelProtocol int32
|
||||
|
||||
const (
|
||||
ProtocolHTTP = ChannelProtocol(1)
|
||||
ProtocolHttps = ChannelProtocol(2)
|
||||
ProtocolSocks5 = ChannelProtocol(3)
|
||||
ProtocolAll ChannelProtocol = iota
|
||||
ProtocolHTTP
|
||||
ProtocolHttps
|
||||
ProtocolSocks5
|
||||
)
|
||||
|
||||
type ResourceInfo struct {
|
||||
@@ -53,10 +55,10 @@ type ResourceInfo struct {
|
||||
Live int32
|
||||
DailyLimit int32
|
||||
DailyUsed int32
|
||||
DailyLast time.Time
|
||||
DailyLast common.LocalDateTime
|
||||
Quota int32
|
||||
Used int32
|
||||
Expire time.Time
|
||||
Expire common.LocalDateTime
|
||||
}
|
||||
|
||||
// region RemoveChannel
|
||||
@@ -313,7 +315,7 @@ func (s *channelService) CreateChannel(
|
||||
Used: resource.Used + int32(count),
|
||||
DailyLast: common.LocalDateTime(now),
|
||||
}
|
||||
last := resource.DailyLast
|
||||
last := time.Time(resource.DailyLast)
|
||||
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
|
||||
toUpdate.DailyUsed = int32(count)
|
||||
} else {
|
||||
@@ -365,7 +367,7 @@ func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
|
||||
}
|
||||
|
||||
// 检查每日限额
|
||||
today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02")
|
||||
today := time.Now().Format("2006-01-02") == time.Time(resource.DailyLast).Format("2006-01-02")
|
||||
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
|
||||
if today && dailyRemain < count {
|
||||
return ChannelServiceErr("套餐每日配额不足")
|
||||
@@ -373,7 +375,7 @@ func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
|
||||
|
||||
// 检查时间或配额
|
||||
if resource.Type == 1 { // 包时
|
||||
if resource.Expire.Before(time.Now()) {
|
||||
if time.Time(resource.Expire).Before(time.Now()) {
|
||||
return ChannelServiceErr("套餐已过期")
|
||||
}
|
||||
} else { // 包量
|
||||
@@ -559,6 +561,7 @@ func assignPort(
|
||||
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
|
||||
portsMap[key] = struct{}{}
|
||||
}
|
||||
println(len(portsMap))
|
||||
|
||||
// 查找用户白名单
|
||||
var whitelist []string
|
||||
@@ -570,6 +573,9 @@ func assignPort(
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if len(whitelist) == 0 {
|
||||
return nil, nil, ChannelServiceErr("用户没有白名单")
|
||||
}
|
||||
}
|
||||
|
||||
// 配置启用代理
|
||||
|
||||
@@ -288,26 +288,6 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id")
|
||||
var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}}
|
||||
var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}}
|
||||
var user = &models.User{
|
||||
ID: 101,
|
||||
Phone: "12312341234",
|
||||
}
|
||||
db.Create(user)
|
||||
var whitelists = []*models.Whitelist{
|
||||
{ID: 1, UserID: 101, Host: "123.123.123.123"},
|
||||
{ID: 2, UserID: 101, Host: "456.456.456.456"},
|
||||
{ID: 3, UserID: 101, Host: "789.789.789.789"},
|
||||
}
|
||||
db.Create(whitelists)
|
||||
var proxy = &models.Proxy{
|
||||
ID: 1,
|
||||
Version: 1,
|
||||
Name: "test-proxy",
|
||||
Host: "111.111.111.111",
|
||||
Type: 1,
|
||||
Secret: "test:secret",
|
||||
}
|
||||
db.Create(proxy)
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"test-proxy": []remote.AutoConfig{
|
||||
@@ -315,18 +295,48 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
var clearDb = func() {
|
||||
|
||||
db.Exec("delete from resource where true")
|
||||
var resource = &models.Resource{
|
||||
var user *models.User
|
||||
var whitelists []*models.Whitelist
|
||||
var proxy *models.Proxy
|
||||
var resource *models.Resource
|
||||
var resourcePss *models.ResourcePss
|
||||
var resetDb = func() {
|
||||
user = &models.User{
|
||||
ID: 101,
|
||||
Phone: "12312341234",
|
||||
}
|
||||
db.Exec("delete from user where true")
|
||||
db.Create(user)
|
||||
|
||||
whitelists = []*models.Whitelist{
|
||||
{ID: 1, UserID: 101, Host: "123.123.123.123"},
|
||||
{ID: 2, UserID: 101, Host: "456.456.456.456"},
|
||||
{ID: 3, UserID: 101, Host: "789.789.789.789"},
|
||||
}
|
||||
db.Exec("delete from whitelist where true")
|
||||
db.Create(whitelists)
|
||||
|
||||
proxy = &models.Proxy{
|
||||
ID: 1,
|
||||
Version: 1,
|
||||
Name: "test-proxy",
|
||||
Host: "111.111.111.111",
|
||||
Type: 1,
|
||||
Secret: "test:secret",
|
||||
}
|
||||
db.Exec("delete from proxy where true")
|
||||
db.Create(proxy)
|
||||
|
||||
resource = &models.Resource{
|
||||
ID: 1,
|
||||
UserID: 101,
|
||||
Active: true,
|
||||
}
|
||||
db.Exec("delete from resource where true")
|
||||
db.Create(resource)
|
||||
|
||||
db.Exec("delete from resource_pss where true")
|
||||
var resourcePss = &models.ResourcePss{
|
||||
resourcePss = &models.ResourcePss{
|
||||
ID: 1,
|
||||
ResourceID: 1,
|
||||
Type: 1,
|
||||
@@ -334,10 +344,12 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
Expire: common.LocalDateTime(time.Now().AddDate(1, 0, 0)),
|
||||
DailyLimit: 10000,
|
||||
}
|
||||
db.Exec("delete from resource_pss where true")
|
||||
db.Create(resourcePss)
|
||||
|
||||
db.Exec("delete from channel where true")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
@@ -359,7 +371,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
|
||||
mc.ConnectMock = func(param remote.CloudConnectReq) error {
|
||||
if param.Uuid != proxy.Name {
|
||||
@@ -509,7 +521,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
|
||||
mc.ConnectMock = func(param remote.CloudConnectReq) error {
|
||||
if param.Uuid != proxy.Name {
|
||||
@@ -653,7 +665,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
|
||||
mc.ConnectMock = func(param remote.CloudConnectReq) error {
|
||||
if param.Uuid != proxy.Name {
|
||||
@@ -802,7 +814,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
},
|
||||
wantErr: true,
|
||||
wantErrContains: "无权限访问",
|
||||
@@ -819,7 +831,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
|
||||
resource2 := &models.Resource{
|
||||
ID: 2,
|
||||
@@ -852,7 +864,7 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
|
||||
// 创建一个配额几乎用完的资源包
|
||||
resource2 := models.Resource{
|
||||
@@ -886,7 +898,14 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
},
|
||||
setup: func() {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
resetDb()
|
||||
mc.AutoQueryMock = func() (remote.CloudConnectResp, error) {
|
||||
return remote.CloudConnectResp{
|
||||
"test-proxy": []remote.AutoConfig{
|
||||
{Count: 20000},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// 创建大量占用端口的通道
|
||||
var channels = make([]models.Channel, 10000)
|
||||
var expr = time.Now().Add(time.Hour)
|
||||
@@ -908,8 +927,6 @@ func Test_channelService_CreateChannel(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mr.FlushAll()
|
||||
clearDb()
|
||||
if tt.setup != nil {
|
||||
tt.setup()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user