package services import ( "context" "errors" "fmt" "math" "platform/pkg/rds" "platform/web/common" "platform/web/models" q "platform/web/queries" "time" "github.com/google/uuid" "github.com/jxskiss/base62" "github.com/redis/go-redis/v9" "gorm.io/gorm" ) var Channel = &channelService{} type channelService struct { } // CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误 func (s *channelService) CreateChannel( ctx context.Context, auth *AuthContext, resourceId int32, protocol ChannelProtocol, authType ChannelAuthType, count int, nodeFilter ...NodeFilterConfig, ) ([]*models.Channel, error) { // 创建通道 var channels []*models.Channel err := q.Q.Transaction(func(tx *q.Query) error { // 查找套餐 var resource = struct { data models.Resource pss models.ResourcePss }{} err := q.Resource.As("data"). LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)). Where(q.Resource.ID.Eq(resourceId)). Scan(&resource) if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return ChannelServiceErr("套餐不存在") } return err } // 检查使用人 if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.data.UserID { return common.AuthForbiddenErr("无权限访问") } // 检查套餐状态 if !resource.data.Active { return ChannelServiceErr("套餐已失效") } // 检查每日限额 today := time.Now().Format("2006-01-02") == resource.pss.DailyLast.Format("2006-01-02") dailyRemain := int(math.Max(float64(resource.pss.DailyLimit-resource.pss.DailyUsed), 0)) if today && dailyRemain < count { return ChannelServiceErr("套餐每日配额不足") } // 检查时间或配额 if resource.pss.Type == 1 { // 包时 if resource.pss.Expire.Before(time.Now()) { return ChannelServiceErr("套餐已过期") } } else { // 包量 remain := int(math.Max(float64(resource.pss.Quota-resource.pss.Used), 0)) if remain < count { return ChannelServiceErr("套餐配额不足") } } // 筛选可用节点 nodes, err := Node.Filter(auth.Payload.Id, protocol, count, nodeFilter...) if err != nil { return err } // 获取用户配置白名单 whitelist, err := q.Whitelist.Where( q.Whitelist.UserID.Eq(auth.Payload.Id), ).Find() if err != nil { return err } // 创建连接通道 channels = make([]*models.Channel, 0, len(nodes)*len(whitelist)) for _, node := range nodes { for _, allowed := range whitelist { username, password := genPassPair() channels = append(channels, &models.Channel{ UserID: auth.Payload.Id, NodeID: node.ID, NodePort: node.FwdPort, Protocol: string(protocol), AuthIP: authType == ChannelAuthTypeIp, UserHost: allowed.Host, AuthPass: authType == ChannelAuthTypePass, Username: username, Password: password, Expiration: time.Now().Add(time.Duration(resource.pss.Live) * time.Second), }) } } // 保存到数据库 err = tx.Channel.Create(channels...) if err != nil { return err } // 更新套餐使用记录 if today { resource.pss.DailyUsed += int32(count) resource.pss.Used += int32(count) } else { resource.pss.DailyLast = time.Now() resource.pss.DailyUsed = int32(count) resource.pss.Used += int32(count) } err = tx.ResourcePss. Where(q.ResourcePss.ID.Eq(resource.pss.ID)). Omit(q.ResourcePss.ResourceID). Save(&resource.pss) if err != nil { return err } return nil }) if err != nil { return nil, err } // 缓存通道信息与异步删除任务 pipe := rds.Client.TxPipeline() zList := make([]redis.Z, 0, len(channels)) for _, channel := range channels { pipe.Set(ctx, chKey(channel), channel, channel.Expiration.Sub(time.Now())) zList = append(zList, redis.Z{ Score: float64(channel.Expiration.Unix()), Member: channel.ID, }) } pipe.ZAdd(ctx, "tasks:channel", zList...) _, err = pipe.Exec(ctx) if err != nil { return nil, err } // 返回连接通道列表 return channels, errors.New("not implemented") } type ChannelAuthType int const ( ChannelAuthTypeIp = iota ChannelAuthTypePass ) type ChannelProtocol string const ( ProtocolSocks5 = ChannelProtocol("socks5") ProtocolHTTP = ChannelProtocol("http") ProtocolHttps = ChannelProtocol("https") ) func genPassPair() (string, string) { usernameBytes, err := uuid.New().MarshalBinary() if err != nil { panic(err) } passwordBytes, err := uuid.New().MarshalBinary() if err != nil { panic(err) } username := base62.EncodeToString(usernameBytes) password := base62.EncodeToString(passwordBytes) return username, password } func (s *channelService) RemoveChannels(auth *AuthContext, id ...int32) error { var channels []*models.Channel // 删除通道 err := q.Q.Transaction(func(tx *q.Query) error { // 查找通道 channels, err := tx.Channel.Where( q.Channel.ID.In(id...), ).Find() if err != nil { return err } // 检查权限,只有用户自己和管理员能删除 for _, channel := range channels { if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID { return common.AuthForbiddenErr("无权限访问") } } // 删除指定的通道 result, err := tx.Channel.Delete(channels...) if err != nil { return err } if result.RowsAffected != int64(len(channels)) { return ChannelServiceErr("删除通道失败") } return nil }) if err != nil { return err } // 删除缓存,异步任务直接在消费端处理删除 pipe := rds.Client.TxPipeline() for _, channel := range channels { pipe.Del(context.Background(), chKey(channel)) } return nil } func chKey(channel *models.Channel) string { return fmt.Sprintf("channel:%s:%d", channel.UserHost, channel.NodePort) } type ChannelServiceErr string func (c ChannelServiceErr) Error() string { return string(c) }