重构通道创建逻辑

This commit is contained in:
2025-05-08 19:02:07 +08:00
parent e2cc318560
commit 623e9652d9
3 changed files with 317 additions and 328 deletions

View File

@@ -16,7 +16,7 @@ import (
"platform/web/auth"
"platform/web/core"
g "platform/web/globals"
"platform/web/models"
m "platform/web/models"
q "platform/web/queries"
"strconv"
"strings"
@@ -32,23 +32,6 @@ var Channel = &channelService{}
type channelService struct {
}
type ChannelAuthType int
const (
ChannelAuthTypeAll ChannelAuthType = iota
ChannelAuthTypeIp
ChannelAuthTypePass
)
type ChannelProtocol int32
const (
ProtocolAll ChannelProtocol = iota
ProtocolHTTP
ProtocolHttps
ProtocolSocks5
)
type ResourceInfo struct {
Id int32
UserId int32
@@ -135,7 +118,7 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
// 组织数据
var configMap = make(map[int32][]g.PortConfigsReq, len(proxies))
var proxyMap = make(map[int32]*models.Proxy, len(proxies))
var proxyMap = make(map[int32]*m.Proxy, len(proxies))
for _, proxy := range proxies {
configMap[proxy.ID] = make([]g.PortConfigsReq, 0)
proxyMap[proxy.ID] = proxy
@@ -245,185 +228,152 @@ func (s *channelService) CreateChannel(
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*PortInfo, error) {
) (newChannels []*m.Channel, err error) {
var now = time.Now()
var step = time.Now()
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
filter := NodeFilterConfig{}
var filter = NodeFilterConfig{}
if len(nodeFilter) > 0 {
filter = nodeFilter[0]
}
var addr []*PortInfo
err := q.Q.Transaction(func(q *q.Query) error {
err = q.Q.Transaction(func(q *q.Query) (err error) {
// 查找套餐
step = time.Now()
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
resource, err := findResource(q, rid, resourceId, authCtx, count)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
// 禁止 id 猜测
return ChannelServiceErr("无权限访问")
return err
}
// 查找网关
proxies, err := findProxies(q, rid)
if err != nil {
return err
}
// 查找已使用的节点
channels, err := findChannels(q, rid, proxies)
if err != nil {
return err
}
// 查找白名单
var whitelist *[]string
if authType == ChannelAuthTypeIp {
whitelist, err = findWhitelist(q, rid, authCtx.Payload.Id)
if err != nil {
return err
}
return err
}
slog.Debug("查找套餐", "rid", rid, "step", time.Since(step))
// 检查用户权限
err = checkUser(authCtx, resource, count)
// 分配节点
var expire = now.Add(time.Duration(resource.Live) * time.Second)
newChannels, err = calcChannels(proxies, channels, whitelist, count, authCtx.Payload.Id, protocol, authType, expire, filter)
if err != nil {
return err
}
// 申请节点
step = time.Now()
edgeAssigns, err := assignEdge(q, count, filter)
if err != nil {
return err
}
slog.Debug("申请节点", "rid", rid, "total", time.Since(step))
// 分配端口
step = time.Now()
expiration := core.LocalDateTime(now.Add(time.Duration(resource.Live) * time.Second))
_addr, channels, err := assignPort(q, edgeAssigns, authCtx.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return err
}
addr = _addr
slog.Debug("分配端口", "rid", rid, "total", time.Since(step))
// 更新套餐使用记录
step = time.Now()
toUpdate := models.ResourcePss{
Used: resource.Used + int32(count),
DailyLast: core.LocalDateTime(now),
}
var last = time.Time(resource.DailyLast)
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
toUpdate.DailyUsed = int32(count)
} else {
toUpdate.DailyUsed = resource.DailyUsed + int32(count)
}
_, err = q.ResourcePss.
Where(q.ResourcePss.ResourceID.Eq(resourceId)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast,
).
Updates(toUpdate)
err = updateResource(rid, resource, count, now)
if err != nil {
return err
}
slog.Debug("更新套餐使用记录", "rid", rid, "step", time.Since(step))
// 保存通道
err = saveChannels(newChannels)
if err != nil {
return err
}
// 缓存通道数据
step = time.Now()
err = cache(ctx, channels)
err = cacheChannels(ctx, rid, newChannels)
if err != nil {
return err
}
slog.Debug("缓存通道数据", "rid", rid, "step", time.Since(step))
return nil
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
if err != nil {
return nil, err
}
return addr, nil
return newChannels, nil
}
func checkUser(authCtx *auth.Context, resource *ResourceInfo, count int) error {
func findResource(q *q.Query, rid string, resourceId int32, authCtx *auth.Context, count int) (*ResourceInfo, error) {
var step = time.Now()
// 检查使用人
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != resource.UserId {
return core.ForbiddenErr("无权限访问")
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(
data.ID.Eq(resourceId),
data.UserID.Eq(authCtx.Payload.Id),
).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrResourceNotExist // 防止 id 猜测
}
return nil, err
}
// 检查套餐状态
if !resource.Active {
return ChannelServiceErr("套餐已失效")
return nil, ErrResourceInvalid
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == time.Time(resource.DailyLast).Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
return nil, ErrResourceDailyLimit
}
// 检查时间或配额
if resource.Type == 1 { // 包时
if time.Time(resource.Expire).Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
return nil, ErrResourceExpired
}
} else { // 包量
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
return nil, ErrResourceExhausted
}
}
return nil
slog.Debug("查找套餐", "rid", rid, "step", time.Since(step))
return resource, nil
}
// assignEdge 分配边缘节点数量
func assignEdge(q *q.Query, count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
// 查询可以使用的网关
func findProxies(q *q.Query, rid string) (proxies []*m.Proxy, err error) {
var step = time.Now()
proxies, err := q.Proxy.
proxies, err = q.Proxy.
Where(q.Proxy.Type.Eq(1)).
Find()
if err != nil {
return nil, err
}
slog.Debug("查找网关", "step", time.Since(step))
slog.Debug("查找网关", "rid", rid, "step", time.Since(step))
return proxies, nil
}
// 查询已配置的节点
step = time.Now()
rProxyConfigs, err := g.Cloud.CloudAutoQuery()
if err != nil {
return nil, err
}
slog.Debug("查询已配置节点 (remote)", "step", time.Since(step))
// 查询已使用的节点
step = time.Now()
func findChannels(q *q.Query, rid string, proxies []*m.Proxy) (channels []*m.Channel, err error) {
var step = time.Now()
var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies {
proxyIds[i] = proxy.ID
}
channels, err := q.Channel.Debug().
channels, err = q.Channel.Debug().
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
@@ -437,74 +387,111 @@ func assignEdge(q *q.Query, count int, filter NodeFilterConfig) (*AssignEdgeResu
if err != nil {
return nil, err
}
var proxyUses = make(map[int32]int, len(channels))
for _, channel := range channels {
proxyUses[channel.ProxyID]++
slog.Debug("查找已使用节点", "rid", rid, "step", time.Since(step))
return channels, nil
}
func findWhitelist(q *q.Query, rid string, userId int32) (*[]string, error) {
var step = time.Now()
// 按需查找用户白名单
var whitelist []string
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, err
}
if len(whitelist) == 0 {
return nil, ChannelServiceErr("用户没有白名单")
}
slog.Debug("查找已使用节点", "step", time.Since(step))
slog.Debug("查找用户白名单", "rid", rid, "step", time.Since(step))
return &whitelist, nil
}
// 组织数据
var infos = make([]*ProxyInfo, len(proxies))
for i, proxy := range proxies {
infos[i] = &ProxyInfo{
proxy: proxy,
used: proxyUses[proxy.ID],
}
func calcChannels(
proxies []*m.Proxy,
allChannels []*m.Channel,
whitelist *[]string,
count int,
userId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]*m.Channel, error) {
var step = time.Now()
rConfigs, ok := rProxyConfigs[proxy.Name]
if !ok {
infos[i].count = 0
continue
}
// 查询已配置的节点
remoteConfigs, err := g.Cloud.CloudAutoQuery()
if err != nil {
return nil, err
}
for _, rConfig := range rConfigs {
if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov {
infos[i].count = rConfig.Count
// 统计已用节点量与端口查找表
var proxyUses = make(map[int32]int, len(allChannels))
var portsMap = make(map[uint64]struct{})
for _, channel := range allChannels {
proxyUses[channel.ProxyID]++
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
portsMap[key] = struct{}{}
}
// 计算分配额度
var total = len(allChannels) + count
var avg = int(math.Ceil(float64(total) / float64(len(proxies))))
// 分配节点
var newChannels []*m.Channel
for _, proxy := range proxies {
// 分配前后的节点量
var prev = proxyUses[proxy.ID]
var next = int(math.Max(float64(prev), float64(int(math.Min(float64(avg), float64(total))))))
total -= next
// 网关配置的节点量
var count = 0
remoteConfig, ok := remoteConfigs[proxy.Name]
if ok {
for _, config := range remoteConfig {
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov {
count = config.Count
break
}
}
}
}
// 分配新增的节点
var configs = make([]*ProxyConfig, len(proxies))
var needed = len(channels) + count
avg := int(math.Ceil(float64(needed) / float64(len(proxies))))
for i, info := range infos {
var prev = info.used
var next = int(math.Min(float64(avg), float64(needed)))
info.used = int(math.Max(float64(prev), float64(next)))
needed -= info.used
if env.DebugExternalChange && info.used > info.count {
if env.DebugExternalChange && next > count {
step = time.Now()
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
rConfigs := rProxyConfigs[info.proxy.Name]
var multiple float64 = 2 // 扩张倍数
var newConfig = g.AutoConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: int(math.Ceil(float64(info.used) * 2)),
Count: int(math.Ceil(float64(next) * multiple)),
}
var newConfigs []g.AutoConfig
var update = false
for _, rConfig := range rConfigs {
if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov {
newConfigs = append(newConfigs, newConfig)
update = true
} else {
newConfigs = append(newConfigs, rConfig)
}
}
if !update {
if count == 0 {
newConfigs = append(newConfigs, newConfig)
} else {
newConfigs = make([]g.AutoConfig, len(remoteConfig))
for i, config := range remoteConfig {
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov {
count = config.Count
break
}
newConfigs[i] = config
}
}
err := g.Cloud.CloudConnect(g.CloudConnectReq{
Uuid: info.proxy.Name,
Uuid: proxy.Name,
Edge: nil,
AutoConfig: newConfigs,
})
@@ -512,86 +499,23 @@ func assignEdge(q *q.Query, count int, filter NodeFilterConfig) (*AssignEdgeResu
return nil, err
}
slog.Debug("分配新增的节点", "step", time.Since(step))
slog.Debug("提交节点配置",
slog.Duration("step", time.Since(step)),
slog.String("proxy", proxy.Name),
slog.Int("used", prev),
slog.Int("count", next),
)
}
configs[i] = &ProxyConfig{
proxy: info.proxy,
count: int(math.Max(float64(next-prev), 0)),
// 节点增量
var acc = next - prev
if acc <= 0 {
continue
}
}
return &AssignEdgeResult{
configs: configs,
channels: channels,
}, nil
}
type ProxyInfo struct {
proxy *models.Proxy
used int
count int
}
type AssignEdgeResult struct {
configs []*ProxyConfig
channels []*models.Channel
}
type ProxyConfig struct {
proxy *models.Proxy
count int
}
// assignPort 分配指定数量的端口
func assignPort(
q *q.Query,
proxies *AssignEdgeResult,
userId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
expiration core.LocalDateTime,
filter NodeFilterConfig,
) ([]*PortInfo, []*models.Channel, error) {
var step time.Time
var configs = proxies.configs
var exists = proxies.channels
// 端口查找表
var portsMap = make(map[uint64]struct{})
for _, channel := range exists {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
portsMap[key] = struct{}{}
}
println(len(portsMap))
// 查找用户白名单
var whitelist []string
if authType == ChannelAuthTypeIp {
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, nil, err
}
if len(whitelist) == 0 {
return nil, nil, ChannelServiceErr("用户没有白名单")
}
}
// 配置启用代理
var result []*PortInfo
var channels []*models.Channel
for _, config := range configs {
var err error
var proxy = config.proxy
var count = config.count
// 筛选可用端口
var configs = make([]g.PortConfigsReq, 0, count)
for port := 10000; port < 20000 && len(configs) < count; port++ {
var portConfigs = make([]g.PortConfigsReq, 0, acc)
for port := 10000; port < 20000 && len(portConfigs) < acc; port++ {
// 跳过存在的端口
key := uint64(proxy.ID)<<32 | uint64(port)
_, ok := portsMap[key]
@@ -600,8 +524,7 @@ func assignPort(
}
// 配置新端口
var i = len(configs)
configs = append(configs, g.PortConfigsReq{
var portConf = g.PortConfigsReq{
Port: port,
Edge: nil,
Status: true,
@@ -612,78 +535,42 @@ func assignPort(
Count: u.P(1),
PacketLoss: 30,
},
})
}
var newChannel = &m.Channel{
UserID: userId,
ProxyID: proxy.ID,
ProxyHost: proxy.Host,
ProxyPort: int32(port),
Protocol: int32(protocol),
Expiration: core.LocalDateTime(expiration),
}
switch authType {
case ChannelAuthTypeIp:
configs[i].Whitelist = &whitelist
configs[i].Userpass = u.P("")
for _, item := range whitelist {
channels = append(channels, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
UserHost: item,
ProxyHost: proxy.Host,
ProxyPort: int32(port),
AuthIP: true,
AuthPass: false,
Protocol: int32(protocol),
Expiration: expiration,
})
}
result = append(result, &PortInfo{
Proto: protocol,
Host: proxy.Host,
Port: port,
})
portConf.Whitelist = whitelist
portConf.Userpass = u.P("")
newChannel.AuthIP = true
case ChannelAuthTypePass:
username, password := genPassPair()
configs[i].Whitelist = &[]string{}
configs[i].Userpass = u.P(fmt.Sprintf("%s:%s", username, password))
channels = append(channels, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
ProxyHost: proxy.Host,
ProxyPort: int32(port),
AuthIP: false,
AuthPass: true,
Username: username,
Password: password,
Protocol: int32(protocol),
Expiration: expiration,
})
result = append(result, &PortInfo{
Proto: protocol,
Host: proxy.Host,
Port: port,
Username: &username,
Password: &password,
})
portConf.Whitelist = &[]string{}
portConf.Userpass = u.P(fmt.Sprintf("%s:%s", username, password))
newChannel.AuthPass = true
newChannel.Username = username
newChannel.Password = password
default:
return nil, nil, ChannelServiceErr("不支持的通道认证方式")
return nil, ChannelServiceErr("不支持的通道认证方式")
}
portConfigs = append(portConfigs, portConf)
newChannels = append(newChannels, newChannel)
}
if len(configs) < count {
return nil, nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
if len(portConfigs) < acc {
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 保存到数据库
step = time.Now()
err = q.Channel.
Omit(
q.Channel.NodeID,
q.Channel.NodeHost,
q.Channel.DeletedAt,
).
Create(channels...)
if err != nil {
return nil, nil, err
}
slog.Debug("保存到数据库", "step", time.Since(step))
// 提交端口配置并更新节点列表
if env.DebugExternalChange {
step = time.Now()
@@ -694,24 +581,77 @@ func assignPort(
secret[0],
secret[1],
)
err = gateway.GatewayPortConfigs(configs)
err = gateway.GatewayPortConfigs(portConfigs)
if err != nil {
return nil, nil, err
return nil, err
}
slog.Debug("提交端口配置", "step", time.Since(step))
}
}
return result, channels, nil
slog.Debug("申请节点", "rid", step, "total", time.Since(step))
return newChannels, nil
}
type PortInfo struct {
Proto ChannelProtocol `json:"-"`
Host string `json:"host"`
Port int `json:"port"`
Username *string `json:"username,omitempty"`
Password *string `json:"password,omitempty"`
func updateResource(rid string, resource *ResourceInfo, count int, now time.Time) (err error) {
var step = time.Now()
toUpdate := m.ResourcePss{
Used: resource.Used + int32(count),
DailyLast: core.LocalDateTime(now),
}
var last = time.Time(resource.DailyLast)
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
toUpdate.DailyUsed = int32(count)
} else {
toUpdate.DailyUsed = resource.DailyUsed + int32(count)
}
_, err = q.ResourcePss.
Where(q.ResourcePss.ResourceID.Eq(resource.Id)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast,
).
Updates(toUpdate)
if err != nil {
return err
}
slog.Debug("更新套餐使用记录", "rid", rid, "step", time.Since(step))
return nil
}
func saveChannels(channels []*m.Channel) (err error) {
// 保存到数据库
var step = time.Now()
err = q.Channel.
Omit(
q.Channel.NodeID,
q.Channel.NodeHost,
q.Channel.DeletedAt,
).
Create(channels...)
if err != nil {
return err
}
slog.Debug("保存到数据库", "step", time.Since(step))
return nil
}
func cacheChannels(ctx context.Context, rid string, channels []*m.Channel) (err error) {
var step = time.Now()
err = cache(ctx, channels)
if err != nil {
return err
}
slog.Debug("缓存通道数据", "rid", rid, "step", time.Since(step))
return nil
}
// endregion
@@ -735,11 +675,11 @@ func genPassPair() (string, string) {
return string(username), string(password)
}
func chKey(channel *models.Channel) string {
func chKey(channel *m.Channel) string {
return fmt.Sprintf("channel:%d", channel.ID)
}
func cache(ctx context.Context, channels []*models.Channel) error {
func cache(ctx context.Context, channels []*m.Channel) error {
if len(channels) == 0 {
return nil
}
@@ -769,7 +709,7 @@ func cache(ctx context.Context, channels []*models.Channel) error {
return nil
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
func deleteCache(ctx context.Context, channels []*m.Channel) error {
if len(channels) == 0 {
return nil
}
@@ -786,8 +726,33 @@ func deleteCache(ctx context.Context, channels []*models.Channel) error {
return nil
}
type ChannelAuthType int
const (
ChannelAuthTypeAll ChannelAuthType = iota
ChannelAuthTypeIp
ChannelAuthTypePass
)
type ChannelProtocol int32
const (
ProtocolAll ChannelProtocol = iota
ProtocolHTTP
ProtocolHttps
ProtocolSocks5
)
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}
const (
ErrResourceNotExist = ChannelServiceErr("套餐不存在")
ErrResourceInvalid = ChannelServiceErr("套餐不可用")
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
ErrResourceExpired = ChannelServiceErr("套餐已过期")
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
)