diff --git a/README.md b/README.md index b05dce1..608f38b 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ oauth token 验证授权范围 开发环境数据库迁移: ```powershell -pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai" --allow-hazards INDEX_BUILD,INDEX_DROPPE +pg-schema-diff apply --schema-dir .\scripts\sql --dsn "host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai" ``` ## 枚举字典 diff --git a/cmd/playground/main.go b/cmd/playground/main.go index cdad5be..95d7622 100644 --- a/cmd/playground/main.go +++ b/cmd/playground/main.go @@ -1,10 +1,9 @@ package main -import ( - "fmt" - "time" -) - func main() { - fmt.Printf("%v\n", time.Now()) + println('|') + println(':') + println('\t') + println('\r') + println('\n') } diff --git a/go.mod b/go.mod index 7071590..1202951 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/jxskiss/base62 v1.1.0 github.com/lmittmann/tint v1.0.7 + github.com/mattn/go-sqlite3 v1.14.24 github.com/redis/go-redis/v9 v9.7.3 golang.org/x/crypto v0.36.0 gorm.io/driver/postgres v1.5.11 @@ -37,7 +38,6 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/stretchr/testify v1.8.2 // indirect diff --git a/pkg/testutil/remote.go b/pkg/testutil/remote.go index e5f5983..71ef483 100644 --- a/pkg/testutil/remote.go +++ b/pkg/testutil/remote.go @@ -69,7 +69,20 @@ func (m *MockCloudClient) CloudAutoQuery() (remote.CloudConnectResp, error) { // SetupCloudClientMock 替换全局CloudClient为测试实现并在测试完成后恢复 func SetupCloudClientMock(t *testing.T) *MockCloudClient { - mock := &MockCloudClient{} + mock := &MockCloudClient{ + EdgesMock: func(param remote.CloudEdgesReq) (*remote.CloudEdgesResp, error) { + panic("not implemented") + }, + ConnectMock: func(param remote.CloudConnectReq) error { + panic("not implemented") + }, + DisconnectMock: func(param remote.CloudDisconnectReq) (int, error) { + panic("not implemented") + }, + AutoQueryMock: func() (remote.CloudConnectResp, error) { + panic("not implemented") + }, + } remote.Cloud = mock return mock @@ -117,7 +130,14 @@ type GatewayClientIns struct { mu sync.Mutex } -var testGatewayBase = &GatewayClientIns{} +var testGatewayBase = &GatewayClientIns{ + PortConfigsMock: func(c *MockGatewayClient, params []remote.PortConfigsReq) error { + panic("not implemented") + }, + PortActiveMock: func(c *MockGatewayClient, param ...remote.PortActiveReq) (map[string]remote.PortData, error) { + panic("not implemented") + }, +} // SetupGatewayClientMock 创建一个MockGatewayClient并提供替换函数 func SetupGatewayClientMock(t *testing.T) *GatewayClientIns { diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index af797c3..8e38840 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -512,7 +512,7 @@ create table channel ( auth_ip bool not null default false, user_host varchar(255), auth_pass bool not null default false, - username varchar(255) unique, + username varchar(255), password varchar(255), expiration timestamp not null, created_at timestamp default current_timestamp, diff --git a/web/common/types.go b/web/common/types.go index 4198440..c616f51 100644 --- a/web/common/types.go +++ b/web/common/types.go @@ -67,19 +67,43 @@ type PageResp struct { type LocalDateTime time.Time +var formats = []string{ + "2006-01-02 15:04:05.999999999-07:00", + "2006-01-02T15:04:05.999999999-07:00", + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04", + "2006-01-02T15:04", + "2006-01-02", +} + //goland:noinspection GoMixedReceiverTypes func (ldt *LocalDateTime) Scan(value interface{}) (err error) { + var t time.Time - nullTime := &sql.NullTime{} - err = nullTime.Scan(value) - if err != nil { - return err + if strValue, ok := value.(string); ok { + var timeValue time.Time + for _, format := range formats { + timeValue, err = time.Parse(format, strValue) + if err == nil { + t = timeValue + break + } + } + t = timeValue + } else { + nullTime := &sql.NullTime{} + err = nullTime.Scan(value) + if err != nil { + return err + } + if nullTime == nil { + return nil + } + t = nullTime.Time } - if nullTime == nil { - return nil - } - - t := nullTime.Time *ldt = LocalDateTime(time.Date( t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local, )) diff --git a/web/handlers/channel.go b/web/handlers/channel.go index dd2a791..cfd1364 100644 --- a/web/handlers/channel.go +++ b/web/handlers/channel.go @@ -15,34 +15,22 @@ import ( // region CreateChannel type CreateChannelReq struct { - ResourceId int32 `json:"resource_id" validate:"required"` - Protocol services.ChannelProtocol `json:"protocol" validate:"required,oneof=socks5 http https"` - AuthType services.ChannelAuthType `json:"auth_type" validate:"required,oneof=0 1"` - Count int `json:"count" validate:"required"` - Prov string `json:"prov" validate:"required"` - City string `json:"city" validate:"required"` - Isp string `json:"isp" validate:"required"` - ResultType CreateChannelResultType `json:"result_type" validate:"required,oneof=json text"` - ResultBreaker []rune `json:"result_breaker" validate:""` - ResultSeparator []rune `json:"result_separator" validate:""` + ResourceId int32 `json:"resource_id" validate:"required"` + AuthType services.ChannelAuthType `json:"auth_type" validate:"required"` + Protocol services.ChannelProtocol `json:"protocol" validate:"required"` + Count int `json:"count" validate:"required"` + Prov string `json:"prov"` + City string `json:"city"` + Isp string `json:"isp"` } func CreateChannel(c *fiber.Ctx) error { + req := new(CreateChannelReq) if err := c.BodyParser(req); err != nil { return err } - if req.ResultType == "" { - req.ResultType = CreateChannelResultTypeText - } - if req.ResultBreaker == nil { - req.ResultBreaker = []rune("\r\n") - } - if req.ResultSeparator == nil { - req.ResultSeparator = []rune("|") - } - // 建立连接通道 auth, ok := c.Locals("auth").(*services.AuthContext) if !ok { @@ -66,35 +54,7 @@ func CreateChannel(c *fiber.Ctx) error { return err } - var separator = string(req.ResultSeparator) - switch req.ResultType { - case CreateChannelResultTypeJson: - return c.JSON(fiber.Map{ - "code": 1, - "data": result, - }) - default: - var breaker = string(req.ResultBreaker) - var str = strings.Builder{} - for _, info := range result { - - str.WriteString(info.Host) - - str.WriteString(separator) - str.WriteString(strconv.Itoa(info.Port)) - - if info.Username != nil { - str.WriteString(separator) - str.WriteString(*info.Username) - } - if info.Password != nil { - str.WriteString(separator) - str.WriteString(*info.Password) - } - str.WriteString(breaker) - } - return c.SendString(str.String()) - } + return c.JSON(result) } type CreateChannelResultType string diff --git a/web/handlers/resource.go b/web/handlers/resource.go index 71e597b..b14e5e5 100644 --- a/web/handlers/resource.go +++ b/web/handlers/resource.go @@ -66,12 +66,11 @@ func ListResourcePss(c *fiber.Ctx) error { do = do.Where(q.ResourcePss.As(q.Resource.Pss.Name()).Expire.Lte(common.LocalDateTime(*req.ExpireBefore))) } - var resource []*m.Resource - err = do.Debug(). + resource, err := do.Debug(). Order(q.Resource.CreatedAt.Desc()). Offset(req.GetOffset()). Limit(req.GetLimit()). - Scan(&resource) + Find() if err != nil { return err } diff --git a/web/services/channel.go b/web/services/channel.go index f652092..2dc240b 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -33,16 +33,18 @@ type channelService struct { type ChannelAuthType int const ( - ChannelAuthTypeIp = iota + ChannelAuthTypeAll ChannelAuthType = iota + ChannelAuthTypeIp ChannelAuthTypePass ) type ChannelProtocol int32 const ( - ProtocolHTTP = ChannelProtocol(1) - ProtocolHttps = ChannelProtocol(2) - ProtocolSocks5 = ChannelProtocol(3) + ProtocolAll ChannelProtocol = iota + ProtocolHTTP + ProtocolHttps + ProtocolSocks5 ) type ResourceInfo struct { @@ -53,10 +55,10 @@ type ResourceInfo struct { Live int32 DailyLimit int32 DailyUsed int32 - DailyLast time.Time + DailyLast common.LocalDateTime Quota int32 Used int32 - Expire time.Time + Expire common.LocalDateTime } // region RemoveChannel @@ -313,7 +315,7 @@ func (s *channelService) CreateChannel( Used: resource.Used + int32(count), DailyLast: common.LocalDateTime(now), } - last := resource.DailyLast + last := time.Time(resource.DailyLast) if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() { toUpdate.DailyUsed = int32(count) } else { @@ -365,7 +367,7 @@ func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { } // 检查每日限额 - today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02") + today := time.Now().Format("2006-01-02") == time.Time(resource.DailyLast).Format("2006-01-02") dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0)) if today && dailyRemain < count { return ChannelServiceErr("套餐每日配额不足") @@ -373,7 +375,7 @@ func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { // 检查时间或配额 if resource.Type == 1 { // 包时 - if resource.Expire.Before(time.Now()) { + if time.Time(resource.Expire).Before(time.Now()) { return ChannelServiceErr("套餐已过期") } } else { // 包量 @@ -559,6 +561,7 @@ func assignPort( key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort) portsMap[key] = struct{}{} } + println(len(portsMap)) // 查找用户白名单 var whitelist []string @@ -570,6 +573,9 @@ func assignPort( if err != nil { return nil, nil, err } + if len(whitelist) == 0 { + return nil, nil, ChannelServiceErr("用户没有白名单") + } } // 配置启用代理 diff --git a/web/services/channel_test.go b/web/services/channel_test.go index cd3cbb4..9087fea 100644 --- a/web/services/channel_test.go +++ b/web/services/channel_test.go @@ -288,26 +288,6 @@ func Test_channelService_CreateChannel(t *testing.T) { ctx := context.WithValue(context.Background(), requestid.ConfigDefault.ContextKey, "test-request-id") var adminAuth = &AuthContext{Payload: Payload{Id: 100, Type: PayloadAdmin}} var userAuth = &AuthContext{Payload: Payload{Id: 101, Type: PayloadUser}} - var user = &models.User{ - ID: 101, - Phone: "12312341234", - } - db.Create(user) - var whitelists = []*models.Whitelist{ - {ID: 1, UserID: 101, Host: "123.123.123.123"}, - {ID: 2, UserID: 101, Host: "456.456.456.456"}, - {ID: 3, UserID: 101, Host: "789.789.789.789"}, - } - db.Create(whitelists) - var proxy = &models.Proxy{ - ID: 1, - Version: 1, - Name: "test-proxy", - Host: "111.111.111.111", - Type: 1, - Secret: "test:secret", - } - db.Create(proxy) mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { return remote.CloudConnectResp{ "test-proxy": []remote.AutoConfig{ @@ -315,18 +295,48 @@ func Test_channelService_CreateChannel(t *testing.T) { }, }, nil } - var clearDb = func() { - db.Exec("delete from resource where true") - var resource = &models.Resource{ + var user *models.User + var whitelists []*models.Whitelist + var proxy *models.Proxy + var resource *models.Resource + var resourcePss *models.ResourcePss + var resetDb = func() { + user = &models.User{ + ID: 101, + Phone: "12312341234", + } + db.Exec("delete from user where true") + db.Create(user) + + whitelists = []*models.Whitelist{ + {ID: 1, UserID: 101, Host: "123.123.123.123"}, + {ID: 2, UserID: 101, Host: "456.456.456.456"}, + {ID: 3, UserID: 101, Host: "789.789.789.789"}, + } + db.Exec("delete from whitelist where true") + db.Create(whitelists) + + proxy = &models.Proxy{ + ID: 1, + Version: 1, + Name: "test-proxy", + Host: "111.111.111.111", + Type: 1, + Secret: "test:secret", + } + db.Exec("delete from proxy where true") + db.Create(proxy) + + resource = &models.Resource{ ID: 1, UserID: 101, Active: true, } + db.Exec("delete from resource where true") db.Create(resource) - db.Exec("delete from resource_pss where true") - var resourcePss = &models.ResourcePss{ + resourcePss = &models.ResourcePss{ ID: 1, ResourceID: 1, Type: 1, @@ -334,10 +344,12 @@ func Test_channelService_CreateChannel(t *testing.T) { Expire: common.LocalDateTime(time.Now().AddDate(1, 0, 0)), DailyLimit: 10000, } + db.Exec("delete from resource_pss where true") db.Create(resourcePss) db.Exec("delete from channel where true") } + tests := []struct { name string args args @@ -359,7 +371,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() mc.ConnectMock = func(param remote.CloudConnectReq) error { if param.Uuid != proxy.Name { @@ -509,7 +521,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() mc.ConnectMock = func(param remote.CloudConnectReq) error { if param.Uuid != proxy.Name { @@ -653,7 +665,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() mc.ConnectMock = func(param remote.CloudConnectReq) error { if param.Uuid != proxy.Name { @@ -802,7 +814,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() }, wantErr: true, wantErrContains: "无权限访问", @@ -819,7 +831,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() resource2 := &models.Resource{ ID: 2, @@ -852,7 +864,7 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() // 创建一个配额几乎用完的资源包 resource2 := models.Resource{ @@ -886,7 +898,14 @@ func Test_channelService_CreateChannel(t *testing.T) { }, setup: func() { mr.FlushAll() - clearDb() + resetDb() + mc.AutoQueryMock = func() (remote.CloudConnectResp, error) { + return remote.CloudConnectResp{ + "test-proxy": []remote.AutoConfig{ + {Count: 20000}, + }, + }, nil + } // 创建大量占用端口的通道 var channels = make([]models.Channel, 10000) var expr = time.Now().Add(time.Hour) @@ -908,8 +927,6 @@ func Test_channelService_CreateChannel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mr.FlushAll() - clearDb() if tt.setup != nil { tt.setup() }