diff --git a/README.md b/README.md index 859c888..66dc8ec 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ trade/create 性能问题,缩短事务时间,考虑其他方式实现可靠 网关缩扩容太慢 +redis channel lease 加一个 zset,定时处理没有成功释放的端口 + ### 长期 分离 task 的客户端,支持多进程(prefork 必要!) diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index 1ec850c..bf49354 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -604,7 +604,7 @@ create table channel ( filter_prov text, filter_city text, ip inet, - whitelists text[], + whitelists text, username text, password text, expired_at timestamptz not null, diff --git a/web/globals/orm/localdatetime.go b/web/globals/orm/localdatetime.go deleted file mode 100644 index bec4cad..0000000 --- a/web/globals/orm/localdatetime.go +++ /dev/null @@ -1,74 +0,0 @@ -package orm - -import ( - "database/sql" - "database/sql/driver" - "time" -) - -type LocalDateTime time.Time - -var formats = []string{ - "2006-01-02 15:04:05.999999999-07:00", - "2006-01-02T15:04:05.999999999-07:00", - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999", - "2006-01-02 15:04:05", - "2006-01-02T15:04:05", - "2006-01-02 15:04", - "2006-01-02T15:04", - "2006-01-02", -} - -func (ldt *LocalDateTime) Scan(value any) (err error) { - var t time.Time - if strValue, ok := value.(string); ok { - var timeValue time.Time - for _, format := range formats { - timeValue, err = time.Parse(format, strValue) - if err == nil { - t = timeValue - break - } - } - t = timeValue - } else { - nullTime := &sql.NullTime{} - err = nullTime.Scan(value) - if err != nil { - return err - } - if nullTime == nil { - return nil - } - t = nullTime.Time - } - *ldt = LocalDateTime(time.Date( - t.Year(), t.Month(), t.Day(), t.Hour(), t.Minute(), t.Second(), t.Nanosecond(), time.Local, - )) - return -} - -func (ldt LocalDateTime) Value() (driver.Value, error) { - return time.Time(ldt).Local(), nil -} - -func (ldt LocalDateTime) GormDataType() string { - return "ldt" -} - -func (ldt LocalDateTime) GobEncode() ([]byte, error) { - return time.Time(ldt).GobEncode() -} - -func (ldt *LocalDateTime) GobDecode(b []byte) error { - return (*time.Time)(ldt).GobDecode(b) -} - -func (ldt LocalDateTime) MarshalJSON() ([]byte, error) { - return time.Time(ldt).MarshalJSON() -} - -func (ldt *LocalDateTime) UnmarshalJSON(b []byte) error { - return (*time.Time)(ldt).UnmarshalJSON(b) -} diff --git a/web/globals/orm/slice.go b/web/globals/orm/slice.go deleted file mode 100644 index 5f97ae1..0000000 --- a/web/globals/orm/slice.go +++ /dev/null @@ -1,24 +0,0 @@ -package orm - -import ( - "database/sql/driver" - "encoding/json" - "fmt" -) - -type Slice[T any] struct { - Arr []T -} - -func (s Slice[T]) Value() (driver.Value, error) { - return json.Marshal(s) -} - -func (s *Slice[T]) Scan(value any) error { - switch value := value.(type) { - case []byte: - return json.Unmarshal(value, s) - default: - return fmt.Errorf("不支持的类型: %T", value) - } -} diff --git a/web/handlers/channel.go b/web/handlers/channel.go index 011f561..0db7cff 100644 --- a/web/handlers/channel.go +++ b/web/handlers/channel.go @@ -56,7 +56,8 @@ func ListChannels(c *fiber.Ctx) error { } // 查询数据 - channels, err := q.Channel.Debug(). + channels, err := q.Channel. + Preload(q.Channel.Proxy). Where(cond). Order(q.Channel.CreatedAt.Desc()). Offset(req.GetOffset()). @@ -110,17 +111,6 @@ type CreateChannelRespItem struct { func CreateChannel(c *fiber.Ctx) error { - // 检查权限 - authCtx, err := auth.GetAuthCtx(c).PermitUser() - if err != nil { - return err - } - - user := authCtx.User - if user.IDToken == nil || *user.IDToken == "" { - return fiber.NewError(fiber.StatusForbidden, "账号未实名") - } - // 解析参数 req := new(CreateChannelReq) if err := g.Validator.Validate(c, req); err != nil { @@ -135,7 +125,6 @@ func CreateChannel(c *fiber.Ctx) error { // 创建通道 result, err := s.Channel.CreateChannels( ip, - user.ID, req.ResourceId, req.AuthType == s.ChannelAuthTypeIp, req.AuthType == s.ChannelAuthTypePass, diff --git a/web/models/channel.go b/web/models/channel.go index 6225df2..460ef8b 100644 --- a/web/models/channel.go +++ b/web/models/channel.go @@ -9,20 +9,20 @@ import ( // Channel 通道表 type Channel struct { core.Model - UserID int32 `json:"user_id" gorm:"column:user_id"` // 用户ID - ResourceID int32 `json:"resource_id" gorm:"column:resource_id"` // 套餐ID - ProxyID int32 `json:"proxy_id" gorm:"column:proxy_id"` // 代理ID - BatchNo string `json:"batch_no" gorm:"column:batch_no"` // 批次编号 - Port uint16 `json:"port" gorm:"column:port"` // 代理端口 - EdgeID *int32 `json:"edge_id" gorm:"column:edge_id"` // 节点ID(手动配置) - FilterISP *EdgeISP `json:"filter_isp" gorm:"column:filter_isp"` // 运营商过滤(自动配置):参考 edge.isp - FilterProv *string `json:"filter_prov" gorm:"column:filter_prov"` // 省份过滤(自动配置) - FilterCity *string `json:"filter_city" gorm:"column:filter_city"` // 城市过滤(自动配置) - IP *orm.Inet `json:"ip" gorm:"column:ip"` // 节点地址 - Whitelists *orm.Slice[string] `json:"whitelists" gorm:"column:whitelists"` // IP白名单,逗号分隔 - Username *string `json:"username" gorm:"column:username"` // 用户名 - Password *string `json:"password" gorm:"column:password"` // 密码 - ExpiredAt time.Time `json:"expired_at" gorm:"column:expired_at"` // 过期时间 + UserID int32 `json:"user_id" gorm:"column:user_id"` // 用户ID + ResourceID int32 `json:"resource_id" gorm:"column:resource_id"` // 套餐ID + ProxyID int32 `json:"proxy_id" gorm:"column:proxy_id"` // 代理ID + BatchNo string `json:"batch_no" gorm:"column:batch_no"` // 批次编号 + Port uint16 `json:"port" gorm:"column:port"` // 代理端口 + EdgeID *int32 `json:"edge_id" gorm:"column:edge_id"` // 节点ID(手动配置) + FilterISP *EdgeISP `json:"filter_isp" gorm:"column:filter_isp"` // 运营商过滤(自动配置):参考 edge.isp + FilterProv *string `json:"filter_prov" gorm:"column:filter_prov"` // 省份过滤(自动配置) + FilterCity *string `json:"filter_city" gorm:"column:filter_city"` // 城市过滤(自动配置) + IP *orm.Inet `json:"ip" gorm:"column:ip"` // 节点地址 + Whitelists *string `json:"whitelists" gorm:"column:whitelists"` // IP白名单,逗号分隔 + Username *string `json:"username" gorm:"column:username"` // 用户名 + Password *string `json:"password" gorm:"column:password"` // 密码 + ExpiredAt time.Time `json:"expired_at" gorm:"column:expired_at"` // 过期时间 User User `json:"user" gorm:"foreignKey:UserID"` Resource Resource `json:"resource" gorm:"foreignKey:ResourceID"` diff --git a/web/queries/channel.gen.go b/web/queries/channel.gen.go index c4f2a2a..9f807be 100644 --- a/web/queries/channel.gen.go +++ b/web/queries/channel.gen.go @@ -41,7 +41,7 @@ func newChannel(db *gorm.DB, opts ...gen.DOOption) channel { _channel.FilterProv = field.NewString(tableName, "filter_prov") _channel.FilterCity = field.NewString(tableName, "filter_city") _channel.IP = field.NewField(tableName, "ip") - _channel.Whitelists = field.NewField(tableName, "whitelists") + _channel.Whitelists = field.NewString(tableName, "whitelists") _channel.Username = field.NewString(tableName, "username") _channel.Password = field.NewString(tableName, "password") _channel.ExpiredAt = field.NewTime(tableName, "expired_at") @@ -149,7 +149,7 @@ type channel struct { FilterProv field.String FilterCity field.String IP field.Field - Whitelists field.Field + Whitelists field.String Username field.String Password field.String ExpiredAt field.Time @@ -190,7 +190,7 @@ func (c *channel) updateTableName(table string) *channel { c.FilterProv = field.NewString(table, "filter_prov") c.FilterCity = field.NewString(table, "filter_city") c.IP = field.NewField(table, "ip") - c.Whitelists = field.NewField(table, "whitelists") + c.Whitelists = field.NewString(table, "whitelists") c.Username = field.NewString(table, "username") c.Password = field.NewString(table, "password") c.ExpiredAt = field.NewTime(table, "expired_at") diff --git a/web/queries/proxy.gen.go b/web/queries/proxy.gen.go index 04ba4cc..f4f5fd9 100644 --- a/web/queries/proxy.gen.go +++ b/web/queries/proxy.gen.go @@ -37,6 +37,7 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy { _proxy.Secret = field.NewString(tableName, "secret") _proxy.Type = field.NewInt(tableName, "type") _proxy.Status = field.NewInt(tableName, "status") + _proxy.Meta = field.NewField(tableName, "meta") _proxy.Channels = proxyHasManyChannels{ db: db.Session(&gorm.Session{}), @@ -122,6 +123,7 @@ type proxy struct { Secret field.String Type field.Int Status field.Int + Meta field.Field Channels proxyHasManyChannels fieldMap map[string]field.Expr @@ -149,6 +151,7 @@ func (p *proxy) updateTableName(table string) *proxy { p.Secret = field.NewString(table, "secret") p.Type = field.NewInt(table, "type") p.Status = field.NewInt(table, "status") + p.Meta = field.NewField(table, "meta") p.fillFieldMap() @@ -165,7 +168,7 @@ func (p *proxy) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (p *proxy) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 11) + p.fieldMap = make(map[string]field.Expr, 12) p.fieldMap["id"] = p.ID p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt @@ -176,6 +179,7 @@ func (p *proxy) fillFieldMap() { p.fieldMap["secret"] = p.Secret p.fieldMap["type"] = p.Type p.fieldMap["status"] = p.Status + p.fieldMap["meta"] = p.Meta } diff --git a/web/services/channel.go b/web/services/channel.go index 83af030..99f921e 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -17,7 +17,7 @@ var Channel ChannelService = &channelBaiyinService{} // 通道服务 type ChannelService interface { - CreateChannels(source netip.Addr, userId int32, resourceId int32, authWhitelist bool, authPassword bool, count int, edgeFilter ...EdgeFilter) ([]*m.Channel, error) + CreateChannels(source netip.Addr, resourceId int32, authWhitelist bool, authPassword bool, count int, edgeFilter ...EdgeFilter) ([]*m.Channel, error) RemoveChannels(batch string, ids []int32) error } @@ -47,12 +47,11 @@ func genPassPair() (string, string) { return string(username), string(password) } -func findResource(q *q.Query, resourceId int32, userId int32, count int, now time.Time) (*ResourceView, error) { +func findResource(q *q.Query, resourceId int32, count int, now time.Time) (*ResourceView, error) { resource, err := q.Resource. Preload(field.Associations). Where( q.Resource.ID.Eq(resourceId), - q.Resource.UserID.Eq(userId), q.Resource.Active.Is(true), ). Take() @@ -64,6 +63,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim Id: resource.ID, Active: resource.Active, Type: resource.Type, + User: resource.User, } switch resource.Type { @@ -114,35 +114,6 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim info.Used = sub.Used } - // 检查套餐使用情况 - switch info.Mode { - default: - return nil, core.NewBizErr("不支持的套餐模式") - - // 包时 - case m.ResourceModeTime: - // 检查过期时间 - if info.Expire.Before(now) { - return nil, ErrResourceExpired - } - // 检查每日限额 - used := 0 - if now.Format("2006-01-02") == info.DailyLast.Format("2006-01-02") { - used = int(info.DailyUsed) - } - excess := used+count > int(info.DailyLimit) - if excess { - return nil, ErrResourceDailyLimit - } - - // 包量 - case m.ResourceModeQuota: - // 检查可用配额 - if int(info.Quota)-int(info.Used) < count { - return nil, ErrResourceExhausted - } - } - return info, nil } @@ -159,14 +130,17 @@ type ResourceView struct { Quota int32 Used int32 Expire time.Time + User m.User } func lockChans(batch string, count int, expire time.Time) ([]netip.AddrPort, error) { chans, err := g.Redis.Eval( context.Background(), RedisScriptLockChans, - []string{"channel"}, - batch, + []string{ + "channel:chans", + "channel:lease:" + batch, + }, count, expire.Unix(), ).StringSlice() @@ -187,13 +161,10 @@ func lockChans(batch string, count int, expire time.Time) ([]netip.AddrPort, err } var RedisScriptLockChans = ` -local key = KEYS[1] -local batch = ARGV[1] -local count = tonumber(ARGV[2]) -local expire = tonumber(ARGV[3]) - -local chans_key = key .. ":chans" -local lease_key = key .. ":lease:" .. batch +local chans_key = KEYS[1] +local lease_key = KEYS[2] +local count = tonumber(ARGV[1]) +local expire = tonumber(ARGV[2]) if redis.call("SCARD", chans_key) < count then return nil @@ -210,12 +181,19 @@ return ports ` func freeChans(batch string, chans []string) error { + values := make([]any, len(chans)) + for i, ch := range chans { + values[i] = ch + } + err := g.Redis.Eval( context.Background(), RedisScriptFreeChans, - []string{"channel"}, - batch, - chans, + []string{ + "channel:chans", + "channel:lease:" + batch, + }, + values..., ).Err() if err != nil { return core.NewBizErr("释放通道失败", err) @@ -225,15 +203,11 @@ func freeChans(batch string, chans []string) error { } var RedisScriptFreeChans = ` -local key = KEYS[1] -local batch = ARGV[1] -local chans = ARGV[2] - -local chans_key = key .. ":chans" -local lease_key = key .. ":lease:" .. batch +local chans_key = KEYS[1] +local lease_key = KEYS[2] +local chans = ARGV redis.call("SADD", chans_key, unpack(chans)) - redis.call("DEL", lease_key) return chans diff --git a/web/services/channel_baiyin.go b/web/services/channel_baiyin.go index 1583a61..b66c56b 100644 --- a/web/services/channel_baiyin.go +++ b/web/services/channel_baiyin.go @@ -18,11 +18,12 @@ import ( "time" "github.com/hibiken/asynq" + "gorm.io/gen/field" ) type channelBaiyinService struct{} -func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, resourceId int32, authWhitelist bool, authPassword bool, count int, edgeFilter ...EdgeFilter) ([]*m.Channel, error) { +func (s *channelBaiyinService) CreateChannels(source netip.Addr, resourceId int32, authWhitelist bool, authPassword bool, count int, edgeFilter ...EdgeFilter) ([]*m.Channel, error) { if count > 400 { return nil, core.NewBizErr("单次最多提取 400 个") } @@ -35,9 +36,21 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r now := time.Now() batch := ID.GenReadable("bat") + // 获取用户套餐 + resource, err := findResource(q.Q, resourceId, count, now) + if err != nil { + return nil, err + } + + // 检查用户 + user := resource.User + if user.IDToken == nil || *user.IDToken == "" { + return nil, core.NewBizErr("账号未实名") + } + // 获取用户白名单并检查用户 ip 地址 whitelists, err := q.Whitelist.Where( - q.Whitelist.UserID.Eq(userId), + q.Whitelist.UserID.Eq(user.ID), ).Find() if err != nil { return nil, err @@ -55,11 +68,35 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r return nil, core.NewBizErr(fmt.Sprintf("IP 地址 %s 不在白名单内", source.String())) } - // 获取用户套餐并检查 - resource, err := findResource(q.Q, resourceId, userId, count, now) - if err != nil { - return nil, err + // 检查套餐使用情况 + switch resource.Mode { + default: + return nil, core.NewBizErr("不支持的套餐模式") + + // 包时 + case m.ResourceModeTime: + // 检查过期时间 + if resource.Expire.Before(now) { + return nil, ErrResourceExpired + } + // 检查每日限额 + used := 0 + if now.Format("2006-01-02") == resource.DailyLast.Format("2006-01-02") { + used = int(resource.DailyUsed) + } + excess := used+count > int(resource.DailyLimit) + if excess { + return nil, ErrResourceDailyLimit + } + + // 包量 + case m.ResourceModeQuota: + // 检查可用配额 + if int(resource.Quota)-int(resource.Used) < count { + return nil, ErrResourceExhausted + } } + expire := now.Add(resource.Live) // 获取可用通道 @@ -104,18 +141,19 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r // 使用记录 actions[i] = &m.LogsUserUsage{ - UserID: userId, + UserID: user.ID, ResourceID: resourceId, Count: int32(count), ISP: u.P(filter.Isp.String()), Prov: filter.Prov, City: filter.City, + IP: orm.Inet{Addr: source}, } // 通道数据 inet := orm.Inet{Addr: ch.Addr()} channels[i] = &m.Channel{ - UserID: userId, + UserID: user.ID, ResourceID: resourceId, BatchNo: batch, ProxyID: findProxy[inet].ID, @@ -124,9 +162,10 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r FilterProv: filter.Prov, FilterCity: filter.City, ExpiredAt: expire, + Proxy: *findProxy[inet], } if authWhitelist { - channels[i].Whitelists = &orm.Slice[string]{Arr: whitelistIPs} + channels[i].Whitelists = u.P(strings.Join(whitelistIPs, ",")) } if authPassword { username, password := genPassPair() @@ -181,7 +220,9 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r } // 保存通道和分配记录 - err = q.Channel.Create(channels...) + err = q.Channel. + Omit(field.AssociationFields). + Create(channels...) if err != nil { return core.NewServErr("保存通道失败", err) } @@ -226,7 +267,7 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r }, } if authWhitelist { - configs[i].Whitelist = &channel.Whitelists.Arr + configs[i].Whitelist = &whitelistIPs } if authPassword { configs[i].Userpass = u.P(fmt.Sprintf("%s:%s", *channel.Username, *channel.Password)) @@ -248,9 +289,10 @@ func (s *channelBaiyinService) CreateChannels(source netip.Addr, userId int32, r } func (s *channelBaiyinService) RemoveChannels(batch string, ids []int32) error { + start := time.Now() // 获取连接数据 - channels, err := q.Channel.Debug(). + channels, err := q.Channel. Preload(q.Channel.Proxy). Where(q.Channel.ID.In(ids...)). Find() @@ -279,7 +321,7 @@ func (s *channelBaiyinService) RemoveChannels(batch string, ids []int32) error { // 释放端口 err = freeChans(batch, chans) if err != nil { - return core.NewServErr("释放端口失败", err) + return err } // 清空配置 @@ -304,9 +346,10 @@ func (s *channelBaiyinService) RemoveChannels(batch string, ids []int32) error { } } else { bytes, _ := json.Marshal(configs) - slog.Debug("清除代理端口配置", "config", bytes) + slog.Debug("清除代理端口配置", "config", string(bytes)) } } + slog.Debug("清除代理端口配置", "time", time.Since(start).String()) return nil } diff --git a/web/tasks/task.go b/web/tasks/task.go index 2a55090..e7587fe 100644 --- a/web/tasks/task.go +++ b/web/tasks/task.go @@ -11,11 +11,13 @@ import ( e "platform/web/events" g "platform/web/globals" m "platform/web/models" + q "platform/web/queries" s "platform/web/services" "strings" "time" "github.com/hibiken/asynq" + "gorm.io/datatypes" ) func HandleCompleteTrade(_ context.Context, task *asynq.Task) (err error) { @@ -58,7 +60,7 @@ func HandleRemoveChannel(_ context.Context, task *asynq.Task) (err error) { } func HandleFlushGateway(_ context.Context, task *asynq.Task) (err error) { - now := time.Now() + start := time.Now() // 获取所有网关:配置组 proxies, err := s.Proxy.AllProxies(m.ProxyTypeBaiYin, true) @@ -120,16 +122,26 @@ func HandleFlushGateway(_ context.Context, task *asynq.Task) (err error) { } if env.DebugExternalChange { - g.Cloud.CloudConnect(g.CloudConnectReq{ + err := g.Cloud.CloudConnect(g.CloudConnectReq{ Uuid: proxy.Mac, AutoConfig: configs, }) + if err != nil { + slog.Error("提交代理后备配置失败", "error", err) + } } else { bytes, _ := json.Marshal(configs) slog.Debug("更新代理后备配置", "config", string(bytes)) } + + _, err := q.Proxy. + Where(q.Proxy.ID.Eq(proxy.ID)). + UpdateSimple(q.Proxy.Meta.Value(datatypes.NewJSONType(configs))) + if err != nil { + slog.Error("更新代理后备配置失败", "error", err) + } } - slog.Debug("更新代理后备配置", "time", time.Since(now).String()) + slog.Debug("更新代理后备配置", "time", time.Since(start).String()) return nil }