实现云端控制的动态节点分配逻辑

This commit is contained in:
2025-03-28 10:03:29 +08:00
parent e337a9c08e
commit e61f0bef32
16 changed files with 1313 additions and 138 deletions

View File

@@ -4,11 +4,15 @@ import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"platform/pkg/orm"
"platform/pkg/rds"
"platform/pkg/remote"
"platform/web/common"
"platform/web/models"
q "platform/web/queries"
"strconv"
"time"
"github.com/google/uuid"
@@ -82,7 +86,7 @@ func (s *channelService) CreateChannel(
}
// 筛选可用节点
nodes, err := Node.Filter(auth.Payload.Id, protocol, count, nodeFilter...)
nodes, err := Node.Filter(ctx, auth.Payload.Id, count, nodeFilter...)
if err != nil {
return err
}
@@ -103,10 +107,11 @@ func (s *channelService) CreateChannel(
channels = append(channels, &models.Channel{
UserID: auth.Payload.Id,
NodeID: node.ID,
NodePort: node.FwdPort,
UserHost: allowed.Host,
NodeHost: node.Host,
ProxyPort: node.ProxyPort,
Protocol: string(protocol),
AuthIP: authType == ChannelAuthTypeIp,
UserHost: allowed.Host,
AuthPass: authType == ChannelAuthTypePass,
Username: username,
Password: password,
@@ -146,17 +151,7 @@ func (s *channelService) CreateChannel(
}
// 缓存通道信息与异步删除任务
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
pipe.Set(ctx, chKey(channel), channel, channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err = pipe.Exec(ctx)
err = cache(ctx, channels)
if err != nil {
return nil, err
}
@@ -194,7 +189,7 @@ func genPassPair() (string, string) {
return username, password
}
func (s *channelService) RemoveChannels(auth *AuthContext, id ...int32) error {
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
var channels []*models.Channel
@@ -231,16 +226,16 @@ func (s *channelService) RemoveChannels(auth *AuthContext, id ...int32) error {
}
// 删除缓存,异步任务直接在消费端处理删除
pipe := rds.Client.TxPipeline()
for _, channel := range channels {
pipe.Del(context.Background(), chKey(channel))
err = deleteCache(ctx, channels)
if err != nil {
return err
}
return nil
}
func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%s:%d", channel.UserHost, channel.NodePort)
return fmt.Sprintf("channel:%s:%s", channel.UserHost, channel.NodeHost)
}
type ChannelServiceErr string
@@ -248,3 +243,327 @@ type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}
// region channel by remote
func (s *channelService) RemoteCreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
filter := NodeFilterConfig{}
if len(nodeFilter) > 0 {
filter = nodeFilter[0]
}
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(data.ALL, pss.ALL).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ChannelServiceErr("套餐不存在")
}
return nil, err
}
// 检查用户权限
err = checkUser(auth, resource, count)
if err != nil {
return nil, err
}
// 申请节点
assigned, err := assignEdge(count, filter)
if err != nil {
return nil, err
}
// 分配端口
expiration := time.Now().Add(time.Duration(resource.pss.Live) * time.Second)
channels, err := assignPort(assigned, auth.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return nil, err
}
// 缓存并关闭代理
err = cache(ctx, channels)
if err != nil {
return nil, err
}
return channels, nil
}
// endregion
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.data.UserID {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.data.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.pss.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.pss.DailyLimit-resource.pss.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.pss.Type == 1 { // 包时
if resource.pss.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.pss.Quota-resource.pss.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
return nil
}
// assignEdge 分配边缘节点数量
func assignEdge(count int, filter NodeFilterConfig) ([]AssignEdgeResult, error) {
// 查询现有节点连接情况
edgeConfigs, err := remote.Client.CloudAutoQuery()
if err != nil {
return nil, err
}
proxies, err := q.Proxy.
Where(q.Proxy.Type.Eq(1)).
Find()
if err != nil {
return nil, err
}
// 尽量平均分配节点用量
var total = count
for _, v := range edgeConfigs {
total += v.Count
}
avg := int(math.Ceil(float64(total) / float64(len(edgeConfigs))))
var result []AssignEdgeResult
var rCount = 0
for _, proxy := range proxies {
prev, ok := edgeConfigs[proxy.Name]
var nextCount = 0
if !ok || (prev.Count < avg && prev.Count < total) {
nextCount = int(math.Min(float64(avg), float64(total)))
result = append(result, AssignEdgeResult{
proxy: proxy,
count: nextCount - prev.Count,
})
total -= nextCount
} else {
continue
}
_rCount, err := remote.Client.CloudConnect(remote.CloudConnectReq{
Uuid: proxy.Name,
Edge: nil,
AutoConfig: []remote.AutoConfig{{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: nextCount,
}},
})
if err != nil {
return nil, err
}
rCount += _rCount
}
slog.Debug("cloud connect", "count", rCount)
return result, nil
}
type AssignEdgeResult struct {
proxy *models.Proxy
count int
}
// assignPort 分配指定数量的端口
func assignPort(
assigns []AssignEdgeResult,
userId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]*models.Channel, error) {
// 查询代理已配置端口
var proxyIds = make([]int32, 0, len(assigns))
for _, assigned := range assigns {
proxyIds = append(proxyIds, assigned.proxy.ID)
}
channels, err := q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
Where(
q.Channel.ProxyID.In(proxyIds...),
q.Channel.Expiration.Gt(time.Now())).
Group(
q.Channel.ProxyPort).
Find()
if err != nil {
return nil, err
}
// 端口查找表
var proxyPorts = make(map[uint64]struct{})
for _, channel := range channels {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
proxyPorts[key] = struct{}{}
}
// 配置启用代理
var result []*models.Channel
for i := 0; i < len(assigns); i++ {
proxy := assigns[i].proxy
count := assigns[i].count
// 筛选可用端口
var portConfigs = make([]remote.PortConfigsReq, count)
for port := 10000; port < 20000 || len(portConfigs) < count; port++ {
// 跳过存在的端口
key := uint64(proxy.ID)<<32 | uint64(port)
_, ok := proxyPorts[key]
if ok {
continue
}
// 配置新端口
portConfigs[port] = remote.PortConfigsReq{
Port: strconv.Itoa(port),
Edge: nil,
Status: true,
AutoEdgeConfig: remote.AutoEdgeConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: 1,
},
}
switch authType {
case ChannelAuthTypeIp:
var whitelist []string
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, err
}
portConfigs[port].Whitelist = whitelist
for _, item := range whitelist {
result = append(result, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
UserHost: item,
ProxyPort: int32(port),
AuthIP: true,
AuthPass: false,
Protocol: string(protocol),
Expiration: expiration,
})
}
case ChannelAuthTypePass:
username, password := genPassPair()
portConfigs[port].Userpass = fmt.Sprintf("%s:%s", username, password)
result = append(result, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
ProxyPort: int32(port),
AuthIP: false,
AuthPass: true,
Username: username,
Password: password,
Protocol: string(protocol),
Expiration: expiration,
})
}
}
// 提交端口配置
gateway := remote.InitGateway(
proxy.Host,
"api",
"123456",
)
err = gateway.GatewayPortConfigs(portConfigs)
if err != nil {
return nil, err
}
}
err = q.Channel.Save(result...)
if err != nil {
return nil, err
}
return result, nil
}
func cache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
pipe.Set(ctx, chKey(channel), channel, channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
return nil
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
keys := make([]string, len(channels))
for i := range keys {
keys[i] = chKey(channels[i])
}
pipe.Del(ctx, keys...)
// 忽略异步任务zrem 效率较低,在使用时再删除
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
return nil
}
type ResourceInfo struct {
data models.Resource
pss models.ResourcePss
}

