完善通道删除与定时失效功能

This commit is contained in:
2025-03-31 09:09:05 +08:00
parent ec4f499edd
commit 47bb49ce70
18 changed files with 832 additions and 619 deletions

View File

@@ -1,7 +1,12 @@
package web
import (
"context"
"encoding/base64"
"errors"
"log/slog"
"platform/web/common"
q "platform/web/queries"
"slices"
"strings"
@@ -14,16 +19,36 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取令牌
var header = c.Get("Authorization")
var token = strings.TrimPrefix(header, "Bearer ")
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
var split = strings.Split(header, " ")
if len(split) != 2 {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
Message: "无效的令牌",
})
}
// 验证令牌
auth, err := services.Session.Find(c.Context(), token)
var token = split[1]
if token == "" {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "无效的令牌",
})
}
var auth *services.AuthContext
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
auth, err = authBasic(c.Context(), token)
}
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
@@ -32,22 +57,6 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
}
// 检查权限
// switch auth.Payload.Type {
// case services.PayloadAdmin:
// // 管理员不需要权限检查
// case services.PayloadUser:
// if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
// return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
// Error: true,
// Message: "拒绝访问",
// })
// }
// default:
// return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
// Error: true,
// Message: "拒绝访问",
// })
// }
if !slices.Contains(types, auth.Payload.Type) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
@@ -70,97 +79,95 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
}
func PermitAll(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadClientConfidential,
services.PayloadUser,
services.PayloadAdmin,
}, permissions...)
}
// PermitUser 创建针对单个路由的鉴权中间件
func PermitUser(permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取令牌
var header = c.Get("Authorization")
var token = strings.TrimPrefix(header, "Bearer ")
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 验证令牌
auth, err := services.Session.Find(c.Context(), token)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 检查权限
switch auth.Payload.Type {
case services.PayloadAdmin:
// 管理员不需要权限检查
case services.PayloadUser:
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
default:
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
return c.Next()
}
return Permit([]services.PayloadType{
services.PayloadUser,
services.PayloadAdmin,
}, permissions...)
}
func PermitDevice(permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取令牌
var header = c.Get("Authorization")
var token = strings.TrimPrefix(header, "Bearer ")
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 验证令牌
auth, err := services.Session.Find(c.Context(), token)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 检查权限
switch auth.Payload.Type {
case services.PayloadAdmin:
// 管理员不需要权限检查
case services.PayloadClientPublic, services.PayloadClientConfidential:
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
default:
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
return c.Next()
}
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadClientConfidential,
services.PayloadAdmin,
}, permissions...)
}
func PermitPublic(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadAdmin,
}, permissions...)
}
func PermitConfidential(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientConfidential,
services.PayloadAdmin,
}, permissions...)
}
func authBearer(ctx context.Context, token string) (*services.AuthContext, error) {
auth, err := services.Session.Find(ctx, token)
if err != nil {
slog.Debug(err.Error())
return nil, err
}
return auth, nil
}
func authBasic(ctx context.Context, token string) (*services.AuthContext, error) {
// 解析 Basic 认证信息
var base, err = base64.URLEncoding.DecodeString(token)
if err != nil {
slog.Debug(err.Error())
return nil, err
}
var split = strings.Split(string(base), ":")
if len(split) != 2 {
msg := "无法解析 Basic 认证信息"
slog.Debug(msg)
return nil, errors.New(msg)
}
var clientID = split[0]
// 获取客户端信息
client, err := q.Client.
Where(
q.Client.ClientID.Eq(clientID),
q.Client.Spec.Eq(0),
q.Client.GrantClient.Is(true),
q.Client.Status.Eq(1)).
Take()
if err != nil {
return nil, err
}
// todo 查询客户端关联权限
// 组织授权信息(一次性请求)
return &services.AuthContext{
Payload: services.Payload{
Id: client.ID,
Type: services.PayloadClientConfidential,
Name: client.Name,
Avatar: client.Icon,
},
Permissions: nil,
Metadata: nil,
}, nil
}

View File

