重构通道管理逻辑,支持通过任务删除不同类型通道;引入 Asynq 处理异步任务;更新数据库结构以支持通道类型区分
This commit is contained in:
@@ -4,13 +4,13 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/hibiken/asynq"
|
||||
"gorm.io/gen/field"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"platform/pkg/env"
|
||||
"platform/pkg/u"
|
||||
"platform/web/auth"
|
||||
channel2 "platform/web/domains/channel"
|
||||
edge2 "platform/web/domains/edge"
|
||||
proxy2 "platform/web/domains/proxy"
|
||||
@@ -19,11 +19,11 @@ import (
|
||||
"platform/web/globals/orm"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"platform/web/tasks"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gofiber/fiber/v2/middleware/requestid"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
@@ -32,47 +32,71 @@ var Channel = &channelService{}
|
||||
type channelService struct {
|
||||
}
|
||||
|
||||
// region RemoveChannel
|
||||
// region 删除通道
|
||||
|
||||
func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error {
|
||||
func (s *channelService) RemoveChannels(id []int32, userId ...int32) error {
|
||||
var now = time.Now()
|
||||
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
|
||||
|
||||
err := q.Q.Transaction(func(tx *q.Query) error {
|
||||
|
||||
// 查找通道
|
||||
channels, err := tx.Channel.Where(
|
||||
q.Channel.ID.In(id...),
|
||||
).Find()
|
||||
var do = tx.Channel.Where(q.Channel.ID.In(id...))
|
||||
if len(userId) > 0 {
|
||||
do.Where(q.Channel.UserID.Eq(userId[0]))
|
||||
}
|
||||
channels, err := tx.Channel.Where(do).Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查权限,如果为用户操作的话,则只能删除自己的通道
|
||||
proxyMap := make(map[int32]*m.Proxy)
|
||||
proxyIds := make([]int32, 0)
|
||||
resourceMap := make(map[int32]*m.Resource)
|
||||
resourceIds := make([]int32, 0)
|
||||
for _, channel := range channels {
|
||||
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
|
||||
return ErrRemoveForbidden
|
||||
if _, ok := proxyMap[channel.ProxyID]; !ok {
|
||||
proxyIds = append(proxyIds, channel.ProxyID)
|
||||
proxyMap[channel.ProxyID] = &m.Proxy{}
|
||||
}
|
||||
if _, ok := resourceMap[channel.ResourceID]; !ok {
|
||||
resourceIds = append(resourceIds, channel.ResourceID)
|
||||
resourceMap[channel.ResourceID] = &m.Resource{}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找资源
|
||||
resources, err := tx.Resource.Where(tx.Resource.ID.In(resourceIds...)).Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, res := range resources {
|
||||
resourceMap[res.ID] = res
|
||||
}
|
||||
|
||||
// 查找代理
|
||||
proxySet := make(map[int32]struct{})
|
||||
proxyIds := make([]int32, 0)
|
||||
for _, channel := range channels {
|
||||
if _, ok := proxySet[channel.ProxyID]; !ok {
|
||||
proxyIds = append(proxyIds, channel.ProxyID)
|
||||
proxySet[channel.ProxyID] = struct{}{}
|
||||
}
|
||||
}
|
||||
proxies, err := tx.Proxy.Where(
|
||||
q.Proxy.ID.In(proxyIds...),
|
||||
).Find()
|
||||
proxies, err := tx.Proxy.Where(q.Proxy.ID.In(proxyIds...)).Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, proxy := range proxies {
|
||||
proxyMap[proxy.ID] = proxy
|
||||
}
|
||||
|
||||
// 区分通道类型
|
||||
shortToRemove := make([]*m.Channel, 0)
|
||||
longToRemove := make([]*m.Channel, 0)
|
||||
for _, channel := range channels {
|
||||
resource := resourceMap[channel.ResourceID]
|
||||
switch resource2.Type(resource.Type) {
|
||||
case resource2.TypeShort:
|
||||
shortToRemove = append(shortToRemove, channel)
|
||||
case resource2.TypeLong:
|
||||
longToRemove = append(longToRemove, channel)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除指定的通道
|
||||
result, err := tx.Channel.Debug().
|
||||
result, err := tx.Channel.
|
||||
Where(q.Channel.ID.In(id...)).
|
||||
Update(q.Channel.DeletedAt, now)
|
||||
if err != nil {
|
||||
@@ -86,95 +110,14 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
|
||||
if env.DebugExternalChange {
|
||||
var step = time.Now()
|
||||
|
||||
// 组织数据
|
||||
var configMap = make(map[int32][]g.PortConfigsReq, len(proxies))
|
||||
var proxyMap = make(map[int32]*m.Proxy, len(proxies))
|
||||
for _, proxy := range proxies {
|
||||
configMap[proxy.ID] = make([]g.PortConfigsReq, 0)
|
||||
proxyMap[proxy.ID] = proxy
|
||||
}
|
||||
var portMap = make(map[uint64]struct{})
|
||||
for _, channel := range channels {
|
||||
var config = g.PortConfigsReq{
|
||||
Port: int(channel.ProxyPort),
|
||||
Edge: &[]string{},
|
||||
AutoEdgeConfig: &g.AutoEdgeConfig{
|
||||
Count: u.P(0),
|
||||
},
|
||||
Status: false,
|
||||
}
|
||||
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
|
||||
|
||||
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
|
||||
portMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
slog.Debug("组织数据", "rid", rid, "step", time.Since(step))
|
||||
|
||||
// 更新配置
|
||||
for proxyId, configs := range configMap {
|
||||
if len(configs) == 0 {
|
||||
continue
|
||||
}
|
||||
proxy, ok := proxyMap[proxyId]
|
||||
if !ok {
|
||||
return ChannelServiceErr("代理不存在")
|
||||
}
|
||||
|
||||
var secret = strings.Split(proxy.Secret, ":")
|
||||
gateway := g.NewGateway(
|
||||
proxy.Host,
|
||||
secret[0],
|
||||
secret[1],
|
||||
)
|
||||
|
||||
// 查询节点配置
|
||||
step = time.Now()
|
||||
|
||||
actives, err := gateway.GatewayPortActive()
|
||||
if len(shortToRemove) > 0 {
|
||||
err := removeShortChannelExternal(proxies, shortToRemove)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("查询节点配置", "rid", rid, "step", time.Since(step))
|
||||
|
||||
// 更新节点配置
|
||||
step = time.Now()
|
||||
|
||||
err = gateway.GatewayPortConfigs(configs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("更新节点配置", "rid", rid, "step", time.Since(step))
|
||||
|
||||
// 下线对应节点
|
||||
step = time.Now()
|
||||
|
||||
var edges []string
|
||||
for portStr, active := range actives {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := uint64(proxyId)<<32 | uint64(port)
|
||||
if _, ok := portMap[key]; ok {
|
||||
edges = append(edges, active.Edge...)
|
||||
}
|
||||
}
|
||||
if len(edges) > 0 {
|
||||
_, err := g.Cloud.CloudDisconnect(g.CloudDisconnectReq{
|
||||
Uuid: proxy.Name,
|
||||
Edge: edges,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("下线对应节点", "rid", rid, "step", time.Since(step))
|
||||
}
|
||||
|
||||
slog.Debug("提交删除通道配置", "step", time.Since(step))
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -186,12 +129,91 @@ func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Conte
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeShortChannelExternal(proxies []*m.Proxy, channels []*m.Channel) error {
|
||||
// 组织数据
|
||||
var configMap = make(map[int32][]g.PortConfigsReq, len(proxies))
|
||||
var proxyMap = make(map[int32]*m.Proxy, len(proxies))
|
||||
for _, proxy := range proxies {
|
||||
configMap[proxy.ID] = make([]g.PortConfigsReq, 0)
|
||||
proxyMap[proxy.ID] = proxy
|
||||
}
|
||||
var portMap = make(map[uint64]struct{})
|
||||
for _, channel := range channels {
|
||||
var config = g.PortConfigsReq{
|
||||
Port: int(channel.ProxyPort),
|
||||
Edge: &[]string{},
|
||||
AutoEdgeConfig: &g.AutoEdgeConfig{
|
||||
Count: u.P(0),
|
||||
},
|
||||
Status: false,
|
||||
}
|
||||
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
|
||||
|
||||
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
|
||||
portMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
// 更新配置
|
||||
for proxyId, configs := range configMap {
|
||||
if len(configs) == 0 {
|
||||
continue
|
||||
}
|
||||
proxy, ok := proxyMap[proxyId]
|
||||
if !ok {
|
||||
return ChannelServiceErr("代理不存在")
|
||||
}
|
||||
|
||||
var secret = strings.Split(proxy.Secret, ":")
|
||||
gateway := g.NewGateway(
|
||||
proxy.Host,
|
||||
secret[0],
|
||||
secret[1],
|
||||
)
|
||||
|
||||
// 查询节点配置
|
||||
actives, err := gateway.GatewayPortActive()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新节点配置
|
||||
err = gateway.GatewayPortConfigs(configs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 下线对应节点
|
||||
var edges []string
|
||||
for portStr, active := range actives {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key := uint64(proxyId)<<32 | uint64(port)
|
||||
if _, ok := portMap[key]; ok {
|
||||
edges = append(edges, active.Edge...)
|
||||
}
|
||||
}
|
||||
if len(edges) > 0 {
|
||||
_, err := g.Cloud.CloudDisconnect(g.CloudDisconnectReq{
|
||||
Uuid: proxy.Name,
|
||||
Edge: edges,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region CreateChannel
|
||||
// region 创建通道
|
||||
|
||||
func (s *channelService) CreateChannel(
|
||||
authCtx *auth.Context,
|
||||
userId int32,
|
||||
resourceId int32,
|
||||
protocol channel2.Protocol,
|
||||
authType ChannelAuthType,
|
||||
@@ -204,10 +226,11 @@ func (s *channelService) CreateChannel(
|
||||
filter = edgeFilter[0]
|
||||
}
|
||||
|
||||
var resource *ResourceInfo
|
||||
err = q.Q.Transaction(func(q *q.Query) (err error) {
|
||||
|
||||
// 查找套餐
|
||||
resource, err := findResource(q, resourceId, authCtx.Payload.Id, count, now)
|
||||
resource, err = findResource(q, resourceId, userId, count, now)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -215,7 +238,7 @@ func (s *channelService) CreateChannel(
|
||||
// 查找白名单
|
||||
var whitelist []string
|
||||
if authType == ChannelAuthTypeIp {
|
||||
whitelist, err = findWhitelist(q, authCtx.Payload.Id)
|
||||
whitelist, err = findWhitelist(q, userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -232,10 +255,10 @@ func (s *channelService) CreateChannel(
|
||||
switch resource.Type {
|
||||
case resource2.TypeShort:
|
||||
config.Expiration = now.Add(time.Duration(resource.Live) * time.Second)
|
||||
channels, err = assignShortChannels(q, authCtx.Payload.Id, count, config, filter, now)
|
||||
channels, err = assignShortChannels(q, userId, resourceId, count, config, filter, now)
|
||||
case resource2.TypeLong:
|
||||
config.Expiration = now.Add(time.Duration(resource.Live) * time.Hour)
|
||||
channels, err = assignLongChannels(q, authCtx.Payload.Id, count, config, filter)
|
||||
channels, err = assignLongChannels(q, userId, resourceId, count, config, filter)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -253,6 +276,27 @@ func (s *channelService) CreateChannel(
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 定时异步删除过期通道
|
||||
var duration time.Duration
|
||||
switch resource.Type {
|
||||
case resource2.TypeShort:
|
||||
duration = time.Duration(resource.Live) * time.Second
|
||||
case resource2.TypeLong:
|
||||
duration = time.Duration(resource.Live) * time.Minute
|
||||
}
|
||||
|
||||
var ids = make([]int32, len(channels))
|
||||
for i := range channels {
|
||||
ids[i] = channels[i].ID
|
||||
}
|
||||
_, err = g.Asynq.Enqueue(
|
||||
tasks.NewRemoveChannel(ids),
|
||||
asynq.ProcessIn(duration),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
@@ -287,7 +331,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim
|
||||
info.DailyLast = time.Time(sub.DailyLast)
|
||||
info.Quota = sub.Quota
|
||||
info.Used = sub.Used
|
||||
info.Expire = time.Time(sub.DailyLast)
|
||||
info.Expire = time.Time(sub.Expire)
|
||||
case resource2.TypeLong:
|
||||
var sub = resource.Long
|
||||
info.Mode = resource2.Mode(sub.Type)
|
||||
@@ -297,7 +341,7 @@ func findResource(q *q.Query, resourceId int32, userId int32, count int, now tim
|
||||
info.DailyLast = time.Time(sub.DailyLast)
|
||||
info.Quota = sub.Quota
|
||||
info.Used = sub.Used
|
||||
info.Expire = time.Time(sub.DailyLast)
|
||||
info.Expire = time.Time(sub.Expire)
|
||||
}
|
||||
|
||||
// 检查套餐状态
|
||||
@@ -353,14 +397,7 @@ func findWhitelist(q *q.Query, userId int32) ([]string, error) {
|
||||
return whitelist, nil
|
||||
}
|
||||
|
||||
func assignShortChannels(
|
||||
q *q.Query,
|
||||
userId int32,
|
||||
count int,
|
||||
config ChannelCreateConfig,
|
||||
filter EdgeFilterConfig,
|
||||
now time.Time,
|
||||
) ([]*m.Channel, error) {
|
||||
func assignShortChannels(q *q.Query, userId int32, resourceId int32, count int, config ChannelCreateConfig, filter EdgeFilterConfig, now time.Time) ([]*m.Channel, error) {
|
||||
|
||||
// 查找网关
|
||||
proxies, err := q.Proxy.
|
||||
@@ -503,6 +540,7 @@ func assignShortChannels(
|
||||
var newChannel = &m.Channel{
|
||||
UserID: userId,
|
||||
ProxyID: proxy.ID,
|
||||
ResourceID: resourceId,
|
||||
ProxyHost: proxy.Host,
|
||||
ProxyPort: int32(port),
|
||||
Protocol: int32(config.Protocol),
|
||||
@@ -556,7 +594,7 @@ func assignShortChannels(
|
||||
return newChannels, nil
|
||||
}
|
||||
|
||||
func assignLongChannels(q *q.Query, userId int32, count int, config ChannelCreateConfig, filter EdgeFilterConfig) ([]*m.Channel, error) {
|
||||
func assignLongChannels(q *q.Query, userId int32, resourceId int32, count int, config ChannelCreateConfig, filter EdgeFilterConfig) ([]*m.Channel, error) {
|
||||
|
||||
// 查询符合条件的节点,根据 channel 统计使用次数
|
||||
var edges = make([]struct {
|
||||
@@ -631,6 +669,7 @@ func assignLongChannels(q *q.Query, userId int32, count int, config ChannelCreat
|
||||
UserID: userId,
|
||||
ProxyID: edge.ProxyID,
|
||||
EdgeID: edge.ID,
|
||||
ResourceID: resourceId,
|
||||
Protocol: int32(config.Protocol),
|
||||
AuthIP: config.AuthIp,
|
||||
AuthPass: config.AuthPass,
|
||||
@@ -752,8 +791,6 @@ func saveAssigns(q *q.Query, resource *ResourceInfo, channels []*m.Channel, now
|
||||
return nil
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
func genPassPair() (string, string) {
|
||||
//goland:noinspection SpellCheckingInspection
|
||||
var alphabet = []rune("abcdefghjkmnpqrstuvwxyz")
|
||||
@@ -773,6 +810,8 @@ func genPassPair() (string, string) {
|
||||
return string(username), string(password)
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
type ChannelAuthType int
|
||||
|
||||
const (
|
||||
@@ -820,6 +859,5 @@ const (
|
||||
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
|
||||
ErrResourceExpired = ChannelServiceErr("套餐已过期")
|
||||
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
|
||||
ErrRemoveForbidden = ChannelServiceErr("删除通道失败,当前用户没有权限")
|
||||
ErrEdgesNoAvailable = ChannelServiceErr("没有可用的节点")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user