View File

@@ -1,16 +1,22 @@
package services
import (
"encoding/json"
"fmt"
"context"
"platform/pkg/orm"
"platform/web/models"
)
type NodeServiceErr string
func (e NodeServiceErr) Error() string {
return string(e)
}
var Node = &nodeService{}
type nodeService struct{}
func (s *nodeService) Filter(userId int32, count int, config ...NodeFilterConfig) ([]*FilteredNode, error) {
func (s *nodeService) Filter(ctx context.Context, userId int32, count int, config ...NodeFilterConfig) ([]*models.Node, error) {
_config := NodeFilterConfig{}
if len(config) > 0 {
_config = config[0]
@@ -24,24 +30,35 @@ func (s *nodeService) Filter(userId int32, count int, config ...NodeFilterConfig
Limit(count).
Find(&nodes)
rs, _ := json.Marshal(nodes)
fmt.Printf(string(rs))
// todo 异步任务关闭代理
// 返回节点列表
return nodes, nil
// todo 异步任务缩容
return nil, nil
}
type NodeFilterConfig struct {
Isp string
Prov string
City string
Isp string `json:"isp"`
Prov string `json:"prov"`
City string `json:"city"`
}
type NodeFilterAsyncTask struct {
Config NodeFilterConfig `json:"config"`
Count int `json:"count"`
}
// 筛选节点的SQL语句暂时用不到
// 筛选已连接的符合条件且未分配给用户过的节点
//
// 静态条件:省,市,运营商
// 排序方式1.分配给该用户的次数 2.分配给全部用户的次数
const filterSqlRaw = `
select
n.id as id,
n.name as name,
n.fwd_port as port,
n.host as host,
n.fwd_port as fwd_port,
count(c.*) as total,
count(c.*) filter ( where c.user_id = ? ) as assigned
from
@@ -61,7 +78,8 @@ order by
type FilteredNode struct {
Id int32 `json:"id"`
Name string `json:"name"`
Port int32 `json:"port"`
Host string `json:"host"`
FwdPort int32 `json:"fwd_port"`
Total int32 `json:"total"`
Assigned int32 `json:"assigned"`
}