@@ -2,7 +2,6 @@ package handlers
import (
"errors"
"fmt"
"platform/web/services"
"strings"
@@ -35,7 +34,7 @@ func CreateChannel(c *fiber.Ctx) error {
return errors.New("user not found")
}
assigns, err := services.Channel.RemoteCreateChannel(
result, err := services.Channel.CreateChannel(
c.Context(),
auth,
req.ResourceId,
@@ -52,17 +51,6 @@ func CreateChannel(c *fiber.Ctx) error {
return err
}
// 返回连接通道列表
var result []string
for _, assign := range assigns {
var proxy = assign.Proxy
var channels = assign.Channels
for _, channel := range channels {
url := fmt.Sprintf("%s://%s:%d", channel.Protocol, proxy.Host, channel.ProxyPort)
result = append(result, url)
}
}
switch req.ResultType {
case CreateChannelResultTypeJson:
return c.JSON(fiber.Map{
@@ -101,3 +89,32 @@ const (
)
// endregion
// region RemoveChannels
type RemoveChannelsReq struct {
ByIds []int32 `json:"by_ids" validate:"required"`
}
func RemoveChannels(c *fiber.Ctx) error {
req := new(RemoveChannelsReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 获取用户信息
auth, ok := c.Locals("auth").(*services.AuthContext)
if !ok {
return errors.New("user not found")
}
// 删除通道
err := services.Channel.RemoveChannels(c.Context(), auth, req.ByIds...)
if err != nil {
return err
}
return c.SendStatus(fiber.StatusOK)
}
// endregion

View File

@@ -19,6 +19,7 @@ type Proxy struct {
Name string `gorm:"column:name;not null;comment:代理服务名称" json:"name"` // 代理服务名称
Host string `gorm:"column:host;not null;comment:代理服务地址" json:"host"` // 代理服务地址
Type int32 `gorm:"column:type;not null;comment:代理服务类型0-自有1-三方" json:"type"` // 代理服务类型0-自有1-三方
Secret string `gorm:"column:secret;comment:代理服务密钥" json:"secret"` // 代理服务密钥
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间

View File

@@ -19,8 +19,8 @@ type ResourcePss struct {
Type int32 `gorm:"column:type;comment:套餐类型1-包时2-包量" json:"type"` // 套餐类型1-包时2-包量
Live int32 `gorm:"column:live;comment:可用时长(秒)" json:"live"` // 可用时长(秒)
Quota int32 `gorm:"column:quota;comment:配额数量" json:"quota"` // 配额数量
Used int32 `gorm:"column:used;comment:已用数量" json:"used"` // 已用数量
Expire time.Time `gorm:"column:expire;comment:过期时间" json:"expire"` // 过期时间
Used int32 `gorm:"column:used;not null;comment:已用数量" json:"used"` // 已用数量
DailyLimit int32 `gorm:"column:daily_limit;not null;comment:每日限制" json:"daily_limit"` // 每日限制
DailyUsed int32 `gorm:"column:daily_used;not null;comment:今日已用数量" json:"daily_used"` // 今日已用数量
DailyLast time.Time `gorm:"column:daily_last;comment:今日最后使用时间" json:"daily_last"` // 今日最后使用时间

View File

@@ -32,6 +32,7 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy {
_proxy.Name = field.NewString(tableName, "name")
_proxy.Host = field.NewString(tableName, "host")
_proxy.Type = field.NewInt32(tableName, "type")
_proxy.Secret = field.NewString(tableName, "secret")
_proxy.CreatedAt = field.NewTime(tableName, "created_at")
_proxy.UpdatedAt = field.NewTime(tableName, "updated_at")
_proxy.DeletedAt = field.NewField(tableName, "deleted_at")
@@ -50,6 +51,7 @@ type proxy struct {
Name field.String // 代理服务名称
Host field.String // 代理服务地址
Type field.Int32 // 代理服务类型0-自有1-三方
Secret field.String // 代理服务密钥
CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间
@@ -74,6 +76,7 @@ func (p *proxy) updateTableName(table string) *proxy {
p.Name = field.NewString(table, "name")
p.Host = field.NewString(table, "host")
p.Type = field.NewInt32(table, "type")
p.Secret = field.NewString(table, "secret")
p.CreatedAt = field.NewTime(table, "created_at")
p.UpdatedAt = field.NewTime(table, "updated_at")
p.DeletedAt = field.NewField(table, "deleted_at")
@@ -93,12 +96,13 @@ func (p *proxy) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
}
func (p *proxy) fillFieldMap() {
p.fieldMap = make(map[string]field.Expr, 8)
p.fieldMap = make(map[string]field.Expr, 9)
p.fieldMap["id"] = p.ID
p.fieldMap["version"] = p.Version
p.fieldMap["name"] = p.Name
p.fieldMap["host"] = p.Host
p.fieldMap["type"] = p.Type
p.fieldMap["secret"] = p.Secret
p.fieldMap["created_at"] = p.CreatedAt
p.fieldMap["updated_at"] = p.UpdatedAt
p.fieldMap["deleted_at"] = p.DeletedAt

View File

@@ -32,8 +32,8 @@ func newResourcePss(db *gorm.DB, opts ...gen.DOOption) resourcePss {
_resourcePss.Type = field.NewInt32(tableName, "type")
_resourcePss.Live = field.NewInt32(tableName, "live")
_resourcePss.Quota = field.NewInt32(tableName, "quota")
_resourcePss.Used = field.NewInt32(tableName, "used")
_resourcePss.Expire = field.NewTime(tableName, "expire")
_resourcePss.Used = field.NewInt32(tableName, "used")
_resourcePss.DailyLimit = field.NewInt32(tableName, "daily_limit")
_resourcePss.DailyUsed = field.NewInt32(tableName, "daily_used")
_resourcePss.DailyLast = field.NewTime(tableName, "daily_last")
@@ -55,8 +55,8 @@ type resourcePss struct {
Type field.Int32 // 套餐类型1-包时2-包量
Live field.Int32 // 可用时长(秒)
Quota field.Int32 // 配额数量
Used field.Int32 // 已用数量
Expire field.Time // 过期时间
Used field.Int32 // 已用数量
DailyLimit field.Int32 // 每日限制
DailyUsed field.Int32 // 今日已用数量
DailyLast field.Time // 今日最后使用时间
@@ -84,8 +84,8 @@ func (r *resourcePss) updateTableName(table string) *resourcePss {
r.Type = field.NewInt32(table, "type")
r.Live = field.NewInt32(table, "live")
r.Quota = field.NewInt32(table, "quota")
r.Used = field.NewInt32(table, "used")
r.Expire = field.NewTime(table, "expire")
r.Used = field.NewInt32(table, "used")
r.DailyLimit = field.NewInt32(table, "daily_limit")
r.DailyUsed = field.NewInt32(table, "daily_used")
r.DailyLast = field.NewTime(table, "daily_last")
@@ -114,8 +114,8 @@ func (r *resourcePss) fillFieldMap() {
r.fieldMap["type"] = r.Type
r.fieldMap["live"] = r.Live
r.fieldMap["quota"] = r.Quota
r.fieldMap["used"] = r.Used
r.fieldMap["expire"] = r.Expire
r.fieldMap["used"] = r.Used
r.fieldMap["daily_limit"] = r.DailyLimit
r.fieldMap["daily_used"] = r.DailyUsed
r.fieldMap["daily_last"] = r.DailyLast

View File

@@ -2,7 +2,6 @@ package web
import (
"platform/web/handlers"
"platform/web/services"
"github.com/gofiber/fiber/v2"
)
@@ -18,10 +17,6 @@ func ApplyRouters(app *fiber.App) {
// 通道
channel := api.Group("/channel")
channel.Post("/create", Permit([]services.PayloadType{
services.PayloadClientConfidential,
services.PayloadClientPublic,
services.PayloadUser,
services.PayloadAdmin,
}), handlers.CreateChannel)
channel.Post("/create", PermitAll(), handlers.CreateChannel)
channel.Post("/remove", PermitAll(), handlers.RemoveChannels)
}

View File

@@ -7,12 +7,16 @@ import (
"fmt"
"log/slog"
"math"
"platform/pkg/env"
"platform/pkg/orm"
"platform/pkg/rds"
"platform/pkg/remote"
"platform/pkg/v"
"platform/web/common"
"platform/web/models"
q "platform/web/queries"
"strconv"
"strings"
"time"
"github.com/google/uuid"
@@ -26,143 +30,6 @@ var Channel = &channelService{}
type channelService struct {
}
// CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
// 创建通道
var channels []*models.Channel
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = ResourceInfo{}
err := q.Resource.As("data").
LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)).
Where(q.Resource.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.Type == 1 { // 包时
if resource.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
// 筛选可用节点
nodes, err := Node.Filter(ctx, auth.Payload.Id, count, nodeFilter...)
if err != nil {
return err
}
// 获取用户配置白名单
whitelist, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(auth.Payload.Id),
).Find()
if err != nil {
return err
}
// 创建连接通道
channels = make([]*models.Channel, 0, len(nodes)*len(whitelist))
for _, node := range nodes {
for _, allowed := range whitelist {
username, password := genPassPair()
channels = append(channels, &models.Channel{
UserID: auth.Payload.Id,
NodeID: node.ID,
UserHost: allowed.Host,
NodeHost: node.Host,
ProxyPort: node.ProxyPort,
Protocol: string(protocol),
AuthIP: authType == ChannelAuthTypeIp,
AuthPass: authType == ChannelAuthTypePass,
Username: username,
Password: password,
Expiration: time.Now().Add(time.Duration(resource.Live) * time.Second),
})
}
}
// 保存到数据库
err = tx.Channel.Create(channels...)
if err != nil {
return err
}
// 更新套餐使用记录
if today {
resource.DailyUsed += int32(count)
resource.Used += int32(count)
} else {
resource.DailyLast = time.Now()
resource.DailyUsed = int32(count)
resource.Used += int32(count)
}
err = tx.ResourcePss.
Where(q.ResourcePss.ID.Eq(resource.Id)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast).
Save(&models.ResourcePss{
Used: resource.Used,
DailyUsed: resource.DailyUsed,
DailyLast: resource.DailyLast})
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 缓存通道信息与异步删除任务
// err = cache(ctx, channels)
// if err != nil {
// return nil, err
// }
// 返回连接通道列表
return channels, errors.New("not implemented")
}
type ChannelAuthType int
const (
@@ -178,24 +45,23 @@ const (
ProtocolHttps = ChannelProtocol("https")
)
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
type ResourceInfo struct {
Id int32
UserId int32
Active bool
Type int32
Live int32
DailyLimit int32
DailyUsed int32
DailyLast time.Time
Quota int32
Used int32
Expire time.Time
}
// region RemoveChannel
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
var channels []*models.Channel
// 删除通道
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找通道
@@ -206,15 +72,30 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
return err
}
// 检查权限,只有用户自己和管理员能删除
// 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
return common.AuthForbiddenErr("无权限访问")
}
}
// 查找代理
proxySet := make(map[int32]struct{})
proxyIds := make([]int32, 0)
for _, channel := range channels {
if _, ok := proxySet[channel.ProxyID]; !ok {
proxyIds = append(proxyIds, channel.ProxyID)
proxySet[channel.ProxyID] = struct{}{}
}
}
proxies, err := tx.Proxy.Where(
q.Proxy.ID.In(proxyIds...),
).Find()
// 删除指定的通道
result, err := tx.Channel.Delete(channels...)
result, err := tx.Channel.
Where(q.Channel.ID.In(id...)).
Update(q.Channel.DeletedAt, time.Now())
if err != nil {
return err
}
@@ -222,30 +103,103 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
return ChannelServiceErr("删除通道失败")
}
// 删除缓存,异步任务直接在消费端处理删除
err = deleteCache(ctx, channels)
if err != nil {
return err
}
// 禁用代理端口并下线用过的节点
if env.DebugExternalChange {
var configMap = make(map[int32][]remote.PortConfigsReq, len(proxies))
var proxyMap = make(map[int32]*models.Proxy, len(proxies))
for _, proxy := range proxies {
configMap[proxy.ID] = make([]remote.PortConfigsReq, 0)
proxyMap[proxy.ID] = proxy
}
var portMap = make(map[uint64]struct{})
for _, channel := range channels {
var config = remote.PortConfigsReq{
Port: int(channel.ProxyPort),
Edge: &[]string{},
AutoEdgeConfig: &remote.AutoEdgeConfig{
Count: v.P(0),
},
Status: false,
}
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
portMap[key] = struct{}{}
}
for proxyId, configs := range configMap {
if len(configs) == 0 {
continue
}
proxy, ok := proxyMap[proxyId]
if !ok {
return ChannelServiceErr("代理不存在")
}
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
// 查询配置的节点
actives, err := gateway.GatewayPortActive()
if err != nil {
return err
}
// 取消配置
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return err
}
// 下线对应节点
var edges []string
for portStr, active := range actives {
port, err := strconv.Atoi(portStr)
if err != nil {
return err
}
key := uint64(proxyId)<<32 | uint64(port)
if _, ok := portMap[key]; ok {
edges = append(edges, active.Edge...)
}
}
if len(edges) > 0 {
_, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{
Uuid: proxy.Name,
Edge: edges,
})
if err != nil {
return err
}
}
}
}
return nil
})
if err != nil {
return err
}
// 删除缓存,异步任务直接在消费端处理删除
err = deleteCache(ctx, channels)
if err != nil {
return err
}
return nil
}
type ChannelServiceErr string
// endregion
func (c ChannelServiceErr) Error() string {
return string(c)
}
// region CreateChannel
// region channel by remote
func (s *channelService) RemoteCreateChannel(
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
@@ -253,57 +207,78 @@ func (s *channelService) RemoteCreateChannel(
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]AssignPortResult, error) {
) ([]string, error) {
filter := NodeFilterConfig{}
if len(nodeFilter) > 0 {
filter = nodeFilter[0]
}
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Debug().Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ChannelServiceErr("套餐不存在")
var addr []string
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
err := data.Scopes(orm.Alias(data)).
Select(
data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
return nil, err
}
// 检查用户权限
err = checkUser(auth, resource, count)
if err != nil {
return nil, err
}
slog.Debug("检查用户权限完成")
// 检查用户权限
err = checkUser(auth, resource, count)
if err != nil {
return err
}
var postAssigns []AssignPortResult
err = q.Q.Transaction(func(tx *q.Query) error {
// 申请节点
edgeAssigns, err := assignEdge(count, filter)
if err != nil {
return err
}
debugAssigned := fmt.Sprintf("%+v", edgeAssigns)
slog.Debug("申请节点完成", "edgeAssigns", debugAssigned)
// 分配端口
expiration := time.Now().Add(time.Duration(resource.Live) * time.Second)
postAssigns, err = assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter)
now := time.Now()
expiration := now.Add(time.Duration(resource.Live) * time.Second)
_addr, channels, err := assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return err
}
addr = _addr
// 更新套餐使用记录
_, err = q.ResourcePss.
Where(q.ResourcePss.ResourceID.Eq(resourceId)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast,
).
Updates(models.ResourcePss{
Used: resource.Used + int32(count),
DailyUsed: resource.DailyUsed + int32(count),
DailyLast: now,
})
if err != nil {
return err
}
// 缓存通道数据
err = cache(ctx, channels)
if err != nil {
return err
}
debugChannels := fmt.Sprintf("%+v", postAssigns)
slog.Debug("分配端口完成", "portAssigns", debugChannels)
return nil
})
@@ -311,17 +286,9 @@ func (s *channelService) RemoteCreateChannel(
return nil, err
}
// 缓存并关闭代理
err = cache(ctx, postAssigns)
if err != nil {
return nil, err
}
return postAssigns, nil
return addr, nil
}
// endregion
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
// 检查使用人
@@ -368,17 +335,17 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
}
// 查询已配置的节点
allConfigs, err := remote.Client.CloudAutoQuery()
rProxyConfigs, err := remote.Client.CloudAutoQuery()
if err != nil {
return nil, err
}
// 查询已分配的节点
// 查询已使用的节点
var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies {
proxyIds[i] = proxy.ID
}
assigns, err := q.Channel.
channels, err := q.Channel.
Select(
q.Channel.ProxyID,
q.Channel.ProxyPort).
@@ -392,80 +359,86 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
if err != nil {
return nil, err
}
var proxyUses = make(map[int32]int, len(channels))
for _, channel := range channels {
proxyUses[channel.ProxyID]++
}
// 过滤需要变动的连接配置
var current = 0
var result = make([]ProxyConfig, len(proxies))
// 组织数据
var infos = make([]*ProxyInfo, len(proxies))
for i, proxy := range proxies {
remoteConfigs, ok := allConfigs[proxy.Name]
infos[i] = &ProxyInfo{
proxy: proxy,
used: proxyUses[proxy.ID],
}
rConfigs, ok := rProxyConfigs[proxy.Name]
if !ok {
result[i] = ProxyConfig{
proxy: proxy,
config: &remote.AutoConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: 0,
},
}
infos[i].count = 0
continue
}
for _, config := range remoteConfigs {
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov {
current += config.Count
result[i] = ProxyConfig{
proxy: proxy,
config: &config,
}
for _, rConfig := range rConfigs {
if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov {
infos[i].count = rConfig.Count
}
}
}
// 如果需要新增节点
var needed = len(assigns) + count
if needed-current > 0 {
slog.Debug("新增新节点", "needed", needed, "current", current)
avg := int(math.Ceil(float64(needed) / float64(len(proxies))))
for i, assign := range result {
var prev = assign.config.Count
var next = assign.config.Count
if prev >= avg || prev >= needed {
continue
}
// 分配新增节点
var configs = make([]*ProxyConfig, len(proxies))
var needed = len(channels) + count
avg := int(math.Ceil(float64(needed) / float64(len(proxies))))
for i, info := range infos {
var prev = info.used
var next = int(math.Min(float64(avg), float64(needed)))
next = int(math.Min(float64(avg), float64(needed)))
result[i].config.Count = next - prev
needed -= next
info.used = int(math.Max(float64(prev), float64(next)))
needed -= info.used
if env.DebugExternalChange && info.used > info.count {
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
err := remote.Client.CloudConnect(remote.CloudConnectReq{
Uuid: assign.proxy.Name,
Uuid: info.proxy.Name,
Edge: nil,
AutoConfig: []remote.AutoConfig{{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: next,
Count: int(math.Ceil(float64(info.used) * 1.1)),
}},
})
if err != nil {
return nil, err
}
}
configs[i] = &ProxyConfig{
proxy: info.proxy,
count: int(math.Max(float64(next-prev), 0)),
}
}
return &AssignEdgeResult{
configs: result,
channels: assigns,
configs: configs,
channels: channels,
}, nil
}
type ProxyInfo struct {
proxy *models.Proxy
used int
count int
}
type AssignEdgeResult struct {
configs []ProxyConfig
configs []*ProxyConfig
channels []*models.Channel
}
type ProxyConfig struct {
proxy *models.Proxy
config *remote.AutoConfig
proxy *models.Proxy
count int
}
// assignPort 分配指定数量的端口
@@ -476,9 +449,9 @@ func assignPort(
authType ChannelAuthType,
expiration time.Time,
filter NodeFilterConfig,
) ([]AssignPortResult, error) {
) ([]string, []*models.Channel, error) {
var assigns = proxies.configs
var channels = proxies.channels
var exists = proxies.channels
// 查询代理已配置端口
var proxyIds = make([]int32, 0, len(assigns))
@@ -488,24 +461,20 @@ func assignPort(
// 端口查找表
var proxyPorts = make(map[uint64]struct{})
for _, channel := range channels {
for _, channel := range exists {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
proxyPorts[key] = struct{}{}
}
// 配置启用代理
var result = make([]AssignPortResult, len(assigns))
for i, assign := range assigns {
var result []string
var channels []*models.Channel
for _, assign := range assigns {
var err error
var proxy = assign.proxy
var count = assign.config.Count
result[i] = AssignPortResult{
Proxy: proxy,
}
var count = assign.count
// 筛选可用端口
var channels = result[i].Channels
var configs = make([]remote.PortConfigsReq, 0, count)
for port := 10000; port < 20000 && len(configs) < count; port++ {
// 跳过存在的端口
@@ -521,13 +490,14 @@ func assignPort(
Port: port,
Edge: nil,
Status: true,
AutoEdgeConfig: remote.AutoEdgeConfig{
AutoEdgeConfig: &remote.AutoEdgeConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: 1,
Count: v.P(1),
},
})
switch authType {
case ChannelAuthTypeIp:
var whitelist []string
@@ -536,9 +506,10 @@ func assignPort(
Select(q.Whitelist.Host).
Scan(&whitelist)
if err != nil {
return nil, err
return nil, nil, err
}
configs[i].Whitelist = whitelist
configs[i].Whitelist = &whitelist
configs[i].Userpass = v.P("")
for _, item := range whitelist {
channels = append(channels, &models.Channel{
UserID: userId,
@@ -553,7 +524,8 @@ func assignPort(
}
case ChannelAuthTypePass:
username, password := genPassPair()
configs[i].Userpass = fmt.Sprintf("%s:%s", username, password)
configs[i].Whitelist = new([]string)
configs[i].Userpass = v.P(fmt.Sprintf("%s:%s", username, password))
channels = append(channels, &models.Channel{
UserID: userId,
ProxyID: proxy.ID,
@@ -566,66 +538,82 @@ func assignPort(
Expiration: expiration,
})
}
}
result[i].Channels = channels
result = append(result, fmt.Sprintf("%s://%s:%d", protocol, proxy.Host, port))
}
if len(configs) < count {
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 提交端口配置
gateway := remote.InitGateway(
proxy.Host,
"api",
"123456",
)
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return nil, err
return nil, nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 保存到数据库
err = q.Channel.
Omit(q.Channel.NodeID).
Omit(
q.Channel.NodeID,
q.Channel.NodeHost,
q.Channel.Username,
q.Channel.Password,
q.Channel.DeletedAt,
).
Save(channels...)
if err != nil {
return nil, err
return nil, nil, err
}
// 提交端口配置并更新节点列表
if env.DebugExternalChange {
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return nil, nil, err
}
}
}
return result, nil
return result, channels, nil
}
type AssignPortResult struct {
Proxy *models.Proxy
Channels []*models.Channel
// endregion
func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
}
func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%s:%s", channel.UserHost, channel.NodeHost)
return fmt.Sprintf("channel:%d", channel.ID)
}
func cache(ctx context.Context, assigns []AssignPortResult) error {
func cache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(assigns))
for _, assign := range assigns {
var channels = assign.Channels
for _, channel := range channels {
marshal, err := json.Marshal(assign)
if err != nil {
return err
}
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
zList := make([]redis.Z, 0, len(channels))
for _, channel := range channels {
marshal, err := json.Marshal(channel)
if err != nil {
return err
}
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()),
Member: channel.ID,
})
}
pipe.ZAdd(ctx, "tasks:assign", zList...)
pipe.ZAdd(ctx, "tasks:channel", zList...)
_, err := pipe.Exec(ctx)
if err != nil {
@@ -636,14 +624,11 @@ func cache(ctx context.Context, assigns []AssignPortResult) error {
}
func deleteCache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline()
keys := make([]string, 0, len(channels))
for i := range keys {
keys := make([]string, len(channels))
for i := range channels {
keys[i] = chKey(channels[i])
}
pipe.Del(ctx, keys...)
// 忽略异步任务zrem 效率较低,在使用时再删除
_, err := pipe.Exec(ctx)
_, err := rds.Client.Del(ctx, keys...).Result()
if err != nil {
return err
}
@@ -651,16 +636,8 @@ func deleteCache(ctx context.Context, channels []*models.Channel) error {
return nil
}
type ResourceInfo struct {
Id int32
UserId int32
Active bool
Type int32
Live int32
DailyLimit int32
DailyUsed int32
DailyLast time.Time
Quota int32
Used int32
Expire time.Time
type ChannelServiceErr string
func (c ChannelServiceErr) Error() string {
return string(c)
}

View File

@@ -1,13 +1,17 @@
package web
import (
"net/http"
"platform/pkg/env"
"log/slog"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/requestid"
)
import "log/slog"
import _ "net/http/pprof"
type Config struct {
Listen string
@@ -30,6 +34,7 @@ func New(config *Config) (*Server, error) {
}
func (s *Server) Run() error {
s.fiber = fiber.New(fiber.Config{
ErrorHandler: ErrorHandler,
})
@@ -39,6 +44,13 @@ func (s *Server) Run() error {
ApplyRouters(s.fiber)
go func() {
err := http.ListenAndServe(":6060", nil)
if err != nil {
slog.Error("pprof 服务错误", slog.Any("err", err))
}
}()
port := env.AppPort
slog.Info("Server started on :" + port)
err := s.fiber.Listen(":" + port)