完善通道删除与定时失效功能

This commit is contained in:
2025-03-31 09:09:05 +08:00
parent ec4f499edd
commit 47bb49ce70
18 changed files with 832 additions and 619 deletions

View File

@@ -19,14 +19,18 @@
- [ ] Limiter - [ ] Limiter
- [ ] Compress - [ ] Compress
channel 数据存入顺序,数据库 > 缓存 > 外部接口
remote 令牌问题
用对称加密处理密钥
现在的节点分配逻辑是,每个 user_host:node_port 组算一个分配数,考虑是否改成每个用户算一个分配数 现在的节点分配逻辑是,每个 user_host:node_port 组算一个分配数,考虑是否改成每个用户算一个分配数
考虑将鉴权逻辑放到 handler 里,统一动静态鉴权以及解耦服务层 考虑将鉴权逻辑放到 handler 里,统一动静态鉴权以及解耦服务层
有些地方在用手动事务,有时间改成自动事务 有些地方在用手动事务,有时间改成自动事务
remote 用环境变量保存账号密码!
重新手动实现 model 层 重新手动实现 model 层
环境变量配置默认会话配置 环境变量配置默认会话配置

View File

@@ -5,8 +5,9 @@ import (
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/logs" "platform/pkg/logs"
"platform/pkg/orm" "platform/pkg/orm"
"platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"time"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@@ -16,23 +17,47 @@ func main() {
logs.Init() logs.Init()
orm.Init() orm.Init()
err := q.Q.Transaction(func(tx *q.Query) error {
q.User. q.User.
Select(q.User.Phone).
Save(&m.User{
Phone: "12312341234",
})
var user, _ = q.User.First()
q.Resource.
Select(q.Resource.UserID, q.Resource.Active).
Create(&m.Resource{
UserID: user.ID,
Active: true,
})
var resource, _ = q.Resource.First()
q.ResourcePss.
Select( Select(
q.User.Phone). q.ResourcePss.ResourceID,
Create(&models.User{ q.ResourcePss.Live,
Phone: "12312341234"}) q.ResourcePss.Type,
q.ResourcePss.Expire,
q.ResourcePss.DailyLimit,
).
Create(&m.ResourcePss{
ResourceID: resource.ID,
Live: 180,
Type: 1,
Expire: time.Now().Add(24 * time.Hour * 1000),
DailyLimit: 300000,
})
q.Proxy. q.Proxy.
Select( Select(q.Proxy.Version, q.Proxy.Name, q.Proxy.Host, q.Proxy.Type, q.Proxy.Secret).
q.Proxy.Version, Create(&m.Proxy{
q.Proxy.Name,
q.Proxy.Host,
q.Proxy.Type).
Create(&models.Proxy{
Version: 1, Version: 1,
Name: "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", Name: "7a17e8b4-cdc3-4500-bf16-4a665991a7f6",
Host: "110.40.82.248", Host: "110.40.82.248",
Type: 1}) Type: 1,
Secret: "api:123456",
})
q.Node. q.Node.
Select( Select(
@@ -43,7 +68,7 @@ func main() {
q.Node.Prov, q.Node.Prov,
q.Node.City, q.Node.City,
q.Node.Status). q.Node.Status).
Create(&models.Node{ Create(&m.Node{
Version: 1, Version: 1,
Name: "test-node", Name: "test-node",
Host: "123", Host: "123",
@@ -52,7 +77,8 @@ func main() {
City: "test-city", City: "test-city",
Status: 1}) Status: 1})
var secret, _ = bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost) var testSecret, _ = bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost)
var tasksSecret, _ = bcrypt.GenerateFromPassword([]byte("tasks"), bcrypt.DefaultCost)
q.Client. q.Client.
Select( Select(
q.Client.ClientID, q.Client.ClientID,
@@ -61,14 +87,26 @@ func main() {
q.Client.GrantRefresh, q.Client.GrantRefresh,
q.Client.Spec, q.Client.Spec,
q.Client.Name). q.Client.Name).
Create(&models.Client{ Create(&m.Client{
ClientID: "test", ClientID: "test",
ClientSecret: string(secret), ClientSecret: string(testSecret),
GrantClient: true, GrantClient: true,
GrantRefresh: true, GrantRefresh: true,
Spec: 0, Spec: 0,
Name: "默认客户端", Name: "默认客户端",
}, &m.Client{
ClientID: "tasks",
ClientSecret: string(tasksSecret),
GrantClient: true,
GrantRefresh: true,
Spec: 0,
Name: "异步任务处理服务",
}) })
return nil
})
if err != nil {
panic(err)
}
slog.Info("✔ Data inserted successfully") slog.Info("✔ Data inserted successfully")
} }

View File

@@ -1,33 +1,11 @@
package main package main
import ( import "math"
"fmt"
"platform/pkg/env"
"platform/pkg/logs"
"platform/pkg/orm"
"platform/web/models"
q "platform/web/queries"
)
type ResourceInfo struct { var b62Set = make(map[string]struct{})
data models.Resource var b64Set = make(map[string]struct{})
pss models.ResourcePss
}
func main() { func main() {
println(int(math.Ceil(100 * 1.1)))
env.Init() println(int(math.Ceil(float64(100) * 1.1)))
logs.Init()
orm.Init()
var resource = new(ResourceInfo)
data := q.Resource.As("data")
pss := q.ResourcePss.As("pss")
_ = data.Debug().Scopes(orm.Alias(data)).
Select(data.ALL, pss.ALL).
LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)).
Where(data.ID.Eq(1)).
Scan(&resource)
fmt.Printf("%+v\n", resource)
} }

View File

@@ -2,12 +2,18 @@ package main
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"io"
"log/slog" "log/slog"
"net/http"
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/logs" "platform/pkg/logs"
"platform/pkg/orm"
"platform/pkg/rds" "platform/pkg/rds"
"reflect"
"strconv"
"strings"
"sync"
"time" "time"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
@@ -17,27 +23,51 @@ func main() {
Start() Start()
} }
var taskList = make(map[string]func(ctx context.Context, curr time.Time) error)
func Start() { func Start() {
ctx := context.Background() ctx := context.Background()
env.Init() env.Init()
logs.Init() logs.Init()
rds.Init() rds.Init()
orm.Init()
taskList["stopChannels"] = stopChannels
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
defer ticker.Stop() defer ticker.Stop()
// 互斥锁确保同一时间只有一个协程运行
// 如果之前的 tick 操作未完成,则跳过当前 tick
var mutex = &sync.Mutex{}
for curr := range ticker.C { for curr := range ticker.C {
if mutex.TryLock() {
err := process(ctx, curr) err := process(ctx, curr)
if err != nil { if err != nil {
panic(err) panic(err)
} }
mutex.Unlock()
} else {
slog.Warn("skip tick", slog.String("tick", curr.Format("2006-01-02 15:04:05")))
}
} }
} }
func process(ctx context.Context, curr time.Time) error { func process(ctx context.Context, curr time.Time) error {
// todo 异步化
for name, task := range taskList {
err := task(ctx, curr)
if err != nil {
slog.Error("task failed", slog.String("task", name), slog.String("error", err.Error()))
}
}
return nil
}
func stopChannels(ctx context.Context, curr time.Time) error {
// 获取并删除 // 获取并删除
script := redis.NewScript(` script := redis.NewScript(`
local result = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1]) local result = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1])
@@ -47,21 +77,64 @@ func process(ctx context.Context, curr time.Time) error {
return result return result
`) `)
// 计算时间范围
// 执行脚本 // 执行脚本
result, err := script.Run(ctx, rds.Client, []string{"tasks:session"}, curr.Unix()).Result() result, err := script.Run(ctx, rds.Client, []string{"tasks:channel"}, curr.Unix()).Result()
if err != nil { if err != nil {
return err return err
} }
// 处理结果 // 处理结果
list, ok := result.([]string) list, ok := result.([]any)
if !ok { if !ok {
return errors.New("failed to convert result to []string") return errors.New("failed to convert result to []string")
} }
for _, item := range list { var ids = make([]int, len(list))
// 从数据库删除授权信息 for i, item := range list {
slog.Debug(item) idStr, ok := item.(string)
if !ok {
slog.Debug(reflect.TypeOf(item).String())
return errors.New("failed to convert item to string")
}
id, err := strconv.Atoi(idStr)
if err != nil {
return err
}
ids[i] = id
}
if len(ids) == 0 {
return nil
}
var body = map[string]any{
"by_ids": ids,
}
bodyStr, err := json.Marshal(body)
if err != nil {
return err
}
req, err := http.NewRequest(
http.MethodPost,
"http://localhost:8080/api/channel/remove", // todo 环境变量获取服务地址
strings.NewReader(string(bodyStr)),
)
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
req.SetBasicAuth("tasks", "tasks")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
body, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
return errors.New("failed to stop channels: " + string(body))
} }
return nil return nil

