718 lines
16 KiB
Go
718 lines
16 KiB
Go
package services
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"math"
|
|
"math/rand/v2"
|
|
"platform/pkg/env"
|
|
"platform/pkg/orm"
|
|
"platform/pkg/rds"
|
|
"platform/pkg/u"
|
|
"platform/web/auth"
|
|
"platform/web/core"
|
|
channel2 "platform/web/domains/channel"
|
|
proxy2 "platform/web/domains/proxy"
|
|
g "platform/web/globals"
|
|
m "platform/web/models"
|
|
q "platform/web/queries"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v2/middleware/requestid"
|
|
"github.com/redis/go-redis/v9"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var Channel = &channelService{}
|
|
|
|
type channelService struct {
|
|
}
|
|
|
|
type ResourceInfo struct {
|
|
Id int32
|
|
UserId int32
|
|
Active bool
|
|
Type int32
|
|
Live int32
|
|
DailyLimit int32
|
|
DailyUsed int32
|
|
DailyLast core.LocalDateTime
|
|
Quota int32
|
|
Used int32
|
|
Expire core.LocalDateTime
|
|
}
|
|
|
|
// region RemoveChannel
|
|
|
|
func (s *channelService) RemoveChannels(ctx context.Context, authCtx *auth.Context, id ...int32) error {
|
|
var now = time.Now()
|
|
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
|
|
|
|
err := q.Q.Transaction(func(tx *q.Query) error {
|
|
|
|
// 查找通道
|
|
channels, err := tx.Channel.Where(
|
|
q.Channel.ID.In(id...),
|
|
).Find()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 检查权限,如果为用户操作的话,则只能删除自己的通道
|
|
for _, channel := range channels {
|
|
if authCtx.Payload.Type == auth.PayloadUser && authCtx.Payload.Id != channel.UserID {
|
|
return core.ForbiddenErr("无权限访问")
|
|
}
|
|
}
|
|
|
|
// 查找代理
|
|
proxySet := make(map[int32]struct{})
|
|
proxyIds := make([]int32, 0)
|
|
for _, channel := range channels {
|
|
if _, ok := proxySet[channel.ProxyID]; !ok {
|
|
proxyIds = append(proxyIds, channel.ProxyID)
|
|
proxySet[channel.ProxyID] = struct{}{}
|
|
}
|
|
}
|
|
proxies, err := tx.Proxy.Where(
|
|
q.Proxy.ID.In(proxyIds...),
|
|
).Find()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 删除指定的通道
|
|
result, err := tx.Channel.Debug().
|
|
Where(q.Channel.ID.In(id...)).
|
|
Update(q.Channel.DeletedAt, now)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if result.RowsAffected != int64(len(channels)) {
|
|
return ChannelServiceErr("删除通道失败")
|
|
}
|
|
|
|
// 删除缓存
|
|
err = deleteCache(ctx, channels)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 禁用代理端口并下线用过的节点
|
|
if env.DebugExternalChange {
|
|
var step = time.Now()
|
|
|
|
// 组织数据
|
|
var configMap = make(map[int32][]g.PortConfigsReq, len(proxies))
|
|
var proxyMap = make(map[int32]*m.Proxy, len(proxies))
|
|
for _, proxy := range proxies {
|
|
configMap[proxy.ID] = make([]g.PortConfigsReq, 0)
|
|
proxyMap[proxy.ID] = proxy
|
|
}
|
|
var portMap = make(map[uint64]struct{})
|
|
for _, channel := range channels {
|
|
var config = g.PortConfigsReq{
|
|
Port: int(channel.ProxyPort),
|
|
Edge: &[]string{},
|
|
AutoEdgeConfig: &g.AutoEdgeConfig{
|
|
Count: u.P(0),
|
|
},
|
|
Status: false,
|
|
}
|
|
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
|
|
|
|
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
|
|
portMap[key] = struct{}{}
|
|
}
|
|
|
|
slog.Debug("组织数据", "rid", rid, "step", time.Since(step))
|
|
|
|
// 更新配置
|
|
for proxyId, configs := range configMap {
|
|
if len(configs) == 0 {
|
|
continue
|
|
}
|
|
proxy, ok := proxyMap[proxyId]
|
|
if !ok {
|
|
return ChannelServiceErr("代理不存在")
|
|
}
|
|
|
|
var secret = strings.Split(proxy.Secret, ":")
|
|
gateway := g.NewGateway(
|
|
proxy.Host,
|
|
secret[0],
|
|
secret[1],
|
|
)
|
|
|
|
// 查询节点配置
|
|
step = time.Now()
|
|
|
|
actives, err := gateway.GatewayPortActive()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Debug("查询节点配置", "rid", rid, "step", time.Since(step))
|
|
|
|
// 更新节点配置
|
|
step = time.Now()
|
|
|
|
err = gateway.GatewayPortConfigs(configs)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
slog.Debug("更新节点配置", "rid", rid, "step", time.Since(step))
|
|
|
|
// 下线对应节点
|
|
step = time.Now()
|
|
|
|
var edges []string
|
|
for portStr, active := range actives {
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
key := uint64(proxyId)<<32 | uint64(port)
|
|
if _, ok := portMap[key]; ok {
|
|
edges = append(edges, active.Edge...)
|
|
}
|
|
}
|
|
if len(edges) > 0 {
|
|
_, err := g.Cloud.CloudDisconnect(g.CloudDisconnectReq{
|
|
Uuid: proxy.Name,
|
|
Edge: edges,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
slog.Debug("下线对应节点", "rid", rid, "step", time.Since(step))
|
|
}
|
|
|
|
}
|
|
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion
|
|
|
|
// region CreateChannel
|
|
|
|
func (s *channelService) CreateChannel(
|
|
ctx context.Context,
|
|
authCtx *auth.Context,
|
|
resourceId int32,
|
|
protocol channel2.Protocol,
|
|
authType ChannelAuthType,
|
|
count int,
|
|
nodeFilter ...NodeFilterConfig,
|
|
) (newChannels []*m.Channel, err error) {
|
|
var now = time.Now()
|
|
var rid = ctx.Value(requestid.ConfigDefault.ContextKey).(string)
|
|
var filter = NodeFilterConfig{}
|
|
if len(nodeFilter) > 0 {
|
|
filter = nodeFilter[0]
|
|
}
|
|
|
|
err = q.Q.Transaction(func(q *q.Query) (err error) {
|
|
|
|
// 查找套餐
|
|
resource, err := findResource(q, resourceId, authCtx, count)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 查找网关
|
|
proxies, err := findProxies(q)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 查找已使用的节点
|
|
channels, err := findChannels(q, proxies)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 查找白名单
|
|
var whitelist *[]string
|
|
if authType == ChannelAuthTypeIp {
|
|
whitelist, err = findWhitelist(q, authCtx.Payload.Id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
// 分配节点
|
|
var expire = now.Add(time.Duration(resource.Live) * time.Second)
|
|
newChannels, err = calcChannels(proxies, channels, whitelist, count, authCtx.Payload.Id, protocol, authType, expire, filter)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 更新套餐使用记录
|
|
err = updateResource(rid, resource, count, now)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 保存通道
|
|
err = saveChannels(newChannels)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// 缓存通道数据
|
|
err = cacheChannels(ctx, rid, newChannels)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}, &sql.TxOptions{Isolation: sql.LevelRepeatableRead})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return newChannels, nil
|
|
}
|
|
|
|
func findResource(q *q.Query, resourceId int32, authCtx *auth.Context, count int) (*ResourceInfo, error) {
|
|
var resource = new(ResourceInfo)
|
|
data := q.Resource.As("data")
|
|
pss := q.ResourcePss.As("pss")
|
|
err := data.Scopes(orm.Alias(data)).
|
|
Select(
|
|
data.ID, data.UserID, data.Active,
|
|
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
|
|
).
|
|
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
|
|
Where(
|
|
data.ID.Eq(resourceId),
|
|
data.UserID.Eq(authCtx.Payload.Id),
|
|
).
|
|
Scan(&resource)
|
|
if err != nil {
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return nil, ErrResourceNotExist // 防止 id 猜测
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
// 检查套餐状态
|
|
if !resource.Active {
|
|
return nil, ErrResourceInvalid
|
|
}
|
|
|
|
// 检查每日限额
|
|
today := time.Now().Format("2006-01-02") == time.Time(resource.DailyLast).Format("2006-01-02")
|
|
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
|
|
if today && dailyRemain < count {
|
|
return nil, ErrResourceDailyLimit
|
|
}
|
|
|
|
// 检查时间或配额
|
|
if resource.Type == 1 { // 包时
|
|
if time.Time(resource.Expire).Before(time.Now()) {
|
|
return nil, ErrResourceExpired
|
|
}
|
|
} else { // 包量
|
|
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
|
|
if remain < count {
|
|
return nil, ErrResourceExhausted
|
|
}
|
|
}
|
|
|
|
return resource, nil
|
|
}
|
|
|
|
func findProxies(q *q.Query) (proxies []*m.Proxy, err error) {
|
|
proxies, err = q.Proxy.
|
|
Where(q.Proxy.Type.Eq(int32(proxy2.TypeThirdParty))).
|
|
Find()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return proxies, nil
|
|
}
|
|
|
|
func findChannels(q *q.Query, proxies []*m.Proxy) (channels []*m.Channel, err error) {
|
|
var proxyIds = make([]int32, len(proxies))
|
|
for i, proxy := range proxies {
|
|
proxyIds[i] = proxy.ID
|
|
}
|
|
channels, err = q.Channel.
|
|
Select(
|
|
q.Channel.ProxyID,
|
|
q.Channel.ProxyPort).
|
|
Where(
|
|
q.Channel.ProxyID.In(proxyIds...),
|
|
q.Channel.Expiration.Gt(core.LocalDateTime(time.Now()))).
|
|
Group(
|
|
q.Channel.ProxyPort,
|
|
q.Channel.ProxyID).
|
|
Find()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return channels, nil
|
|
}
|
|
|
|
func findWhitelist(q *q.Query, userId int32) (*[]string, error) {
|
|
var whitelist []string
|
|
err := q.Whitelist.
|
|
Where(q.Whitelist.UserID.Eq(userId)).
|
|
Select(q.Whitelist.Host).
|
|
Scan(&whitelist)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(whitelist) == 0 {
|
|
return nil, ChannelServiceErr("用户没有白名单")
|
|
}
|
|
|
|
return &whitelist, nil
|
|
}
|
|
|
|
func calcChannels(
|
|
proxies []*m.Proxy,
|
|
allChannels []*m.Channel,
|
|
whitelist *[]string,
|
|
count int,
|
|
userId int32,
|
|
protocol channel2.Protocol,
|
|
authType ChannelAuthType,
|
|
expiration time.Time,
|
|
filter NodeFilterConfig,
|
|
) ([]*m.Channel, error) {
|
|
var step = time.Now()
|
|
|
|
// 查询已配置的节点
|
|
remoteConfigs, err := g.Cloud.CloudAutoQuery()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// 统计已用节点量与端口查找表
|
|
var proxyUses = make(map[int32]int, len(allChannels))
|
|
var portsMap = make(map[uint64]struct{})
|
|
for _, channel := range allChannels {
|
|
proxyUses[channel.ProxyID]++
|
|
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
|
|
portsMap[key] = struct{}{}
|
|
}
|
|
|
|
// 计算分配额度
|
|
var total = len(allChannels) + count
|
|
var avg = int(math.Ceil(float64(total) / float64(len(proxies))))
|
|
|
|
// 分配节点
|
|
var newChannels []*m.Channel
|
|
for _, proxy := range proxies {
|
|
|
|
// 分配前后的节点量
|
|
var prev = proxyUses[proxy.ID]
|
|
var next = int(math.Max(float64(prev), float64(int(math.Min(float64(avg), float64(total))))))
|
|
total -= next
|
|
|
|
// 网关配置的节点量
|
|
var count = 0
|
|
remoteConfig, ok := remoteConfigs[proxy.Name]
|
|
if ok {
|
|
for _, config := range remoteConfig {
|
|
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov {
|
|
count = config.Count
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
if env.DebugExternalChange && next > count {
|
|
step = time.Now()
|
|
|
|
var multiple float64 = 2 // 扩张倍数
|
|
var newConfig = g.AutoConfig{
|
|
Province: filter.Prov,
|
|
City: filter.City,
|
|
Isp: filter.Isp,
|
|
Count: int(math.Ceil(float64(next) * multiple)),
|
|
}
|
|
|
|
var newConfigs []g.AutoConfig
|
|
if count == 0 {
|
|
newConfigs = append(newConfigs, newConfig)
|
|
} else {
|
|
newConfigs = make([]g.AutoConfig, len(remoteConfig))
|
|
for i, config := range remoteConfig {
|
|
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov {
|
|
count = config.Count
|
|
break
|
|
}
|
|
newConfigs[i] = config
|
|
}
|
|
}
|
|
|
|
err := g.Cloud.CloudConnect(g.CloudConnectReq{
|
|
Uuid: proxy.Name,
|
|
Edge: nil,
|
|
AutoConfig: newConfigs,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
slog.Debug("提交节点配置",
|
|
slog.Duration("step", time.Since(step)),
|
|
slog.String("proxy", proxy.Name),
|
|
slog.Int("used", prev),
|
|
slog.Int("count", next),
|
|
)
|
|
}
|
|
|
|
// 节点增量
|
|
var acc = next - prev
|
|
if acc <= 0 {
|
|
continue
|
|
}
|
|
|
|
// 筛选可用端口
|
|
var portConfigs = make([]g.PortConfigsReq, 0, acc)
|
|
for port := 10000; port < 20000 && len(portConfigs) < acc; port++ {
|
|
// 跳过存在的端口
|
|
key := uint64(proxy.ID)<<32 | uint64(port)
|
|
_, ok := portsMap[key]
|
|
if ok {
|
|
continue
|
|
}
|
|
|
|
// 配置新端口
|
|
var portConf = g.PortConfigsReq{
|
|
Port: port,
|
|
Edge: nil,
|
|
Status: true,
|
|
AutoEdgeConfig: &g.AutoEdgeConfig{
|
|
Province: filter.Prov,
|
|
City: filter.City,
|
|
Isp: filter.Isp,
|
|
Count: u.P(1),
|
|
PacketLoss: 30,
|
|
},
|
|
}
|
|
var newChannel = &m.Channel{
|
|
UserID: userId,
|
|
ProxyID: proxy.ID,
|
|
ProxyHost: proxy.Host,
|
|
ProxyPort: int32(port),
|
|
Protocol: int32(protocol),
|
|
Expiration: core.LocalDateTime(expiration),
|
|
}
|
|
|
|
switch authType {
|
|
|
|
case ChannelAuthTypeIp:
|
|
portConf.Whitelist = whitelist
|
|
portConf.Userpass = u.P("")
|
|
newChannel.AuthIP = true
|
|
|
|
case ChannelAuthTypePass:
|
|
username, password := genPassPair()
|
|
portConf.Whitelist = &[]string{}
|
|
portConf.Userpass = u.P(fmt.Sprintf("%s:%s", username, password))
|
|
newChannel.AuthPass = true
|
|
newChannel.Username = username
|
|
newChannel.Password = password
|
|
|
|
default:
|
|
return nil, ChannelServiceErr("不支持的通道认证方式")
|
|
}
|
|
|
|
portConfigs = append(portConfigs, portConf)
|
|
newChannels = append(newChannels, newChannel)
|
|
}
|
|
if len(portConfigs) < acc {
|
|
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
|
|
}
|
|
|
|
// 提交端口配置并更新节点列表
|
|
if env.DebugExternalChange {
|
|
step = time.Now()
|
|
|
|
var secret = strings.Split(proxy.Secret, ":")
|
|
gateway := g.NewGateway(
|
|
proxy.Host,
|
|
secret[0],
|
|
secret[1],
|
|
)
|
|
err = gateway.GatewayPortConfigs(portConfigs)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
slog.Debug("提交端口配置", "step", time.Since(step))
|
|
}
|
|
}
|
|
|
|
slog.Debug("申请节点", "total", time.Since(step))
|
|
return newChannels, nil
|
|
}
|
|
|
|
func updateResource(rid string, resource *ResourceInfo, count int, now time.Time) (err error) {
|
|
toUpdate := m.ResourcePss{
|
|
Used: resource.Used + int32(count),
|
|
DailyLast: core.LocalDateTime(now),
|
|
}
|
|
var last = time.Time(resource.DailyLast)
|
|
if now.Year() != last.Year() || now.Month() != last.Month() || now.Day() != last.Day() {
|
|
toUpdate.DailyUsed = int32(count)
|
|
} else {
|
|
toUpdate.DailyUsed = resource.DailyUsed + int32(count)
|
|
}
|
|
_, err = q.ResourcePss.
|
|
Where(q.ResourcePss.ResourceID.Eq(resource.Id)).
|
|
Select(
|
|
q.ResourcePss.Used,
|
|
q.ResourcePss.DailyUsed,
|
|
q.ResourcePss.DailyLast,
|
|
).
|
|
Updates(toUpdate)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func saveChannels(channels []*m.Channel) (err error) {
|
|
err = q.Channel.
|
|
Omit(
|
|
q.Channel.NodeID,
|
|
q.Channel.NodeHost,
|
|
q.Channel.DeletedAt,
|
|
).
|
|
Create(channels...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func cacheChannels(ctx context.Context, rid string, channels []*m.Channel) (err error) {
|
|
err = cache(ctx, channels)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// endregion
|
|
|
|
func genPassPair() (string, string) {
|
|
//goland:noinspection SpellCheckingInspection
|
|
var alphabet = []rune("abcdefghjkmnpqrstuvwxyz")
|
|
var numbers = []rune("23456789")
|
|
|
|
var username = make([]rune, 6)
|
|
var password = make([]rune, 6)
|
|
for i := range 6 {
|
|
if i < 2 {
|
|
username[i] = alphabet[rand.N(len(alphabet))]
|
|
} else {
|
|
username[i] = numbers[rand.N(len(numbers))]
|
|
}
|
|
password[i] = numbers[rand.N(len(numbers))]
|
|
}
|
|
|
|
return string(username), string(password)
|
|
}
|
|
|
|
func chKey(channel *m.Channel) string {
|
|
return fmt.Sprintf("channel:%d", channel.ID)
|
|
}
|
|
|
|
func cache(ctx context.Context, channels []*m.Channel) error {
|
|
if len(channels) == 0 {
|
|
return nil
|
|
}
|
|
|
|
pipe := rds.Client.TxPipeline()
|
|
|
|
zList := make([]redis.Z, 0, len(channels))
|
|
for _, channel := range channels {
|
|
marshal, err := json.Marshal(channel)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
expiration := time.Time(channel.Expiration)
|
|
pipe.Set(ctx, chKey(channel), string(marshal), time.Until(expiration))
|
|
zList = append(zList, redis.Z{
|
|
Score: float64(expiration.Unix()),
|
|
Member: channel.ID,
|
|
})
|
|
}
|
|
pipe.ZAdd(ctx, "tasks:channel", zList...)
|
|
|
|
_, err := pipe.Exec(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func deleteCache(ctx context.Context, channels []*m.Channel) error {
|
|
if len(channels) == 0 {
|
|
return nil
|
|
}
|
|
|
|
keys := make([]string, len(channels))
|
|
for i := range channels {
|
|
keys[i] = chKey(channels[i])
|
|
}
|
|
_, err := rds.Client.Del(ctx, keys...).Result()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type ChannelAuthType int
|
|
|
|
const (
|
|
ChannelAuthTypeAll ChannelAuthType = iota
|
|
ChannelAuthTypeIp
|
|
ChannelAuthTypePass
|
|
)
|
|
|
|
type ChannelServiceErr string
|
|
|
|
func (c ChannelServiceErr) Error() string {
|
|
return string(c)
|
|
}
|
|
|
|
const (
|
|
ErrResourceNotExist = ChannelServiceErr("套餐不存在")
|
|
ErrResourceInvalid = ChannelServiceErr("套餐不可用")
|
|
ErrResourceExhausted = ChannelServiceErr("套餐已用完")
|
|
ErrResourceExpired = ChannelServiceErr("套餐已过期")
|
|
ErrResourceDailyLimit = ChannelServiceErr("套餐每日配额已用完")
|
|
)
|