Files
platform/web/services/channel.go

644 lines
14 KiB
Go

package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"math"
"platform/pkg/env"
"platform/pkg/orm"
"platform/pkg/rds"
"platform/pkg/remote"
"platform/pkg/v"
"platform/web/common"
"platform/web/models"
q "platform/web/queries"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var Channel = &channelService{}
type channelService struct {
}
type ChannelAuthType int
const (
ChannelAuthTypeIp = iota
ChannelAuthTypePass
)
type ChannelProtocol string
const (
ProtocolSocks5 = ChannelProtocol("socks5")
ProtocolHTTP = ChannelProtocol("http")
ProtocolHttps = ChannelProtocol("https")
)
type ResourceInfo struct {
Id int32
UserId int32
Active bool
Type int32
Live int32
DailyLimit int32
DailyUsed int32
DailyLast time.Time
Quota int32
Used int32
Expire time.Time
}
// region RemoveChannel
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
// 删除通道
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找通道
channels, err := tx.Channel.Where(
q.Channel.ID.In(id...),
).Find()
if err != nil {
return err
}
// 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
return common.AuthForbiddenErr("无权限访问")
}
}
// 查找代理
proxySet := make(map[int32]struct{})
proxyIds := make([]int32, 0)
for _, channel := range channels {
if _, ok := proxySet[channel.ProxyID]; !ok {
proxyIds = append(proxyIds, channel.ProxyID)
proxySet[channel.ProxyID] = struct{}{}
}
}
proxies, err := tx.Proxy.Where(
q.Proxy.ID.In(proxyIds...),
).Find()
// 删除指定的通道
result, err := tx.Channel.
Where(q.Channel.ID.In(id...)).
Update(q.Channel.DeletedAt, time.Now())
if err != nil {
return err
}
if result.RowsAffected != int64(len(channels)) {
return ChannelServiceErr("删除通道失败")
}
// 删除缓存,异步任务直接在消费端处理删除
err = deleteCache(ctx, channels)
if err != nil {
return err
}
// 禁用代理端口并下线用过的节点
if env.DebugExternalChange {
var configMap = make(map[int32][]remote.PortConfigsReq, len(proxies))
var proxyMap = make(map[int32]*models.Proxy, len(proxies))
for _, proxy := range proxies {
configMap[proxy.ID] = make([]remote.PortConfigsReq, 0)
proxyMap[proxy.ID] = proxy
}
var portMap = make(map[uint64]struct{})
for _, channel := range channels {
var config = remote.PortConfigsReq{
Port: int(channel.ProxyPort),
Edge: &[]string{},
AutoEdgeConfig: &remote.AutoEdgeConfig{
Count: v.P(0),
},
Status: false,
}
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
portMap[key] = struct{}{}
}
for proxyId, configs := range configMap {
if len(configs) == 0 {
continue
}
proxy, ok := proxyMap[proxyId]
if !ok {
return ChannelServiceErr("代理不存在")
}
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
// 查询配置的节点
actives, err := gateway.GatewayPortActive()
if err != nil {
return err
}
// 取消配置
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return err
}
// 下线对应节点
var edges []string
for portStr, active := range actives {
port, err := strconv.Atoi(portStr)
if err != nil {
return err
}
key := uint64(proxyId)<<32 | uint64(port)
if _, ok := portMap[key]; ok {
edges = append(edges, active.Edge...)
}
}
if len(edges) > 0 {
_, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{
Uuid: proxy.Name,
Edge: edges,
})
if err != nil {
return err
}
}
}
}
return nil
})
if err != nil {
return err
}
return nil
}
// endregion
// region CreateChannel
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]string, error) {
filter := NodeFilterConfig{}
if len(nodeFilter) > 0 {
filter = nodeFilter[0]
}
var addr []string
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
// 检查用户权限
err = checkUser(auth, resource, count)
if err != nil {
return err
}
// 申请节点
edgeAssigns, err := assignEdge(count, filter)
if err != nil {
return err
}
// 分配端口
now := time.Now()
expiration := now.Add(time.Duration(resource.Live) * time.Second)
_addr, channels, err := assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return err
}
addr = _addr
// 更新套餐使用记录
_, err = q.ResourcePss.
Where(q.ResourcePss.ResourceID.Eq(resourceId)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast,
).
Updates(models.ResourcePss{
Used: resource.Used + int32(count),
DailyUsed: resource.DailyUsed + int32(count),
DailyLast: now,
})
if err != nil {
return err
}
// 缓存通道数据
err = cache(ctx, channels)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return addr, nil
}
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.Type == 1 { // 包时
if resource.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
return nil
}
// assignEdge 分配边缘节点数量
func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
// 查询可以使用的网关
proxies, err := q.Proxy.
Where(q.Proxy.Type.Eq(1)).
Find()
if err != nil {
return nil, err
}
// 查询已配置的节点
rProxyConfigs, err := remote.Client.CloudAutoQuery()
if err != nil {
return nil, err
}
// 查询已使用的节点
var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies {
proxyIds[i] = proxy.ID
}
channels, err := q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
Where(
q.Channel.ProxyID.In(proxyIds...),
q.Channel.Expiration.Gt(time.Now())).
Group(
q.Channel.ProxyPort,
q.Channel.ProxyID).
Find()
if err != nil {
return nil, err
}
var proxyUses = make(map[int32]int, len(channels))
for _, channel := range channels {
proxyUses[channel.ProxyID]++
}
// 组织数据
var infos = make([]*ProxyInfo, len(proxies))
for i, proxy := range proxies {
infos[i] = &ProxyInfo{
proxy: proxy,
used: proxyUses[proxy.ID],
}
rConfigs, ok := rProxyConfigs[proxy.Name]
if !ok {
infos[i].count = 0
continue
}
for _, rConfig := range rConfigs {
if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov {
infos[i].count = rConfig.Count
}
}
}
// 分配新增的节点
var configs = make([]*ProxyConfig, len(proxies))
var needed = len(channels) + count
avg := int(math.Ceil(float64(needed) / float64(len(proxies))))
for i, info := range infos {
var prev = info.used
var next = int(math.Min(float64(avg), float64(needed)))
info.used = int(math.Max(float64(prev), float64(next)))
needed -= info.used
if env.DebugExternalChange && info.used > info.count {
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
err := remote.Client.CloudConnect(remote.CloudConnectReq{
Uuid: info.proxy.Name,
Edge: nil,
AutoConfig: []remote.AutoConfig{{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: int(math.Ceil(float64(info.used) * 1.1)),
}},
})
if err != nil {
return nil, err
}
}
configs[i] = &ProxyConfig{
proxy: info.proxy,
count: int(math.Max(float64(next-prev), 0)),
}
}
return &AssignEdgeResult{
configs: configs,
channels: channels,
}, nil
}
type ProxyInfo struct {
proxy *models.Proxy
used int
count int
}
type AssignEdgeResult struct {
configs []*ProxyConfig
channels []*models.Channel
}
type ProxyConfig struct {
proxy *models.Proxy
count int
}
// assignPort 分配指定数量的端口
func assignPort(
proxies *AssignEdgeResult,
userId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]string, []*models.Channel, error) {
var assigns = proxies.configs
var exists = proxies.channels
// 查询代理已配置端口
var proxyIds = make([]int32, 0, len(assigns))
for _, assigned := range assigns {
proxyIds = append(proxyIds, assigned.proxy.ID)
}
// 端口查找表
var proxyPorts = make(map[uint64]struct{})
for _, channel := range exists {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
proxyPorts[key] = struct{}{}
}
// 配置启用代理
var result []string
var channels []*models.Channel
for _, assign := range assigns {
var err error
var proxy = assign.proxy
var count = assign.count
// 筛选可用端口
var configs = make([]remote.PortConfigsReq, 0, count)
for port := 10000; port < 20000 && len(configs) < count; port++ {
// 跳过存在的端口
key := uint64(proxy.ID)<<32 | uint64(port)
_, ok := proxyPorts[key]
if ok {
continue
}
// 配置新端口
var i = len(configs)
configs = append(configs, remote.PortConfigsReq{
Port: port,
Edge: nil,
Status: true,
AutoEdgeConfig: &remote.AutoEdgeConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: v.P(1),
},
})
switch authType {
case ChannelAuthTypeIp:
var whitelist []string
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, nil, err
}
configs[i].Whitelist = &whitelist
configs[i].Userpass = v.P("")
for _, item := range whitelist {
channels = append(channels, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
UserHost: item,
ProxyPort: int32(port),
AuthIP: true,
AuthPass: false,
Protocol: string(protocol),
Expiration: expiration,
})
}
case ChannelAuthTypePass:
username, password := genPassPair()
configs[i].Whitelist = new([]string)
configs[i].Userpass = v.P(fmt.Sprintf("%s:%s", username, password))
channels = append(channels, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
ProxyPort: int32(port),
AuthIP: false,
AuthPass: true,
Username: username,
Password: password,
Protocol: string(protocol),
Expiration: expiration,
})
}
result = append(result, fmt.Sprintf("%s://%s:%d", protocol, proxy.Host, port))
}
if len(configs) < count {
return nil, nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 保存到数据库
err = q.Channel.
Omit(
q.Channel.NodeID,
q.Channel.NodeHost,
q.Channel.Username,
q.Channel.Password,
q.Channel.DeletedAt,
).
Save(channels...)
if err != nil {
return nil, nil, err
}
// 提交端口配置并更新节点列表
if env.DebugExternalChange {
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return nil, nil, err
}
}
}
return result, channels, nil
}
// endregion
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
}
func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%d", channel.ID)
}
func cache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
marshal, err := json.Marshal(channel)
if err != nil {
return err
}
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
return nil
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
keys := make([]string, len(channels))
for i := range channels {
keys[i] = chKey(channels[i])
}
_, err := rds.Client.Del(ctx, keys...).Result()
if err != nil {
return err
}
return nil
}
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}