62
pkg/env/env.go vendored
View File

@@ -55,21 +55,21 @@ func loadDb() {
if _DbName != "" { if _DbName != "" {
DbName = _DbName DbName = _DbName
} else { } else {
panic("环境变量 DB_NAME 的值为空") panic("环境变量 DB_NAME 的值不能为空")
} }
_DbUserName := os.Getenv("DB_USERNAME") _DbUserName := os.Getenv("DB_USERNAME")
if _DbUserName != "" { if _DbUserName != "" {
DbUserName = _DbUserName DbUserName = _DbUserName
} else { } else {
panic("环境变量 DB_USERNAME 的值为空") panic("环境变量 DB_USERNAME 的值不能为空")
} }
_DbPassword := os.Getenv("DB_PASSWORD") _DbPassword := os.Getenv("DB_PASSWORD")
if _DbPassword != "" { if _DbPassword != "" {
DbPassword = _DbPassword DbPassword = _DbPassword
} else { } else {
panic("环境变量 DB_PASSWORD 的值为空") panic("环境变量 DB_PASSWORD 的值不能为空")
} }
} }
@@ -134,6 +134,60 @@ func loadLog() {
// endregion // endregion
// region remote
var (
RemoteAddr = "http://103.139.212.110:9989"
RemoteToken string
)
func loadRemote() {
_RemoteAddr := os.Getenv("REMOTE_ADDR")
if _RemoteAddr != "" {
RemoteAddr = _RemoteAddr
}
_RemoteToken := os.Getenv("REMOTE_TOKEN")
if _RemoteToken == "" {
panic("环境变量 REMOTE_TOKEN 的值不能为空")
}
RemoteToken = _RemoteToken
}
// endregion
// region debug
var (
// DebugHttpDump 是否打印请求和响应的原始数据
DebugHttpDump = false
// DebugExternalChange 是否实际执行非幂等外部接口的调用。
// 例如外部数据修改接口,在内部接口调试时可以关闭,避免对外部数据产生影响
DebugExternalChange = true
)
func loadDebug() {
debugHttpDump := os.Getenv("DEBUG_HTTP_DUMP")
if debugHttpDump != "" {
value, err := strconv.ParseBool(debugHttpDump)
if err != nil {
panic("环境变量 DEBUG_HTTP_DUMP 的值不是布尔值")
}
DebugHttpDump = value
}
debugExternalChange := os.Getenv("DEBUG_EXTERNAL_CHANGE")
if debugExternalChange != "" {
value, err := strconv.ParseBool(debugExternalChange)
if err != nil {
panic("环境变量 DEBUG_EXTERNAL_CHANGE 的值不是布尔值")
}
DebugExternalChange = value
}
}
// endregion
func Init() { func Init() {
err := godotenv.Load() err := godotenv.Load()
if err != nil { if err != nil {
@@ -146,4 +200,6 @@ func Init() {
loadDb() loadDb()
loadRedis() loadRedis()
loadLog() loadLog()
loadDebug()
loadRemote()
} }

View File

@@ -6,6 +6,9 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"platform/pkg/env"
"strconv" "strconv"
"strings" "strings"
) )
@@ -18,10 +21,9 @@ type client struct {
var Client client var Client client
func Init() { func Init() {
// todo 从环境变量中获取参数
Client = client{ Client = client{
url: "http://103.139.212.110:9989", url: env.RemoteAddr,
token: "PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ==", token: env.RemoteToken,
} }
} }
@@ -151,8 +153,8 @@ func (c *client) CloudConnect(param CloudConnectReq) error {
type CloudDisconnectReq struct { type CloudDisconnectReq struct {
Uuid string `json:"uuid"` Uuid string `json:"uuid"`
Edge []string `json:"edge"` Edge []string `json:"edge,omitempty"`
Config []Config `json:"auto_config"` Config []Config `json:"auto_config,omitempty"`
} }
type Config struct { type Config struct {
@@ -246,11 +248,29 @@ func (c *client) requestCloud(method string, url string, data string) (*http.Res
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
req.Header.Set("token", c.token) req.Header.Set("token", c.token)
if env.DebugHttpDump {
str, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
fmt.Println("==============================")
fmt.Println(string(str))
}
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if env.DebugHttpDump {
str, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, err
}
fmt.Println("==============================")
fmt.Println(string(str))
}
return resp, nil return resp, nil
} }
@@ -268,21 +288,21 @@ func InitGateway(url, username, password string) *Gateway {
type PortConfigsReq struct { type PortConfigsReq struct {
Port int `json:"port"` Port int `json:"port"`
Edge []string `json:"edge,omitempty"` Edge *[]string `json:"edge,omitempty"`
Type string `json:"type,omitempty"` Type string `json:"type,omitempty"`
Time int `json:"time,omitempty"` Time int `json:"time,omitempty"`
Status bool `json:"status,omitempty"` Status bool `json:"status"`
Rate int `json:"rate,omitempty"` Rate int `json:"rate,omitempty"`
Whitelist []string `json:"whitelist,omitempty"` Whitelist *[]string `json:"whitelist,omitempty"`
Userpass string `json:"userpass,omitempty"` Userpass *string `json:"userpass,omitempty"`
AutoEdgeConfig AutoEdgeConfig `json:"auto_edge_config,omitempty"` AutoEdgeConfig *AutoEdgeConfig `json:"auto_edge_config,omitempty"`
} }
type AutoEdgeConfig struct { type AutoEdgeConfig struct {
Province string `json:"province,omitempty"` Province string `json:"province,omitempty"`
City string `json:"city,omitempty"` City string `json:"city,omitempty"`
Isp string `json:"isp,omitempty"` Isp string `json:"isp,omitempty"`
Count int `json:"count,omitempty"` Count *int `json:"count,omitempty"`
PacketLoss int `json:"packet_loss,omitempty"` PacketLoss int `json:"packet_loss,omitempty"`
} }
@@ -332,8 +352,8 @@ func (c *Gateway) GatewayPortConfigs(params []PortConfigsReq) error {
type PortActiveReq struct { type PortActiveReq struct {
Port string `json:"port"` Port string `json:"port"`
Active bool `json:"active"` Active *bool `json:"active"`
Status bool `json:"status"` Status *bool `json:"status"`
} }
type PortActiveResp struct { type PortActiveResp struct {
@@ -352,22 +372,34 @@ type PortData struct {
Userpass string `json:"userpass"` Userpass string `json:"userpass"`
} }
func (c *Gateway) GatewayPortActive(param PortActiveReq) (*PortActiveResp, error) { func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) {
_param := PortActiveReq{}
url := strings.Builder{} if len(param) != 0 {
url.WriteString("/port/active") _param = param[0]
if param.Port != "" {
url.WriteString("/")
url.WriteString(param.Port)
} }
url.WriteString("?active=") path := strings.Builder{}
url.WriteString(strconv.FormatBool(param.Active)) path.WriteString("/port/active")
url.WriteString("&status=")
url.WriteString(strconv.FormatBool(param.Status))
resp, err := c.requestGateway("POST", url.String(), "") if _param.Port != "" {
path.WriteString("/")
path.WriteString(_param.Port)
}
values := url.Values{}
if _param.Active != nil {
values.Set("active", strconv.FormatBool(*_param.Active))
}
if _param.Status != nil {
values.Set("status", strconv.FormatBool(*_param.Status))
}
if len(values) > 0 {
path.WriteString("?")
path.WriteString(values.Encode())
}
resp, err := c.requestGateway("GET", path.String(), "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -390,7 +422,11 @@ func (c *Gateway) GatewayPortActive(param PortActiveReq) (*PortActiveResp, error
return nil, err return nil, err
} }
return &result, nil if result.Code != 0 {
return nil, errors.New(result.Msg)
}
return result.Data, nil
} }
// endregion // endregion
@@ -404,10 +440,28 @@ func (c *Gateway) requestGateway(method string, url string, data string) (*http.
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
if env.DebugHttpDump {
str, err := httputil.DumpRequest(req, true)
if err != nil {
return nil, err
}
fmt.Println("==============================")
fmt.Println(string(str))
}
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if env.DebugHttpDump {
str, err := httputil.DumpResponse(resp, true)
if err != nil {
return nil, err
}
fmt.Println("==============================")
fmt.Println(string(str))
}
return resp, nil return resp, nil
} }

6
pkg/v/v.go Normal file
View File

@@ -0,0 +1,6 @@
package v
// P 是一个工具函数,用于在表达式内原地创建一个指针
func P[T any](v T) *T {
return &v
}

View File

@@ -388,6 +388,7 @@ create table proxy (
name varchar(255) not null unique, name varchar(255) not null unique,
host varchar(255) not null, host varchar(255) not null,
type int not null default 0, type int not null default 0,
secret varchar(255),
created_at timestamp default current_timestamp, created_at timestamp default current_timestamp,
updated_at timestamp default current_timestamp, updated_at timestamp default current_timestamp,
deleted_at timestamp deleted_at timestamp
@@ -402,6 +403,7 @@ comment on column proxy.version is '代理服务版本';
comment on column proxy.name is '代理服务名称'; comment on column proxy.name is '代理服务名称';
comment on column proxy.host is '代理服务地址'; comment on column proxy.host is '代理服务地址';
comment on column proxy.type is '代理服务类型0-自有1-三方'; comment on column proxy.type is '代理服务类型0-自有1-三方';
comment on column proxy.secret is '代理服务密钥';
comment on column proxy.created_at is '创建时间'; comment on column proxy.created_at is '创建时间';
comment on column proxy.updated_at is '更新时间'; comment on column proxy.updated_at is '更新时间';
comment on column proxy.deleted_at is '删除时间'; comment on column proxy.deleted_at is '删除时间';
@@ -507,6 +509,7 @@ create index channel_user_host_index on channel (user_host);
create index channel_proxy_port_index on channel (proxy_port); create index channel_proxy_port_index on channel (proxy_port);
create index channel_node_host_index on channel (node_host); create index channel_node_host_index on channel (node_host);
create index channel_expiration_index on channel (expiration); create index channel_expiration_index on channel (expiration);
create index channel_deleted_at_index on channel (deleted_at);
-- channel表字段注释 -- channel表字段注释
comment on table channel is '通道表'; comment on table channel is '通道表';
@@ -593,8 +596,8 @@ create table resource_pss (
type int, type int,
live int, live int,
quota int, quota int,
used int,
expire timestamp, expire timestamp,
used int not null default 0,
daily_limit int not null default 0, daily_limit int not null default 0,
daily_used int not null default 0, daily_used int not null default 0,
daily_last timestamp, daily_last timestamp,

View File

@@ -3,46 +3,37 @@ GET http://110.40.82.250:18702/server/index/getToken/key/juipbyjdapiverify
### remote 配置信息 ### remote 配置信息
GET http://103.139.212.110:9989/api/auto_query GET http://103.139.212.110:9989/api/auto_query
token: PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ== token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ==
### remote 配置重置 ### remote 配置连接
POST http://103.139.212.110:9989/api/connect POST http://103.139.212.110:9989/api/connect
token: PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ== token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ==
Content-Type: application/json Content-Type: application/json
{ {
"uuid": "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", "uuid": "7a17e8b4-cdc3-4500-bf16-4a665991a7f6",
"auto_config": [ "auto_config": [
{ {
"count": 1 "count": 10
} }
] ]
} }
### gateway 端口信息 ### remote 下线全部
GET http://api:123456@110.40.82.248:9990/port/active/ POST http://103.139.212.110:9989/api/disconnect
token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ==
### gateway 配置端口代理
POST http://api:123456@110.40.82.248:9990/port/configs
Content-Type: application/json Content-Type: application/json
//[ {
// { "uuid": "7a17e8b4-cdc3-4500-bf16-4a665991a7f6",
// "port": 10000, "config": {}
// "status": true, }
// "userpass": "mIfSlXIBwVUrKqObNdTzvB:cyDaGtfeBlhRfojTYJP2tR",
// "auto_edge_config": { ### gateway 连接信息
// "count": 1 GET http://api:123456@110.40.82.248:9990/edge
// }
// } ### gateway 端口信息
//] GET http://api:123456@110.40.82.248:9990/port/active
[
{
"port": 10000,
"status": false,
"edge": []
}
]
### 设备令牌 ### 设备令牌
POST http://localhost:8080/api/auth/token POST http://localhost:8080/api/auth/token
@@ -59,23 +50,20 @@ Content-Type: application/json
client.global.set("refresh_token", response.body.refresh_token); client.global.set("refresh_token", response.body.refresh_token);
%} %}
### 密码模式代理 ### 提取代理
POST http://localhost:8080/api/channel/create POST http://localhost:8080/api/channel/create
Authorization: Bearer {{access_token}} Authorization: Basic test test
Content-Type: application/json Content-Type: application/json
Accept: application/json Accept: application/json
{ {
"resource_id": 1, "resource_id": 1,
"protocol": "http", "protocol": "http",
"auth_type": 1, "auth_type": 0,
"count": 1, "count": 200,
"prov": "", "prov": "",
"city": "", "city": "",
"isp": "", "isp": "",
"result_type": "text", "result_type": "text",
"result_separator": "both" "result_separator": "both"
} }
### 白名单模式代理

View File

@@ -1,7 +1,12 @@
package web package web
import ( import (
"context"
"encoding/base64"
"errors"
"log/slog"
"platform/web/common" "platform/web/common"
q "platform/web/queries"
"slices" "slices"
"strings" "strings"
@@ -14,16 +19,36 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error { return func(c *fiber.Ctx) error {
// 获取令牌 // 获取令牌
var header = c.Get("Authorization") var header = c.Get("Authorization")
var token = strings.TrimPrefix(header, "Bearer ") var split = strings.Split(header, " ")
if len(split) != 2 {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "无效的令牌",
})
}
var token = split[1]
if token == "" { if token == "" {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "无效的令牌",
})
}
var auth *services.AuthContext
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true, Error: true,
Message: "没有权限", Message: "没有权限",
}) })
} }
auth, err = authBasic(c.Context(), token)
// 验证令牌 }
auth, err := services.Session.Find(c.Context(), token)
if err != nil { if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true, Error: true,
@@ -32,22 +57,6 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
} }
// 检查权限 // 检查权限
// switch auth.Payload.Type {
// case services.PayloadAdmin:
// // 管理员不需要权限检查
// case services.PayloadUser:
// if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
// return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
// Error: true,
// Message: "拒绝访问",
// })
// }
// default:
// return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
// Error: true,
// Message: "拒绝访问",
// })
// }
if !slices.Contains(types, auth.Payload.Type) { if !slices.Contains(types, auth.Payload.Type) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true, Error: true,
@@ -70,97 +79,95 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
} }
func PermitAll(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadClientConfidential,
services.PayloadUser,
services.PayloadAdmin,
}, permissions...)
}
// PermitUser 创建针对单个路由的鉴权中间件 // PermitUser 创建针对单个路由的鉴权中间件
func PermitUser(permissions ...string) fiber.Handler { func PermitUser(permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error { return Permit([]services.PayloadType{
// 获取令牌 services.PayloadUser,
var header = c.Get("Authorization") services.PayloadAdmin,
var token = strings.TrimPrefix(header, "Bearer ") }, permissions...)
if token == "" {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 验证令牌
auth, err := services.Session.Find(c.Context(), token)
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 检查权限
switch auth.Payload.Type {
case services.PayloadAdmin:
// 管理员不需要权限检查
case services.PayloadUser:
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
default:
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
return c.Next()
}
} }
func PermitDevice(permissions ...string) fiber.Handler { func PermitDevice(permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error { return Permit([]services.PayloadType{
// 获取令牌 services.PayloadClientPublic,
var header = c.Get("Authorization") services.PayloadClientConfidential,
var token = strings.TrimPrefix(header, "Bearer ") services.PayloadAdmin,
if token == "" { }, permissions...)
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ }
Error: true,
Message: "没有权限", func PermitPublic(permissions ...string) fiber.Handler {
}) return Permit([]services.PayloadType{
} services.PayloadClientPublic,
services.PayloadAdmin,
// 验证令牌 }, permissions...)
auth, err := services.Session.Find(c.Context(), token) }
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ func PermitConfidential(permissions ...string) fiber.Handler {
Error: true, return Permit([]services.PayloadType{
Message: "没有权限", services.PayloadClientConfidential,
}) services.PayloadAdmin,
} }, permissions...)
}
// 检查权限
switch auth.Payload.Type { func authBearer(ctx context.Context, token string) (*services.AuthContext, error) {
case services.PayloadAdmin: auth, err := services.Session.Find(ctx, token)
// 管理员不需要权限检查 if err != nil {
case services.PayloadClientPublic, services.PayloadClientConfidential: slog.Debug(err.Error())
if len(permissions) > 0 && !auth.AnyPermission(permissions...) { return nil, err
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ }
Error: true, return auth, nil
Message: "拒绝访问", }
})
} func authBasic(ctx context.Context, token string) (*services.AuthContext, error) {
default:
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ // 解析 Basic 认证信息
Error: true, var base, err = base64.URLEncoding.DecodeString(token)
Message: "拒绝访问", if err != nil {
}) slog.Debug(err.Error())
} return nil, err
}
// 将认证信息存储在上下文中
c.Locals("auth", auth) var split = strings.Split(string(base), ":")
c.Locals("access_token", token) // 存储原始令牌,便于后续操作 if len(split) != 2 {
msg := "无法解析 Basic 认证信息"
return c.Next() slog.Debug(msg)
} return nil, errors.New(msg)
}
var clientID = split[0]
// 获取客户端信息
client, err := q.Client.
Where(
q.Client.ClientID.Eq(clientID),
q.Client.Spec.Eq(0),
q.Client.GrantClient.Is(true),
q.Client.Status.Eq(1)).
Take()
if err != nil {
return nil, err
}
// todo 查询客户端关联权限
// 组织授权信息(一次性请求)
return &services.AuthContext{
Payload: services.Payload{
Id: client.ID,
Type: services.PayloadClientConfidential,
Name: client.Name,
Avatar: client.Icon,
},
Permissions: nil,
Metadata: nil,
}, nil
} }

View File

@@ -2,7 +2,6 @@ package handlers
import ( import (
"errors" "errors"
"fmt"
"platform/web/services" "platform/web/services"
"strings" "strings"
@@ -35,7 +34,7 @@ func CreateChannel(c *fiber.Ctx) error {
return errors.New("user not found") return errors.New("user not found")
} }
assigns, err := services.Channel.RemoteCreateChannel( result, err := services.Channel.CreateChannel(
c.Context(), c.Context(),
auth, auth,
req.ResourceId, req.ResourceId,
@@ -52,17 +51,6 @@ func CreateChannel(c *fiber.Ctx) error {
return err return err
} }
// 返回连接通道列表
var result []string
for _, assign := range assigns {
var proxy = assign.Proxy
var channels = assign.Channels
for _, channel := range channels {
url := fmt.Sprintf("%s://%s:%d", channel.Protocol, proxy.Host, channel.ProxyPort)
result = append(result, url)
}
}
switch req.ResultType { switch req.ResultType {
case CreateChannelResultTypeJson: case CreateChannelResultTypeJson:
return c.JSON(fiber.Map{ return c.JSON(fiber.Map{
@@ -101,3 +89,32 @@ const (
) )
// endregion // endregion
// region RemoveChannels
type RemoveChannelsReq struct {
ByIds []int32 `json:"by_ids" validate:"required"`
}
func RemoveChannels(c *fiber.Ctx) error {
req := new(RemoveChannelsReq)
if err := c.BodyParser(req); err != nil {
return err
}
// 获取用户信息
auth, ok := c.Locals("auth").(*services.AuthContext)
if !ok {
return errors.New("user not found")
}
// 删除通道
err := services.Channel.RemoveChannels(c.Context(), auth, req.ByIds...)
if err != nil {
return err
}
return c.SendStatus(fiber.StatusOK)
}
// endregion

View File

@@ -19,6 +19,7 @@ type Proxy struct {
Name string `gorm:"column:name;not null;comment:代理服务名称" json:"name"` // 代理服务名称 Name string `gorm:"column:name;not null;comment:代理服务名称" json:"name"` // 代理服务名称
Host string `gorm:"column:host;not null;comment:代理服务地址" json:"host"` // 代理服务地址 Host string `gorm:"column:host;not null;comment:代理服务地址" json:"host"` // 代理服务地址
Type int32 `gorm:"column:type;not null;comment:代理服务类型0-自有1-三方" json:"type"` // 代理服务类型0-自有1-三方 Type int32 `gorm:"column:type;not null;comment:代理服务类型0-自有1-三方" json:"type"` // 代理服务类型0-自有1-三方
Secret string `gorm:"column:secret;comment:代理服务密钥" json:"secret"` // 代理服务密钥
CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间
UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间
DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间

View File

@@ -19,8 +19,8 @@ type ResourcePss struct {
Type int32 `gorm:"column:type;comment:套餐类型1-包时2-包量" json:"type"` // 套餐类型1-包时2-包量 Type int32 `gorm:"column:type;comment:套餐类型1-包时2-包量" json:"type"` // 套餐类型1-包时2-包量
Live int32 `gorm:"column:live;comment:可用时长(秒)" json:"live"` // 可用时长(秒) Live int32 `gorm:"column:live;comment:可用时长(秒)" json:"live"` // 可用时长(秒)
Quota int32 `gorm:"column:quota;comment:配额数量" json:"quota"` // 配额数量 Quota int32 `gorm:"column:quota;comment:配额数量" json:"quota"` // 配额数量
Used int32 `gorm:"column:used;comment:已用数量" json:"used"` // 已用数量
Expire time.Time `gorm:"column:expire;comment:过期时间" json:"expire"` // 过期时间 Expire time.Time `gorm:"column:expire;comment:过期时间" json:"expire"` // 过期时间
Used int32 `gorm:"column:used;not null;comment:已用数量" json:"used"` // 已用数量
DailyLimit int32 `gorm:"column:daily_limit;not null;comment:每日限制" json:"daily_limit"` // 每日限制 DailyLimit int32 `gorm:"column:daily_limit;not null;comment:每日限制" json:"daily_limit"` // 每日限制
DailyUsed int32 `gorm:"column:daily_used;not null;comment:今日已用数量" json:"daily_used"` // 今日已用数量 DailyUsed int32 `gorm:"column:daily_used;not null;comment:今日已用数量" json:"daily_used"` // 今日已用数量
DailyLast time.Time `gorm:"column:daily_last;comment:今日最后使用时间" json:"daily_last"` // 今日最后使用时间 DailyLast time.Time `gorm:"column:daily_last;comment:今日最后使用时间" json:"daily_last"` // 今日最后使用时间

View File

@@ -32,6 +32,7 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy {
_proxy.Name = field.NewString(tableName, "name") _proxy.Name = field.NewString(tableName, "name")
_proxy.Host = field.NewString(tableName, "host") _proxy.Host = field.NewString(tableName, "host")
_proxy.Type = field.NewInt32(tableName, "type") _proxy.Type = field.NewInt32(tableName, "type")
_proxy.Secret = field.NewString(tableName, "secret")
_proxy.CreatedAt = field.NewTime(tableName, "created_at") _proxy.CreatedAt = field.NewTime(tableName, "created_at")
_proxy.UpdatedAt = field.NewTime(tableName, "updated_at") _proxy.UpdatedAt = field.NewTime(tableName, "updated_at")
_proxy.DeletedAt = field.NewField(tableName, "deleted_at") _proxy.DeletedAt = field.NewField(tableName, "deleted_at")
@@ -50,6 +51,7 @@ type proxy struct {
Name field.String // 代理服务名称 Name field.String // 代理服务名称
Host field.String // 代理服务地址 Host field.String // 代理服务地址
Type field.Int32 // 代理服务类型0-自有1-三方 Type field.Int32 // 代理服务类型0-自有1-三方
Secret field.String // 代理服务密钥
CreatedAt field.Time // 创建时间 CreatedAt field.Time // 创建时间
UpdatedAt field.Time // 更新时间 UpdatedAt field.Time // 更新时间
DeletedAt field.Field // 删除时间 DeletedAt field.Field // 删除时间
@@ -74,6 +76,7 @@ func (p *proxy) updateTableName(table string) *proxy {
p.Name = field.NewString(table, "name") p.Name = field.NewString(table, "name")
p.Host = field.NewString(table, "host") p.Host = field.NewString(table, "host")
p.Type = field.NewInt32(table, "type") p.Type = field.NewInt32(table, "type")
p.Secret = field.NewString(table, "secret")
p.CreatedAt = field.NewTime(table, "created_at") p.CreatedAt = field.NewTime(table, "created_at")
p.UpdatedAt = field.NewTime(table, "updated_at") p.UpdatedAt = field.NewTime(table, "updated_at")
p.DeletedAt = field.NewField(table, "deleted_at") p.DeletedAt = field.NewField(table, "deleted_at")
@@ -93,12 +96,13 @@ func (p *proxy) GetFieldByName(fieldName string) (field.OrderExpr, bool) {
} }
func (p *proxy) fillFieldMap() { func (p *proxy) fillFieldMap() {
p.fieldMap = make(map[string]field.Expr, 8) p.fieldMap = make(map[string]field.Expr, 9)
p.fieldMap["id"] = p.ID p.fieldMap["id"] = p.ID
p.fieldMap["version"] = p.Version p.fieldMap["version"] = p.Version
p.fieldMap["name"] = p.Name p.fieldMap["name"] = p.Name
p.fieldMap["host"] = p.Host p.fieldMap["host"] = p.Host
p.fieldMap["type"] = p.Type p.fieldMap["type"] = p.Type
p.fieldMap["secret"] = p.Secret
p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["created_at"] = p.CreatedAt
p.fieldMap["updated_at"] = p.UpdatedAt p.fieldMap["updated_at"] = p.UpdatedAt
p.fieldMap["deleted_at"] = p.DeletedAt p.fieldMap["deleted_at"] = p.DeletedAt

View File

@@ -32,8 +32,8 @@ func newResourcePss(db *gorm.DB, opts ...gen.DOOption) resourcePss {
_resourcePss.Type = field.NewInt32(tableName, "type") _resourcePss.Type = field.NewInt32(tableName, "type")
_resourcePss.Live = field.NewInt32(tableName, "live") _resourcePss.Live = field.NewInt32(tableName, "live")
_resourcePss.Quota = field.NewInt32(tableName, "quota") _resourcePss.Quota = field.NewInt32(tableName, "quota")
_resourcePss.Used = field.NewInt32(tableName, "used")
_resourcePss.Expire = field.NewTime(tableName, "expire") _resourcePss.Expire = field.NewTime(tableName, "expire")
_resourcePss.Used = field.NewInt32(tableName, "used")
_resourcePss.DailyLimit = field.NewInt32(tableName, "daily_limit") _resourcePss.DailyLimit = field.NewInt32(tableName, "daily_limit")
_resourcePss.DailyUsed = field.NewInt32(tableName, "daily_used") _resourcePss.DailyUsed = field.NewInt32(tableName, "daily_used")
_resourcePss.DailyLast = field.NewTime(tableName, "daily_last") _resourcePss.DailyLast = field.NewTime(tableName, "daily_last")
@@ -55,8 +55,8 @@ type resourcePss struct {
Type field.Int32 // 套餐类型1-包时2-包量 Type field.Int32 // 套餐类型1-包时2-包量
Live field.Int32 // 可用时长(秒) Live field.Int32 // 可用时长(秒)
Quota field.Int32 // 配额数量 Quota field.Int32 // 配额数量
Used field.Int32 // 已用数量
Expire field.Time // 过期时间 Expire field.Time // 过期时间
Used field.Int32 // 已用数量
DailyLimit field.Int32 // 每日限制 DailyLimit field.Int32 // 每日限制
DailyUsed field.Int32 // 今日已用数量 DailyUsed field.Int32 // 今日已用数量
DailyLast field.Time // 今日最后使用时间 DailyLast field.Time // 今日最后使用时间
@@ -84,8 +84,8 @@ func (r *resourcePss) updateTableName(table string) *resourcePss {
r.Type = field.NewInt32(table, "type") r.Type = field.NewInt32(table, "type")
r.Live = field.NewInt32(table, "live") r.Live = field.NewInt32(table, "live")
r.Quota = field.NewInt32(table, "quota") r.Quota = field.NewInt32(table, "quota")
r.Used = field.NewInt32(table, "used")
r.Expire = field.NewTime(table, "expire") r.Expire = field.NewTime(table, "expire")
r.Used = field.NewInt32(table, "used")
r.DailyLimit = field.NewInt32(table, "daily_limit") r.DailyLimit = field.NewInt32(table, "daily_limit")
r.DailyUsed = field.NewInt32(table, "daily_used") r.DailyUsed = field.NewInt32(table, "daily_used")
r.DailyLast = field.NewTime(table, "daily_last") r.DailyLast = field.NewTime(table, "daily_last")
@@ -114,8 +114,8 @@ func (r *resourcePss) fillFieldMap() {
r.fieldMap["type"] = r.Type r.fieldMap["type"] = r.Type
r.fieldMap["live"] = r.Live r.fieldMap["live"] = r.Live
r.fieldMap["quota"] = r.Quota r.fieldMap["quota"] = r.Quota
r.fieldMap["used"] = r.Used
r.fieldMap["expire"] = r.Expire r.fieldMap["expire"] = r.Expire
r.fieldMap["used"] = r.Used
r.fieldMap["daily_limit"] = r.DailyLimit r.fieldMap["daily_limit"] = r.DailyLimit
r.fieldMap["daily_used"] = r.DailyUsed r.fieldMap["daily_used"] = r.DailyUsed
r.fieldMap["daily_last"] = r.DailyLast r.fieldMap["daily_last"] = r.DailyLast

View File

@@ -2,7 +2,6 @@ package web
import ( import (
"platform/web/handlers" "platform/web/handlers"
"platform/web/services"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
) )
@@ -18,10 +17,6 @@ func ApplyRouters(app *fiber.App) {
// 通道 // 通道
channel := api.Group("/channel") channel := api.Group("/channel")
channel.Post("/create", Permit([]services.PayloadType{ channel.Post("/create", PermitAll(), handlers.CreateChannel)
services.PayloadClientConfidential, channel.Post("/remove", PermitAll(), handlers.RemoveChannels)
services.PayloadClientPublic,
services.PayloadUser,
services.PayloadAdmin,
}), handlers.CreateChannel)
} }

View File

@@ -7,12 +7,16 @@ import (
"fmt" "fmt"
"log/slog" "log/slog"
"math" "math"
"platform/pkg/env"
"platform/pkg/orm" "platform/pkg/orm"
"platform/pkg/rds" "platform/pkg/rds"
"platform/pkg/remote" "platform/pkg/remote"
"platform/pkg/v"
"platform/web/common" "platform/web/common"
"platform/web/models" "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"strconv"
"strings"
"time" "time"
"github.com/google/uuid" "github.com/google/uuid"
@@ -26,143 +30,6 @@ var Channel = &channelService{}
type channelService struct { type channelService struct {
} }
// CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误
func (s *channelService) CreateChannel(
ctx context.Context,
auth *AuthContext,
resourceId int32,
protocol ChannelProtocol,
authType ChannelAuthType,
count int,
nodeFilter ...NodeFilterConfig,
) ([]*models.Channel, error) {
// 创建通道
var channels []*models.Channel
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐
var resource = ResourceInfo{}
err := q.Resource.As("data").
LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)).
Where(q.Resource.ID.Eq(resourceId)).
Scan(&resource)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return ChannelServiceErr("套餐不存在")
}
return err
}
// 检查使用人
if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId {
return common.AuthForbiddenErr("无权限访问")
}
// 检查套餐状态
if !resource.Active {
return ChannelServiceErr("套餐已失效")
}
// 检查每日限额
today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02")
dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0))
if today && dailyRemain < count {
return ChannelServiceErr("套餐每日配额不足")
}
// 检查时间或配额
if resource.Type == 1 { // 包时
if resource.Expire.Before(time.Now()) {
return ChannelServiceErr("套餐已过期")
}
} else { // 包量
remain := int(math.Max(float64(resource.Quota-resource.Used), 0))
if remain < count {
return ChannelServiceErr("套餐配额不足")
}
}
// 筛选可用节点
nodes, err := Node.Filter(ctx, auth.Payload.Id, count, nodeFilter...)
if err != nil {
return err
}
// 获取用户配置白名单
whitelist, err := q.Whitelist.Where(
q.Whitelist.UserID.Eq(auth.Payload.Id),
).Find()
if err != nil {
return err
}
// 创建连接通道
channels = make([]*models.Channel, 0, len(nodes)*len(whitelist))
for _, node := range nodes {
for _, allowed := range whitelist {
username, password := genPassPair()
channels = append(channels, &models.Channel{
UserID: auth.Payload.Id,
NodeID: node.ID,
UserHost: allowed.Host,
NodeHost: node.Host,
ProxyPort: node.ProxyPort,
Protocol: string(protocol),
AuthIP: authType == ChannelAuthTypeIp,
AuthPass: authType == ChannelAuthTypePass,
Username: username,
Password: password,
Expiration: time.Now().Add(time.Duration(resource.Live) * time.Second),
})
}
}
// 保存到数据库
err = tx.Channel.Create(channels...)
if err != nil {
return err
}
// 更新套餐使用记录
if today {
resource.DailyUsed += int32(count)
resource.Used += int32(count)
} else {
resource.DailyLast = time.Now()
resource.DailyUsed = int32(count)
resource.Used += int32(count)
}
err = tx.ResourcePss.
Where(q.ResourcePss.ID.Eq(resource.Id)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast).
Save(&models.ResourcePss{
Used: resource.Used,
DailyUsed: resource.DailyUsed,
DailyLast: resource.DailyLast})
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
// 缓存通道信息与异步删除任务
// err = cache(ctx, channels)
// if err != nil {
// return nil, err
// }
// 返回连接通道列表
return channels, errors.New("not implemented")
}
type ChannelAuthType int type ChannelAuthType int
const ( const (
@@ -178,24 +45,23 @@ const (
ProtocolHttps = ChannelProtocol("https") ProtocolHttps = ChannelProtocol("https")
) )
func genPassPair() (string, string) { type ResourceInfo struct {
usernameBytes, err := uuid.New().MarshalBinary() Id int32
if err != nil { UserId int32
panic(err) Active bool
} Type int32
passwordBytes, err := uuid.New().MarshalBinary() Live int32
if err != nil { DailyLimit int32
panic(err) DailyUsed int32
} DailyLast time.Time
username := base62.EncodeToString(usernameBytes) Quota int32
password := base62.EncodeToString(passwordBytes) Used int32
return username, password Expire time.Time
} }
// region RemoveChannel
func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error { func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error {
var channels []*models.Channel
// 删除通道 // 删除通道
err := q.Q.Transaction(func(tx *q.Query) error { err := q.Q.Transaction(func(tx *q.Query) error {
// 查找通道 // 查找通道
@@ -206,15 +72,30 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
return err return err
} }
// 检查权限,只有用户自己和管理员能删除 // 检查权限,如果为用户操作的话,则只能删除自己的通道
for _, channel := range channels { for _, channel := range channels {
if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID { if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID {
return common.AuthForbiddenErr("无权限访问") return common.AuthForbiddenErr("无权限访问")
} }
} }
// 查找代理
proxySet := make(map[int32]struct{})
proxyIds := make([]int32, 0)
for _, channel := range channels {
if _, ok := proxySet[channel.ProxyID]; !ok {
proxyIds = append(proxyIds, channel.ProxyID)
proxySet[channel.ProxyID] = struct{}{}
}
}
proxies, err := tx.Proxy.Where(
q.Proxy.ID.In(proxyIds...),
).Find()
// 删除指定的通道 // 删除指定的通道
result, err := tx.Channel.Delete(channels...) result, err := tx.Channel.
Where(q.Channel.ID.In(id...)).
Update(q.Channel.DeletedAt, time.Now())
if err != nil { if err != nil {
return err return err
} }
@@ -222,30 +103,103 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext,
return ChannelServiceErr("删除通道失败") return ChannelServiceErr("删除通道失败")
} }
return nil
})
if err != nil {
return err
}
// 删除缓存,异步任务直接在消费端处理删除 // 删除缓存,异步任务直接在消费端处理删除
err = deleteCache(ctx, channels) err = deleteCache(ctx, channels)
if err != nil { if err != nil {
return err return err
} }
// 禁用代理端口并下线用过的节点
if env.DebugExternalChange {
var configMap = make(map[int32][]remote.PortConfigsReq, len(proxies))
var proxyMap = make(map[int32]*models.Proxy, len(proxies))
for _, proxy := range proxies {
configMap[proxy.ID] = make([]remote.PortConfigsReq, 0)
proxyMap[proxy.ID] = proxy
}
var portMap = make(map[uint64]struct{})
for _, channel := range channels {
var config = remote.PortConfigsReq{
Port: int(channel.ProxyPort),
Edge: &[]string{},
AutoEdgeConfig: &remote.AutoEdgeConfig{
Count: v.P(0),
},
Status: false,
}
configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config)
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
portMap[key] = struct{}{}
}
for proxyId, configs := range configMap {
if len(configs) == 0 {
continue
}
proxy, ok := proxyMap[proxyId]
if !ok {
return ChannelServiceErr("代理不存在")
}
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
// 查询配置的节点
actives, err := gateway.GatewayPortActive()
if err != nil {
return err
}
// 取消配置
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return err
}
// 下线对应节点
var edges []string
for portStr, active := range actives {
port, err := strconv.Atoi(portStr)
if err != nil {
return err
}
key := uint64(proxyId)<<32 | uint64(port)
if _, ok := portMap[key]; ok {
edges = append(edges, active.Edge...)
}
}
if len(edges) > 0 {
_, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{
Uuid: proxy.Name,
Edge: edges,
})
if err != nil {
return err
}
}
}
}
return nil
})
if err != nil {
return err
}
return nil return nil
} }
type ChannelServiceErr string // endregion
func (c ChannelServiceErr) Error() string { // region CreateChannel
return string(c)
}
// region channel by remote func (s *channelService) CreateChannel(
func (s *channelService) RemoteCreateChannel(
ctx context.Context, ctx context.Context,
auth *AuthContext, auth *AuthContext,
resourceId int32, resourceId int32,
@@ -253,18 +207,21 @@ func (s *channelService) RemoteCreateChannel(
authType ChannelAuthType, authType ChannelAuthType,
count int, count int,
nodeFilter ...NodeFilterConfig, nodeFilter ...NodeFilterConfig,
) ([]AssignPortResult, error) { ) ([]string, error) {
filter := NodeFilterConfig{} filter := NodeFilterConfig{}
if len(nodeFilter) > 0 { if len(nodeFilter) > 0 {
filter = nodeFilter[0] filter = nodeFilter[0]
} }
var addr []string
err := q.Q.Transaction(func(tx *q.Query) error {
// 查找套餐 // 查找套餐
var resource = new(ResourceInfo) var resource = new(ResourceInfo)
data := q.Resource.As("data") data := q.Resource.As("data")
pss := q.ResourcePss.As("pss") pss := q.ResourcePss.As("pss")
err := data.Debug().Scopes(orm.Alias(data)). err := data.Scopes(orm.Alias(data)).
Select( Select(
data.ID, data.UserID, data.Active, data.ID, data.UserID, data.Active,
pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire, pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire,
@@ -274,36 +231,54 @@ func (s *channelService) RemoteCreateChannel(
Scan(&resource) Scan(&resource)
if err != nil { if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ChannelServiceErr("套餐不存在") return ChannelServiceErr("套餐不存在")
} }
return nil, err return err
} }
// 检查用户权限 // 检查用户权限
err = checkUser(auth, resource, count) err = checkUser(auth, resource, count)
if err != nil { if err != nil {
return nil, err return err
} }
slog.Debug("检查用户权限完成")
var postAssigns []AssignPortResult
err = q.Q.Transaction(func(tx *q.Query) error {
// 申请节点 // 申请节点
edgeAssigns, err := assignEdge(count, filter) edgeAssigns, err := assignEdge(count, filter)
if err != nil { if err != nil {
return err return err
} }
debugAssigned := fmt.Sprintf("%+v", edgeAssigns)
slog.Debug("申请节点完成", "edgeAssigns", debugAssigned)
// 分配端口 // 分配端口
expiration := time.Now().Add(time.Duration(resource.Live) * time.Second) now := time.Now()
postAssigns, err = assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter) expiration := now.Add(time.Duration(resource.Live) * time.Second)
_addr, channels, err := assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter)
if err != nil {
return err
}
addr = _addr
// 更新套餐使用记录
_, err = q.ResourcePss.
Where(q.ResourcePss.ResourceID.Eq(resourceId)).
Select(
q.ResourcePss.Used,
q.ResourcePss.DailyUsed,
q.ResourcePss.DailyLast,
).
Updates(models.ResourcePss{
Used: resource.Used + int32(count),
DailyUsed: resource.DailyUsed + int32(count),
DailyLast: now,
})
if err != nil {
return err
}
// 缓存通道数据
err = cache(ctx, channels)
if err != nil { if err != nil {
return err return err
} }
debugChannels := fmt.Sprintf("%+v", postAssigns)
slog.Debug("分配端口完成", "portAssigns", debugChannels)
return nil return nil
}) })
@@ -311,17 +286,9 @@ func (s *channelService) RemoteCreateChannel(
return nil, err return nil, err
} }
// 缓存并关闭代理 return addr, nil
err = cache(ctx, postAssigns)
if err != nil {
return nil, err
}
return postAssigns, nil
} }
// endregion
func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error {
// 检查使用人 // 检查使用人
@@ -368,17 +335,17 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
} }
// 查询已配置的节点 // 查询已配置的节点
allConfigs, err := remote.Client.CloudAutoQuery() rProxyConfigs, err := remote.Client.CloudAutoQuery()
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 查询已分配的节点 // 查询已使用的节点
var proxyIds = make([]int32, len(proxies)) var proxyIds = make([]int32, len(proxies))
for i, proxy := range proxies { for i, proxy := range proxies {
proxyIds[i] = proxy.ID proxyIds[i] = proxy.ID
} }
assigns, err := q.Channel. channels, err := q.Channel.
Select( Select(
q.Channel.ProxyID, q.Channel.ProxyID,
q.Channel.ProxyPort). q.Channel.ProxyPort).
@@ -392,80 +359,86 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
var proxyUses = make(map[int32]int, len(channels))
for _, channel := range channels {
proxyUses[channel.ProxyID]++
}
// 过滤需要变动的连接配置 // 组织数据
var current = 0 var infos = make([]*ProxyInfo, len(proxies))
var result = make([]ProxyConfig, len(proxies))
for i, proxy := range proxies { for i, proxy := range proxies {
remoteConfigs, ok := allConfigs[proxy.Name] infos[i] = &ProxyInfo{
proxy: proxy,
used: proxyUses[proxy.ID],
}
rConfigs, ok := rProxyConfigs[proxy.Name]
if !ok { if !ok {
result[i] = ProxyConfig{ infos[i].count = 0
proxy: proxy,
config: &remote.AutoConfig{
Province: filter.Prov,
City: filter.City,
Isp: filter.Isp,
Count: 0,
},
}
continue continue
} }
for _, config := range remoteConfigs {
if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov { for _, rConfig := range rConfigs {
current += config.Count if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov {
result[i] = ProxyConfig{ infos[i].count = rConfig.Count
proxy: proxy,
config: &config,
}
} }
} }
} }
// 如果需要新增节点 // 分配新增节点
var needed = len(assigns) + count var configs = make([]*ProxyConfig, len(proxies))
if needed-current > 0 { var needed = len(channels) + count
slog.Debug("新增新节点", "needed", needed, "current", current)
avg := int(math.Ceil(float64(needed) / float64(len(proxies)))) avg := int(math.Ceil(float64(needed) / float64(len(proxies))))
for i, assign := range result { for i, info := range infos {
var prev = assign.config.Count var prev = info.used
var next = assign.config.Count var next = int(math.Min(float64(avg), float64(needed)))
if prev >= avg || prev >= needed {
continue
}
next = int(math.Min(float64(avg), float64(needed))) info.used = int(math.Max(float64(prev), float64(next)))
result[i].config.Count = next - prev needed -= info.used
needed -= next
if env.DebugExternalChange && info.used > info.count {
slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count)
err := remote.Client.CloudConnect(remote.CloudConnectReq{ err := remote.Client.CloudConnect(remote.CloudConnectReq{
Uuid: assign.proxy.Name, Uuid: info.proxy.Name,
Edge: nil, Edge: nil,
AutoConfig: []remote.AutoConfig{{ AutoConfig: []remote.AutoConfig{{
Province: filter.Prov, Province: filter.Prov,
City: filter.City, City: filter.City,
Isp: filter.Isp, Isp: filter.Isp,
Count: next, Count: int(math.Ceil(float64(info.used) * 1.1)),
}}, }},
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
configs[i] = &ProxyConfig{
proxy: info.proxy,
count: int(math.Max(float64(next-prev), 0)),
}
} }
return &AssignEdgeResult{ return &AssignEdgeResult{
configs: result, configs: configs,
channels: assigns, channels: channels,
}, nil }, nil
} }
type ProxyInfo struct {
proxy *models.Proxy
used int
count int
}
type AssignEdgeResult struct { type AssignEdgeResult struct {
configs []ProxyConfig configs []*ProxyConfig
channels []*models.Channel channels []*models.Channel
} }
type ProxyConfig struct { type ProxyConfig struct {
proxy *models.Proxy proxy *models.Proxy
config *remote.AutoConfig count int
} }
// assignPort 分配指定数量的端口 // assignPort 分配指定数量的端口
@@ -476,9 +449,9 @@ func assignPort(
authType ChannelAuthType, authType ChannelAuthType,
expiration time.Time, expiration time.Time,
filter NodeFilterConfig, filter NodeFilterConfig,
) ([]AssignPortResult, error) { ) ([]string, []*models.Channel, error) {
var assigns = proxies.configs var assigns = proxies.configs
var channels = proxies.channels var exists = proxies.channels
// 查询代理已配置端口 // 查询代理已配置端口
var proxyIds = make([]int32, 0, len(assigns)) var proxyIds = make([]int32, 0, len(assigns))
@@ -488,24 +461,20 @@ func assignPort(
// 端口查找表 // 端口查找表
var proxyPorts = make(map[uint64]struct{}) var proxyPorts = make(map[uint64]struct{})
for _, channel := range channels { for _, channel := range exists {
key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort) key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort)
proxyPorts[key] = struct{}{} proxyPorts[key] = struct{}{}
} }
// 配置启用代理 // 配置启用代理
var result = make([]AssignPortResult, len(assigns)) var result []string
for i, assign := range assigns { var channels []*models.Channel
for _, assign := range assigns {
var err error var err error
var proxy = assign.proxy var proxy = assign.proxy
var count = assign.config.Count var count = assign.count
result[i] = AssignPortResult{
Proxy: proxy,
}
// 筛选可用端口 // 筛选可用端口
var channels = result[i].Channels
var configs = make([]remote.PortConfigsReq, 0, count) var configs = make([]remote.PortConfigsReq, 0, count)
for port := 10000; port < 20000 && len(configs) < count; port++ { for port := 10000; port < 20000 && len(configs) < count; port++ {
// 跳过存在的端口 // 跳过存在的端口
@@ -521,13 +490,14 @@ func assignPort(
Port: port, Port: port,
Edge: nil, Edge: nil,
Status: true, Status: true,
AutoEdgeConfig: remote.AutoEdgeConfig{ AutoEdgeConfig: &remote.AutoEdgeConfig{
Province: filter.Prov, Province: filter.Prov,
City: filter.City, City: filter.City,
Isp: filter.Isp, Isp: filter.Isp,
Count: 1, Count: v.P(1),
}, },
}) })
switch authType { switch authType {
case ChannelAuthTypeIp: case ChannelAuthTypeIp:
var whitelist []string var whitelist []string
@@ -536,9 +506,10 @@ func assignPort(
Select(q.Whitelist.Host). Select(q.Whitelist.Host).
Scan(&whitelist) Scan(&whitelist)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
configs[i].Whitelist = whitelist configs[i].Whitelist = &whitelist
configs[i].Userpass = v.P("")
for _, item := range whitelist { for _, item := range whitelist {
channels = append(channels, &models.Channel{ channels = append(channels, &models.Channel{
UserID: userId, UserID: userId,
@@ -553,7 +524,8 @@ func assignPort(
} }
case ChannelAuthTypePass: case ChannelAuthTypePass:
username, password := genPassPair() username, password := genPassPair()
configs[i].Userpass = fmt.Sprintf("%s:%s", username, password) configs[i].Whitelist = new([]string)
configs[i].Userpass = v.P(fmt.Sprintf("%s:%s", username, password))
channels = append(channels, &models.Channel{ channels = append(channels, &models.Channel{
UserID: userId, UserID: userId,
ProxyID: proxy.ID, ProxyID: proxy.ID,
@@ -566,66 +538,82 @@ func assignPort(
Expiration: expiration, Expiration: expiration,
}) })
} }
}
result[i].Channels = channels result = append(result, fmt.Sprintf("%s://%s:%d", protocol, proxy.Host, port))
}
if len(configs) < count { if len(configs) < count {
return nil, ChannelServiceErr("网关端口数量到达上限,无法分配") return nil, nil, ChannelServiceErr("网关端口数量到达上限,无法分配")
}
// 提交端口配置
gateway := remote.InitGateway(
proxy.Host,
"api",
"123456",
)
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return nil, err
} }
// 保存到数据库 // 保存到数据库
err = q.Channel. err = q.Channel.
Omit(q.Channel.NodeID). Omit(
q.Channel.NodeID,
q.Channel.NodeHost,
q.Channel.Username,
q.Channel.Password,
q.Channel.DeletedAt,
).
Save(channels...) Save(channels...)
if err != nil { if err != nil {
return nil, err return nil, nil, err
}
// 提交端口配置并更新节点列表
if env.DebugExternalChange {
var secret = strings.Split(proxy.Secret, ":")
gateway := remote.InitGateway(
proxy.Host,
secret[0],
secret[1],
)
err = gateway.GatewayPortConfigs(configs)
if err != nil {
return nil, nil, err
}
} }
} }
return result, nil return result, channels, nil
} }
type AssignPortResult struct { // endregion
Proxy *models.Proxy
Channels []*models.Channel func genPassPair() (string, string) {
usernameBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
passwordBytes, err := uuid.New().MarshalBinary()
if err != nil {
panic(err)
}
username := base62.EncodeToString(usernameBytes)
password := base62.EncodeToString(passwordBytes)
return username, password
} }
func chKey(channel *models.Channel) string { func chKey(channel *models.Channel) string {
return fmt.Sprintf("channel:%s:%s", channel.UserHost, channel.NodeHost) return fmt.Sprintf("channel:%d", channel.ID)
} }
func cache(ctx context.Context, assigns []AssignPortResult) error { func cache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline() pipe := rds.Client.TxPipeline()
zList := make([]redis.Z, 0, len(assigns)) zList := make([]redis.Z, 0, len(channels))
for _, assign := range assigns {
var channels = assign.Channels
for _, channel := range channels { for _, channel := range channels {
marshal, err := json.Marshal(assign) marshal, err := json.Marshal(channel)
if err != nil { if err != nil {
return err return err
} }
pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now())) pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now()))
zList = append(zList, redis.Z{ zList = append(zList, redis.Z{
Score: float64(channel.Expiration.Unix()), Score: float64(channel.Expiration.Unix()),
Member: channel.ID, Member: channel.ID,
}) })
} }
} pipe.ZAdd(ctx, "tasks:channel", zList...)
pipe.ZAdd(ctx, "tasks:assign", zList...)
_, err := pipe.Exec(ctx) _, err := pipe.Exec(ctx)
if err != nil { if err != nil {
@@ -636,14 +624,11 @@ func cache(ctx context.Context, assigns []AssignPortResult) error {
} }
func deleteCache(ctx context.Context, channels []*models.Channel) error { func deleteCache(ctx context.Context, channels []*models.Channel) error {
pipe := rds.Client.TxPipeline() keys := make([]string, len(channels))
keys := make([]string, 0, len(channels)) for i := range channels {
for i := range keys {
keys[i] = chKey(channels[i]) keys[i] = chKey(channels[i])
} }
pipe.Del(ctx, keys...) _, err := rds.Client.Del(ctx, keys...).Result()
// 忽略异步任务zrem 效率较低,在使用时再删除
_, err := pipe.Exec(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -651,16 +636,8 @@ func deleteCache(ctx context.Context, channels []*models.Channel) error {
return nil return nil
} }
type ResourceInfo struct { type ChannelServiceErr string
Id int32
UserId int32 func (c ChannelServiceErr) Error() string {
Active bool return string(c)
Type int32
Live int32
DailyLimit int32
DailyUsed int32
DailyLast time.Time
Quota int32
Used int32
Expire time.Time
} }

View File

@@ -1,13 +1,17 @@
package web package web
import ( import (
"net/http"
"platform/pkg/env" "platform/pkg/env"
"log/slog"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/requestid" "github.com/gofiber/fiber/v2/middleware/requestid"
) )
import "log/slog"
import _ "net/http/pprof"
type Config struct { type Config struct {
Listen string Listen string
@@ -30,6 +34,7 @@ func New(config *Config) (*Server, error) {
} }
func (s *Server) Run() error { func (s *Server) Run() error {
s.fiber = fiber.New(fiber.Config{ s.fiber = fiber.New(fiber.Config{
ErrorHandler: ErrorHandler, ErrorHandler: ErrorHandler,
}) })
@@ -39,6 +44,13 @@ func (s *Server) Run() error {
ApplyRouters(s.fiber) ApplyRouters(s.fiber)
go func() {
err := http.ListenAndServe(":6060", nil)
if err != nil {
slog.Error("pprof 服务错误", slog.Any("err", err))
}
}()
port := env.AppPort port := env.AppPort
slog.Info("Server started on :" + port) slog.Info("Server started on :" + port)
err := s.fiber.Listen(":" + port) err := s.fiber.Listen(":" + port)