Files
platform/web/services/channel.go

251 lines
5.8 KiB
Go

package services
import (
"context"
"errors"
"fmt"
"math"
"platform/pkg/rds"
"platform/web/common"
"platform/web/models"
q "platform/web/queries"
"time"
"github.com/google/uuid"
"github.com/jxskiss/base62"
"github.com/redis/go-redis/v9"
"gorm.io/gorm"
)
var Channel = &channelService{}
type channelService struct {
}
// CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
// 创建通道
var channels []*models.Channel
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = struct {
data models.Resource
pss models.ResourcePss
}{}
err := q.Resource.As("data").
LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)).
Where(q.Resource.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.data.UserID {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.data.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.pss.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.pss.DailyLimit-resource.pss.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.pss.Type == 1 { // 包时
if resource.pss.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.pss.Quota-resource.pss.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
// 筛选可用节点
nodes, err := Node.Filter(auth.Payload.Id, protocol, count, nodeFilter...)
if err != nil {
return err
}
// 获取用户配置白名单
whitelist, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(auth.Payload.Id),
).Find()
if err != nil {
return err
}
// 创建连接通道
channels = make([]*models.Channel, 0, len(nodes)*len(whitelist))
for _, node := range nodes {
for _, allowed := range whitelist {
username, password := genPassPair()
channels = append(channels, &models.Channel{
UserID: auth.Payload.Id,
NodeID: node.ID,
NodePort: node.FwdPort,
Protocol: string(protocol),
AuthIP: authType == ChannelAuthTypeIp,
UserHost: allowed.Host,
AuthPass: authType == ChannelAuthTypePass,
Username: username,
Password: password,
Expiration: time.Now().Add(time.Duration(resource.pss.Live) * time.Second),
})
}
}
// 保存到数据库
err = tx.Channel.Create(channels...)
if err != nil {
return err
}
// 更新套餐使用记录
if today {
resource.pss.DailyUsed += int32(count)
resource.pss.Used += int32(count)
} else {
resource.pss.DailyLast = time.Now()
resource.pss.DailyUsed = int32(count)
resource.pss.Used += int32(count)
}
err = tx.ResourcePss.
Where(q.ResourcePss.ID.Eq(resource.pss.ID)).
Omit(q.ResourcePss.ResourceID).
Save(&resource.pss)
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 缓存通道信息与异步删除任务
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
pipe.Set(ctx, chKey(channel), channel, channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err = pipe.Exec(ctx)
if err != nil {
return nil, err
}
// 返回连接通道列表
return channels, errors.New("not implemented")
}
type ChannelAuthType int
const (
ChannelAuthTypeIp = iota
ChannelAuthTypePass
)
type ChannelProtocol string
const (
ProtocolSocks5 = ChannelProtocol("socks5")
ProtocolHTTP = ChannelProtocol("http")
ProtocolHttps = ChannelProtocol("https")
)
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
}
func (s *channelService) RemoveChannels(auth *AuthContext, id ...int32) error {
var channels []*models.Channel
// 删除通道
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找通道
channels, err := tx.Channel.Where(
q.Channel.ID.In(id...),
).Find()
if err != nil {
return err
}
// 检查权限,只有用户自己和管理员能删除
for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
return common.AuthForbiddenErr("无权限访问")
}
}
// 删除指定的通道
result, err := tx.Channel.Delete(channels...)
if err != nil {
return err
}
if result.RowsAffected != int64(len(channels)) {
return ChannelServiceErr("删除通道失败")
}
return nil
})
if err != nil {
return err
}
// 删除缓存,异步任务直接在消费端处理删除
pipe := rds.Client.TxPipeline()
for _, channel := range channels {
pipe.Del(context.Background(), chKey(channel))
}
return nil
}
func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%s:%d", channel.UserHost, channel.NodePort)
}
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}