Files
platform/web/services/channel.go

570 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"platform/pkg/orm"
"platform/pkg/rds"
"platform/pkg/remote"
"platform/web/common"
"platform/web/models"
q "platform/web/queries"
"strconv"
"time"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var Channel = &channelService{}
type channelService struct {
}
// CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
// 创建通道
var channels []*models.Channel
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = struct {
data models.Resource
pss models.ResourcePss
}{}
err := q.Resource.As("data").
LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)).
Where(q.Resource.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.data.UserID {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.data.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.pss.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.pss.DailyLimit-resource.pss.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.pss.Type == 1 { // 包时
if resource.pss.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.pss.Quota-resource.pss.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
// 筛选可用节点
nodes, err := Node.Filter(ctx, auth.Payload.Id, count, nodeFilter...)
if err != nil {
return err
}
// 获取用户配置白名单
whitelist, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(auth.Payload.Id),
).Find()
if err != nil {
return err
}
// 创建连接通道
channels = make([]*models.Channel, 0, len(nodes)*len(whitelist))
for _, node := range nodes {
for _, allowed := range whitelist {
username, password := genPassPair()
channels = append(channels, &models.Channel{
UserID: auth.Payload.Id,
NodeID: node.ID,
UserHost: allowed.Host,
NodeHost: node.Host,
ProxyPort: node.ProxyPort,
Protocol: string(protocol),
AuthIP: authType == ChannelAuthTypeIp,
AuthPass: authType == ChannelAuthTypePass,
Username: username,
Password: password,
Expiration: time.Now().Add(time.Duration(resource.pss.Live) * time.Second),
})
}
}
// 保存到数据库
err = tx.Channel.Create(channels...)
if err != nil {
return err
}
// 更新套餐使用记录
if today {
resource.pss.DailyUsed += int32(count)
resource.pss.Used += int32(count)
} else {
resource.pss.DailyLast = time.Now()
resource.pss.DailyUsed = int32(count)
resource.pss.Used += int32(count)
}
err = tx.ResourcePss.
Where(q.ResourcePss.ID.Eq(resource.pss.ID)).
Omit(q.ResourcePss.ResourceID).
Save(&resource.pss)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 缓存通道信息与异步删除任务
err = cache(ctx, channels)
if err != nil {
return nil, err
}
// 返回连接通道列表
return channels, errors.New("not implemented")
}
type ChannelAuthType int
const (
ChannelAuthTypeIp = iota
ChannelAuthTypePass
)
type ChannelProtocol string
const (
ProtocolSocks5 = ChannelProtocol("socks5")
ProtocolHTTP = ChannelProtocol("http")
ProtocolHttps = ChannelProtocol("https")
)
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
}
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
var channels []*models.Channel
// 删除通道
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找通道
channels, err := tx.Channel.Where(
q.Channel.ID.In(id...),
).Find()
if err != nil {
return err
}
// 检查权限,只有用户自己和管理员能删除
for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
return common.AuthForbiddenErr("无权限访问")
}
}
// 删除指定的通道
result, err := tx.Channel.Delete(channels...)
if err != nil {
return err
}
if result.RowsAffected != int64(len(channels)) {
return ChannelServiceErr("删除通道失败")
}
return nil
})
if err != nil {
return err
}
// 删除缓存,异步任务直接在消费端处理删除
err = deleteCache(ctx, channels)
if err != nil {
return err
}
return nil
}
func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%s:%s", channel.UserHost, channel.NodeHost)
}
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}
// region channel by remote
func (s *channelService) RemoteCreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
filter := NodeFilterConfig{}
if len(nodeFilter) > 0 {
filter = nodeFilter[0]
}
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(data.ALL, pss.ALL).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ChannelServiceErr("套餐不存在")
}
return nil, err
}
// 检查用户权限
err = checkUser(auth, resource, count)
if err != nil {
return nil, err
}
// 申请节点
assigned, err := assignEdge(count, filter)
if err != nil {
return nil, err
}
// 分配端口
expiration := time.Now().Add(time.Duration(resource.pss.Live) * time.Second)
channels, err := assignPort(assigned, auth.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return nil, err
}
// 缓存并关闭代理
err = cache(ctx, channels)
if err != nil {
return nil, err
}
return channels, nil
}
// endregion
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.data.UserID {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.data.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.pss.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.pss.DailyLimit-resource.pss.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.pss.Type == 1 { // 包时
if resource.pss.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.pss.Quota-resource.pss.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
return nil
}
// assignEdge 分配边缘节点数量
func assignEdge(count int, filter NodeFilterConfig) ([]AssignEdgeResult, error) {
// 查询现有节点连接情况
edgeConfigs, err := remote.Client.CloudAutoQuery()
if err != nil {
return nil, err
}
proxies, err := q.Proxy.
Where(q.Proxy.Type.Eq(1)).
Find()
if err != nil {
return nil, err
}
// 尽量平均分配节点用量
var total = count
for _, v := range edgeConfigs {
total += v.Count
}
avg := int(math.Ceil(float64(total) / float64(len(edgeConfigs))))
var result []AssignEdgeResult
var rCount = 0
for _, proxy := range proxies {
prev, ok := edgeConfigs[proxy.Name]
var nextCount = 0
if !ok || (prev.Count < avg && prev.Count < total) {
nextCount = int(math.Min(float64(avg), float64(total)))
result = append(result, AssignEdgeResult{
proxy: proxy,
count: nextCount - prev.Count,
})
total -= nextCount
} else {
continue
}
_rCount, err := remote.Client.CloudConnect(remote.CloudConnectReq{
Uuid: proxy.Name,
Edge: nil,
AutoConfig: []remote.AutoConfig{{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: nextCount,
}},
})
if err != nil {
return nil, err
}
rCount += _rCount
}
slog.Debug("cloud connect", "count", rCount)
return result, nil
}
type AssignEdgeResult struct {
proxy *models.Proxy
count int
}
// assignPort 分配指定数量的端口
func assignPort(
assigns []AssignEdgeResult,
userId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]*models.Channel, error) {
// 查询代理已配置端口
var proxyIds = make([]int32, 0, len(assigns))
for _, assigned := range assigns {
proxyIds = append(proxyIds, assigned.proxy.ID)
}
channels, err := q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
Where(
q.Channel.ProxyID.In(proxyIds...),
q.Channel.Expiration.Gt(time.Now())).
Group(
q.Channel.ProxyPort).
Find()
if err != nil {
return nil, err
}
// 端口查找表
var proxyPorts = make(map[uint64]struct{})
for _, channel := range channels {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
proxyPorts[key] = struct{}{}
}
// 配置启用代理
var result []*models.Channel
for i := 0; i < len(assigns); i++ {
proxy := assigns[i].proxy
count := assigns[i].count
// 筛选可用端口
var portConfigs = make([]remote.PortConfigsReq, count)
for port := 10000; port < 20000 || len(portConfigs) < count; port++ {
// 跳过存在的端口
key := uint64(proxy.ID)<<32 | uint64(port)
_, ok := proxyPorts[key]
if ok {
continue
}
// 配置新端口
portConfigs[port] = remote.PortConfigsReq{
Port: strconv.Itoa(port),
Edge: nil,
Status: true,
AutoEdgeConfig: remote.AutoEdgeConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: 1,
},
}
switch authType {
case ChannelAuthTypeIp:
var whitelist []string
err := q.Whitelist.
Where(q.Whitelist.UserID.Eq(userId)).
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, err
}
portConfigs[port].Whitelist = whitelist
for _, item := range whitelist {
result = append(result, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
UserHost: item,
ProxyPort: int32(port),
AuthIP: true,
AuthPass: false,
Protocol: string(protocol),
Expiration: expiration,
})
}
case ChannelAuthTypePass:
username, password := genPassPair()
portConfigs[port].Userpass = fmt.Sprintf("%s:%s", username, password)
result = append(result, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
ProxyPort: int32(port),
AuthIP: false,
AuthPass: true,
Username: username,
Password: password,
Protocol: string(protocol),
Expiration: expiration,
})
}
}
// 提交端口配置
gateway := remote.InitGateway(
proxy.Host,
"api",
"123456",
)
err = gateway.GatewayPortConfigs(portConfigs)
if err != nil {
return nil, err
}
}
err = q.Channel.Save(result...)
if err != nil {
return nil, err
}
return result, nil
}
func cache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
pipe.Set(ctx, chKey(channel), channel, channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
return nil
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
keys := make([]string, len(channels))
for i := range keys {
keys[i] = chKey(channels[i])
}
pipe.Del(ctx, keys...)
// 忽略异步任务zrem 效率较低,在使用时再删除
_, err := pipe.Exec(ctx)
if err != nil {
return err
}
return nil
}
type ResourceInfo struct {
data models.Resource
pss models.ResourcePss
}