实现长效套餐创建逻辑,并整合不同套餐类型的创建流程

This commit is contained in:
2025-05-17 18:59:43 +08:00
parent d9613a19fb
commit 3f8e48ec68
9 changed files with 391 additions and 297 deletions

View File

@@ -3,9 +3,9 @@ package services
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/gofiber/fiber/v2"
"gorm.io/gen/field"
"log/slog"
"math"
"math/rand/v2"
@@ -13,7 +13,9 @@ import (
"platform/pkg/u"
"platform/web/auth"
channel2 "platform/web/domains/channel"
edge2 "platform/web/domains/edge"
proxy2 "platform/web/domains/proxy"
resource2 "platform/web/domains/resource"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
@@ -24,7 +26,6 @@ import (
"github.com/gofiber/fiber/v2/middleware/requestid"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var Channel = &channelService{}
@@ -32,20 +33,6 @@ var Channel = &channelService{}
type channelService struct {
}
type ResourceInfo struct {
Id int32
UserId int32
Active bool
Type int32
Live int32
DailyLimit int32
DailyUsed int32
DailyLast orm.LocalDateTime
Quota int32
Used int32
Expire orm.LocalDateTime
}
// region RemoveChannel
func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error {
@@ -205,16 +192,14 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
// region CreateChannel
func (s *channelService) CreateChannel(
ctx context.Context,
authCtx *auth.Context,
resourceId int32,
protocol channel2.Protocol,
authType ChannelAuthType,
count int,
edgeFilter ...EdgeFilterConfig,
) (newChannels []*m.Channel, err error) {
) (channels []*m.Channel, err error) {
var now = time.Now()
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
var filter = EdgeFilterConfig{}
if len(edgeFilter) > 0 {
filter = edgeFilter[0]
@@ -223,25 +208,13 @@ func (s *channelService) CreateChannel(
err = q.Q.Transaction(func(q *q.Query) (err error) {
// 查找套餐
resource, err := findResource(q, resourceId, authCtx, count)
if err != nil {
return err
}
// 查找网关
proxies, err := findProxies(q)
if err != nil {
return err
}
// 查找已使用的节点
channels, err := findChannels(q, proxies)
resource, err := findResource(q, resourceId, authCtx.Payload.Id, count, now)
if err != nil {
return err
}
// 查找白名单
var whitelist *[]string
var whitelist []string
if authType == ChannelAuthTypeIp {
whitelist, err = findWhitelist(q, authCtx.Payload.Id)
if err != nil {
@@ -250,26 +223,26 @@ func (s *channelService) CreateChannel(
}
// 分配节点
var expire = now.Add(time.Duration(resource.Live) * time.Second)
newChannels, err = calcChannels(proxies, channels, whitelist, count, authCtx.Payload.Id, protocol, authType, expire, filter)
var config = ChannelCreateConfig{
Protocol: protocol,
AuthIp: authType == ChannelAuthTypeIp,
Whitelists: whitelist,
AuthPass: authType == ChannelAuthTypePass,
Expiration: now.Add(time.Duration(resource.Live) * time.Second),
}
switch resource2.Type(resource.Type) {
case resource2.TypeShort:
channels, err = assignShortChannels(q, authCtx.Payload.Id, count, config, filter, now)
case resource2.TypeLong:
channels, err = assignLongChannels(q, authCtx.Payload.Id, count, config, filter)
}
if err != nil {
return err
}
// 更新套餐使用记录
err = updateResource(rid, resource, count, now)
if err != nil {
return err
}
// 保存通道
err = saveChannels(newChannels)
if err != nil {
return err
}
// 缓存通道数据
err = cacheChannels(ctx, rid, newChannels, whitelist)
// 保存通道开通结果
err = saveAssigns(q, resource, channels, now)
if err != nil {
return err
}
@@ -280,93 +253,86 @@ func (s *channelService) CreateChannel(
return nil, err
}
return newChannels, nil
return channels, nil
}
func findResource(q *q.Query, resourceId int32, authCtx *auth.Context, count int) (*ResourceInfo, error) {
var resource = new(ResourceInfo)
data := q.Resource.As("data")
short := q.ResourceShort.As("short")
err := data.Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
short.Type, short.Live, short.DailyUsed, short.DailyLimit, short.DailyLast, short.Quota, short.Used, short.Expire,
func findResource(q *q.Query, resourceId int32, userId int32, count int, now time.Time) (*ResourceInfo, error) {
resource, err := q.Resource.
Preload(
q.Resource.Short.On(q.Resource.Type.Eq(int32(resource2.TypeShort))),
q.Resource.Long.On(q.Resource.Type.Eq(int32(resource2.TypeLong))),
).
LeftJoin(q.ResourceShort.As("short"), short.ResourceID.EqCol(data.ID)).
Where(
data.ID.Eq(resourceId),
data.UserID.Eq(authCtx.Payload.Id),
q.Resource.ID.Eq(resourceId),
q.Resource.UserID.Eq(userId),
).
Scan(&resource)
Take()
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResourceNotExist // 防止 id 猜测
}
return nil, err
return nil, ErrResourceNotExist
}
var info = &ResourceInfo{
Id: resource.ID,
Active: resource.Active,
Type: resource2.Type(resource.Type),
}
switch resource2.Type(resource.Type) {
case resource2.TypeShort:
var sub = resource.Short
info.Mode = resource2.Mode(sub.Type)
info.Live = sub.Live
info.DailyLimit = sub.DailyLimit
info.DailyUsed = sub.DailyUsed
info.DailyLast = time.Time(sub.DailyLast)
info.Quota = sub.Quota
info.Used = sub.Used
info.Expire = time.Time(sub.DailyLast)
case resource2.TypeLong:
var sub = resource.Long
info.Mode = resource2.Mode(sub.Type)
info.Live = sub.Live
info.DailyLimit = sub.DailyLimit
info.DailyUsed = sub.DailyUsed
info.DailyLast = time.Time(sub.DailyLast)
info.Quota = sub.Quota
info.Used = sub.Used
info.Expire = time.Time(sub.DailyLast)
}
// 检查套餐状态
if !resource.Active {
if !info.Active {
return nil, ErrResourceInvalid
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == time.Time(resource.DailyLast).Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
if today && dailyRemain < count {
used := 0
if now.Format("2006-01-02") == info.DailyLast.Format("2006-01-02") {
used = int(info.DailyUsed)
}
excess := used+count > int(info.DailyLimit)
if excess {
return nil, ErrResourceDailyLimit
}
// 检查时间或配额
if resource.Type == 1 { // 包时
if time.Time(resource.Expire).Before(time.Now()) {
switch info.Mode {
case resource2.ModeTime:
if info.Expire.Before(now) {
return nil, ErrResourceExpired
}
} else { // 包量
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
if remain < count {
case resource2.ModeCount:
if int(info.Quota)-int(info.Used) < count {
return nil, ErrResourceExhausted
}
default:
return nil, ChannelServiceErr("不支持的套餐模式")
}
return resource, nil
return info, nil
}
func findProxies(q *q.Query) (proxies []*m.Proxy, err error) {
proxies, err = q.Proxy.
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
Find()
if err != nil {
return nil, err
}
return proxies, nil
}
func findChannels(q *q.Query, proxies []*m.Proxy) (channels []*m.Channel, err error) {
var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies {
proxyIds[i] = proxy.ID
}
channels, err = q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
Where(
q.Channel.ProxyID.In(proxyIds...),
q.Channel.Expiration.Gt(orm.LocalDateTime(time.Now()))).
Group(
q.Channel.ProxyPort,
q.Channel.ProxyID).
Find()
if err != nil {
return nil, err
}
return channels, nil
}
func findWhitelist(q *q.Query, userId int32) (*[]string, error) {
func findWhitelist(q *q.Query, userId int32) ([]string, error) {
var whitelist []string
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
@@ -379,21 +345,45 @@ func findWhitelist(q *q.Query, userId int32) (*[]string, error) {
return nil, ChannelServiceErr("用户没有白名单")
}
return &whitelist, nil
return whitelist, nil
}
func calcChannels(
proxies []*m.Proxy,
allChannels []*m.Channel,
whitelist *[]string,
count int,
func assignShortChannels(
q *q.Query,
userId int32,
protocol channel2.Protocol,
authType ChannelAuthType,
expiration time.Time,
count int,
config ChannelCreateConfig,
filter EdgeFilterConfig,
now time.Time,
) ([]*m.Channel, error) {
var step = time.Now()
// 查找网关
proxies, err := q.Proxy.
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
Find()
if err != nil {
return nil, err
}
// 查找已使用的节点
var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies {
proxyIds[i] = proxy.ID
}
allChannels, err := q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
Where(
q.Channel.ProxyID.In(proxyIds...),
q.Channel.Expiration.Gt(orm.LocalDateTime(now))).
Group(
q.Channel.ProxyPort,
q.Channel.ProxyID).
Find()
if err != nil {
return nil, err
}
// 查询已配置的节点
remoteConfigs, err := g.Cloud.CloudAutoQuery()
@@ -415,15 +405,19 @@ func calcChannels(
var avg = int(math.Ceil(float64(total) / float64(len(proxies))))
// 分配节点
var newChannels []*m.Channel
var newChannels = make([]*m.Channel, 0, count)
for _, proxy := range proxies {
// 分配前后的节点量
var prev = proxyUses[proxy.ID]
var next = int(math.Max(float64(prev), float64(int(math.Min(float64(avg), float64(total))))))
total -= next
// 网关配置的节点量
var acc = next - prev
if acc <= 0 {
continue
}
// 获取远端配置量
var count = 0
remoteConfig, ok := remoteConfigs[proxy.Name]
if ok {
@@ -435,6 +429,7 @@ func calcChannels(
}
}
// 提交节点配置
if env.DebugExternalChange && next > count {
var step = time.Now()
@@ -477,13 +472,7 @@ func calcChannels(
)
}
// 节点增量
var acc = next - prev
if acc <= 0 {
continue
}
// 筛选可用端口
// 筛选可用端口 todo auth all
var portConfigs = make([]g.PortConfigsReq, 0, acc)
for port := 10000; port < 20000 && len(portConfigs) < acc; port++ {
// 跳过存在的端口
@@ -511,18 +500,18 @@ func calcChannels(
ProxyID: proxy.ID,
ProxyHost: proxy.Host,
ProxyPort: int32(port),
Protocol: int32(protocol),
Expiration: orm.LocalDateTime(expiration),
Protocol: int32(config.Protocol),
Expiration: orm.LocalDateTime(config.Expiration),
}
switch authType {
switch {
case ChannelAuthTypeIp:
portConf.Whitelist = whitelist
case config.AuthIp:
portConf.Whitelist = &config.Whitelists
portConf.Userpass = u.P("")
newChannel.AuthIP = true
case ChannelAuthTypePass:
case config.AuthPass:
username, password := genPassPair()
portConf.Whitelist = &[]string{}
portConf.Userpass = u.P(fmt.Sprintf("%s:%s", username, password))
@@ -541,7 +530,7 @@ func calcChannels(
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 提交端口配置并更新节点列表
// 提交端口配置
if env.DebugExternalChange {
var step = time.Now()
@@ -560,37 +549,99 @@ func calcChannels(
}
}
slog.Debug("申请节点", "total", time.Since(step))
if len(newChannels) != count {
return nil, ChannelServiceErr("分配节点失败")
}
return newChannels, nil
}
func updateResource(rid string, resource *ResourceInfo, count int, now time.Time) (err error) {
toUpdate := m.ResourceShort{
Used: resource.Used + int32(count),
DailyLast: orm.LocalDateTime(now),
}
var last = time.Time(resource.DailyLast)
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
toUpdate.DailyUsed = int32(count)
} else {
toUpdate.DailyUsed = resource.DailyUsed + int32(count)
}
_, err = q.ResourceShort.
Where(q.ResourceShort.ResourceID.Eq(resource.Id)).
Select(
q.ResourceShort.Used,
q.ResourceShort.DailyUsed,
q.ResourceShort.DailyLast,
func assignLongChannels(q *q.Query, userId int32, count int, config ChannelCreateConfig, filter EdgeFilterConfig) ([]*m.Channel, error) {
// 查询符合条件的节点,根据 channel 统计使用次数
var edges = make([]struct {
m.Edge
Count int
}, 0)
err := q.Edge.
LeftJoin(q.Channel, q.Channel.EdgeID.EqCol(q.Edge.ID)).
Select(q.Edge.ALL, q.Channel.ALL.Count().As("count")).
Group(q.Edge.ID).
Where(
q.Edge.Prov.Eq(filter.Prov),
q.Edge.City.Eq(filter.City),
q.Edge.Isp.Eq(int32(edge2.ISPFromStr(filter.Isp))),
q.Edge.Status.Eq(1),
).
Updates(toUpdate)
Order(field.NewField("", "count").Asc()).
Scan(edges)
if err != nil {
return nil, err
}
fmt.Printf("edges: %v\n", edges)
// 计算分配负载(考虑去重,维护一个节点使用记录表,优先分配未使用节点,达到算法额定负载后再选择负载最少的节点)
var total = count
for _, edge := range edges {
total += edge.Count
}
var avg = int(math.Ceil(float64(total) / float64(len(edges))))
var channels = make([]*m.Channel, 0, count)
for _, edge := range edges {
prev := edge.Count
next := int(math.Max(float64(prev), float64(int(math.Min(float64(avg), float64(total))))))
total -= next
acc := next - prev
if acc <= 0 {
continue
}
for range acc {
var channel = &m.Channel{
UserID: userId,
ProxyID: edge.ProxyID,
EdgeID: edge.ID,
Protocol: int32(config.Protocol),
AuthIP: config.AuthIp,
AuthPass: config.AuthPass,
Expiration: orm.LocalDateTime(config.Expiration),
}
if config.AuthPass {
username, password := genPassPair()
channel.Username = username
channel.Password = password
}
channels = append(channels, channel)
}
}
// todo 发送配置到网关
return channels, nil
}
func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now time.Time) (err error) {
// 缓存通道数据
pipe := g.Redis.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
expiration := time.Time(channel.Expiration)
zList = append(zList, redis.Z{
Score: float64(expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(context.Background(), "tasks:channel", zList...)
_, err = pipe.Exec(context.Background())
if err != nil {
return err
}
return nil
}
func saveChannels(channels []*m.Channel) (err error) {
// 保存通道
err = q.Channel.
Omit(
q.Channel.EdgeID,
@@ -602,31 +653,34 @@ func saveChannels(channels []*m.Channel) (err error) {
return err
}
return nil
}
func cacheChannels(ctx context.Context, rid string, channels []*m.Channel, whitelists *[]string) (err error) {
if len(channels) == 0 {
return nil
// 更新套餐使用记录
var count = len(channels)
var last = time.Time(resource.DailyLast)
var dailyUsed int32
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
dailyUsed = int32(count)
} else {
dailyUsed = resource.DailyUsed + int32(count)
}
pipe := g.Redis.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
expiration := time.Time(channel.Expiration)
keys := chAuthItems(channel, whitelists)
for _, key := range keys {
pipe.Set(ctx, key, true, time.Since(expiration))
}
zList = append(zList, redis.Z{
Score: float64(expiration.Unix()),
Member: channel.ID,
})
switch resource2.Type(resource.Type) {
case resource2.TypeShort:
_, err = q.ResourceShort.
Where(q.ResourceShort.ResourceID.Eq(resource.Id)).
UpdateSimple(
q.ResourceShort.Used.Add(int32(count)),
q.ResourceShort.DailyUsed.Value(dailyUsed),
q.ResourceShort.DailyLast.Value(orm.LocalDateTime(now)),
)
case resource2.TypeLong:
_, err = q.ResourceLong.
Where(q.ResourceLong.ResourceID.Eq(resource.Id)).
UpdateSimple(
q.ResourceLong.Used.Add(int32(count)),
q.ResourceLong.DailyUsed.Value(dailyUsed),
q.ResourceLong.DailyLast.Value(orm.LocalDateTime(now)),
)
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err = pipe.Exec(ctx)
if err != nil {
return err
}
@@ -655,31 +709,6 @@ func genPassPair() (string, string) {
return string(username), string(password)
}
func chAuthItems(channel *m.Channel, whitelists *[]string) []string {
var count = 1
var ips = make([]string, 0)
if channel.AuthIP && whitelists != nil {
count = len(*whitelists)
ips = *whitelists
}
var proxy = channel.ProxyHost + ":" + strconv.Itoa(int(channel.ProxyPort))
var sb = strings.Builder{}
var items = make([]string, count)
for i := range items {
// 权限 key 格式:<proxy_host>:<proxy_port>:<user_ip>?:<username>?:<password>?
sb.WriteString(proxy)
if channel.AuthIP {
sb.WriteString(":" + ips[i])
}
if channel.AuthPass {
sb.WriteString(":" + channel.Username + ":" + channel.Password)
}
}
return items
}
type ChannelAuthType int
const (
@@ -687,6 +716,34 @@ const (
ChannelAuthTypePass
)
type ChannelCreateConfig struct {
Protocol channel2.Protocol
AuthIp bool
Whitelists []string
AuthPass bool
Expiration time.Time
}
type EdgeFilterConfig struct {
Isp string `json:"isp"`
Prov string `json:"prov"`
City string `json:"city"`
}
type ResourceInfo struct {
Id int32
Active bool
Type resource2.Type
Mode resource2.Mode
Live int32
DailyLimit int32
DailyUsed int32
DailyLast time.Time
Quota int32
Used int32
Expire time.Time
}
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {