diff --git a/README.md b/README.md index 412689d..5a97007 100644 --- a/README.md +++ b/README.md @@ -19,14 +19,18 @@ - [ ] Limiter - [ ] Compress +channel 数据存入顺序,数据库 > 缓存 > 外部接口 + +remote 令牌问题 + +用对称加密处理密钥 + 现在的节点分配逻辑是,每个 user_host:node_port 组算一个分配数,考虑是否改成每个用户算一个分配数 考虑将鉴权逻辑放到 handler 里,统一动静态鉴权以及解耦服务层 有些地方在用手动事务,有时间改成自动事务 -remote 用环境变量保存账号密码! - 重新手动实现 model 层 环境变量配置默认会话配置 diff --git a/cmd/fill/main.go b/cmd/fill/main.go index 17cdf41..b4d31e1 100644 --- a/cmd/fill/main.go +++ b/cmd/fill/main.go @@ -5,8 +5,9 @@ import ( "platform/pkg/env" "platform/pkg/logs" "platform/pkg/orm" - "platform/web/models" + m "platform/web/models" q "platform/web/queries" + "time" "golang.org/x/crypto/bcrypt" ) @@ -16,59 +17,96 @@ func main() { logs.Init() orm.Init() - q.User. - Select( - q.User.Phone). - Create(&models.User{ - Phone: "12312341234"}) + err := q.Q.Transaction(func(tx *q.Query) error { + q.User. + Select(q.User.Phone). + Save(&m.User{ + Phone: "12312341234", + }) + var user, _ = q.User.First() - q.Proxy. - Select( - q.Proxy.Version, - q.Proxy.Name, - q.Proxy.Host, - q.Proxy.Type). - Create(&models.Proxy{ - Version: 1, - Name: "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", - Host: "110.40.82.248", - Type: 1}) + q.Resource. + Select(q.Resource.UserID, q.Resource.Active). + Create(&m.Resource{ + UserID: user.ID, + Active: true, + }) + var resource, _ = q.Resource.First() - q.Node. - Select( - q.Node.Version, - q.Node.Name, - q.Node.Host, - q.Node.Isp, - q.Node.Prov, - q.Node.City, - q.Node.Status). - Create(&models.Node{ - Version: 1, - Name: "test-node", - Host: "123", - Isp: "test-isp", - Prov: "test-prov", - City: "test-city", - Status: 1}) + q.ResourcePss. + Select( + q.ResourcePss.ResourceID, + q.ResourcePss.Live, + q.ResourcePss.Type, + q.ResourcePss.Expire, + q.ResourcePss.DailyLimit, + ). + Create(&m.ResourcePss{ + ResourceID: resource.ID, + Live: 180, + Type: 1, + Expire: time.Now().Add(24 * time.Hour * 1000), + DailyLimit: 300000, + }) - var secret, _ = bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost) - q.Client. - Select( - q.Client.ClientID, - q.Client.ClientSecret, - q.Client.GrantClient, - q.Client.GrantRefresh, - q.Client.Spec, - q.Client.Name). - Create(&models.Client{ - ClientID: "test", - ClientSecret: string(secret), - GrantClient: true, - GrantRefresh: true, - Spec: 0, - Name: "默认客户端", - }) + q.Proxy. + Select(q.Proxy.Version, q.Proxy.Name, q.Proxy.Host, q.Proxy.Type, q.Proxy.Secret). + Create(&m.Proxy{ + Version: 1, + Name: "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", + Host: "110.40.82.248", + Type: 1, + Secret: "api:123456", + }) + + q.Node. + Select( + q.Node.Version, + q.Node.Name, + q.Node.Host, + q.Node.Isp, + q.Node.Prov, + q.Node.City, + q.Node.Status). + Create(&m.Node{ + Version: 1, + Name: "test-node", + Host: "123", + Isp: "test-isp", + Prov: "test-prov", + City: "test-city", + Status: 1}) + + var testSecret, _ = bcrypt.GenerateFromPassword([]byte("test"), bcrypt.DefaultCost) + var tasksSecret, _ = bcrypt.GenerateFromPassword([]byte("tasks"), bcrypt.DefaultCost) + q.Client. + Select( + q.Client.ClientID, + q.Client.ClientSecret, + q.Client.GrantClient, + q.Client.GrantRefresh, + q.Client.Spec, + q.Client.Name). + Create(&m.Client{ + ClientID: "test", + ClientSecret: string(testSecret), + GrantClient: true, + GrantRefresh: true, + Spec: 0, + Name: "默认客户端", + }, &m.Client{ + ClientID: "tasks", + ClientSecret: string(tasksSecret), + GrantClient: true, + GrantRefresh: true, + Spec: 0, + Name: "异步任务处理服务", + }) + return nil + }) + if err != nil { + panic(err) + } slog.Info("✔ Data inserted successfully") } diff --git a/cmd/playground/main.go b/cmd/playground/main.go index 311277e..7e6261c 100644 --- a/cmd/playground/main.go +++ b/cmd/playground/main.go @@ -1,33 +1,11 @@ package main -import ( - "fmt" - "platform/pkg/env" - "platform/pkg/logs" - "platform/pkg/orm" - "platform/web/models" - q "platform/web/queries" -) +import "math" -type ResourceInfo struct { - data models.Resource - pss models.ResourcePss -} +var b62Set = make(map[string]struct{}) +var b64Set = make(map[string]struct{}) func main() { - - env.Init() - logs.Init() - orm.Init() - - var resource = new(ResourceInfo) - data := q.Resource.As("data") - pss := q.ResourcePss.As("pss") - _ = data.Debug().Scopes(orm.Alias(data)). - Select(data.ALL, pss.ALL). - LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)). - Where(data.ID.Eq(1)). - Scan(&resource) - - fmt.Printf("%+v\n", resource) + println(int(math.Ceil(100 * 1.1))) + println(int(math.Ceil(float64(100) * 1.1))) } diff --git a/cmd/tasks/main.go b/cmd/tasks/main.go index beba0e2..a701d7d 100644 --- a/cmd/tasks/main.go +++ b/cmd/tasks/main.go @@ -2,12 +2,18 @@ package main import ( "context" + "encoding/json" "errors" + "io" "log/slog" + "net/http" "platform/pkg/env" "platform/pkg/logs" - "platform/pkg/orm" "platform/pkg/rds" + "reflect" + "strconv" + "strings" + "sync" "time" "github.com/redis/go-redis/v9" @@ -17,27 +23,51 @@ func main() { Start() } +var taskList = make(map[string]func(ctx context.Context, curr time.Time) error) + func Start() { ctx := context.Background() env.Init() logs.Init() rds.Init() - orm.Init() + + taskList["stopChannels"] = stopChannels ticker := time.NewTicker(time.Second) defer ticker.Stop() + // 互斥锁确保同一时间只有一个协程运行 + // 如果之前的 tick 操作未完成,则跳过当前 tick + var mutex = &sync.Mutex{} for curr := range ticker.C { - err := process(ctx, curr) - if err != nil { - panic(err) + if mutex.TryLock() { + err := process(ctx, curr) + if err != nil { + panic(err) + } + mutex.Unlock() + } else { + slog.Warn("skip tick", slog.String("tick", curr.Format("2006-01-02 15:04:05"))) } } } func process(ctx context.Context, curr time.Time) error { + // todo 异步化 + for name, task := range taskList { + err := task(ctx, curr) + if err != nil { + slog.Error("task failed", slog.String("task", name), slog.String("error", err.Error())) + } + } + + return nil +} + +func stopChannels(ctx context.Context, curr time.Time) error { + // 获取并删除 script := redis.NewScript(` local result = redis.call('ZRANGEBYSCORE', KEYS[1], 0, ARGV[1]) @@ -47,21 +77,64 @@ func process(ctx context.Context, curr time.Time) error { return result `) - // 计算时间范围 // 执行脚本 - result, err := script.Run(ctx, rds.Client, []string{"tasks:session"}, curr.Unix()).Result() + result, err := script.Run(ctx, rds.Client, []string{"tasks:channel"}, curr.Unix()).Result() if err != nil { return err } // 处理结果 - list, ok := result.([]string) + list, ok := result.([]any) if !ok { return errors.New("failed to convert result to []string") } - for _, item := range list { - // 从数据库删除授权信息 - slog.Debug(item) + var ids = make([]int, len(list)) + for i, item := range list { + idStr, ok := item.(string) + if !ok { + slog.Debug(reflect.TypeOf(item).String()) + return errors.New("failed to convert item to string") + } + id, err := strconv.Atoi(idStr) + if err != nil { + return err + } + ids[i] = id + } + if len(ids) == 0 { + return nil + } + + var body = map[string]any{ + "by_ids": ids, + } + bodyStr, err := json.Marshal(body) + if err != nil { + return err + } + + req, err := http.NewRequest( + http.MethodPost, + "http://localhost:8080/api/channel/remove", // todo 环境变量获取服务地址 + strings.NewReader(string(bodyStr)), + ) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + req.SetBasicAuth("tasks", "tasks") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + body, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return errors.New("failed to stop channels: " + string(body)) } return nil diff --git a/pkg/env/env.go b/pkg/env/env.go index 05cbf80..2f90951 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -55,21 +55,21 @@ func loadDb() { if _DbName != "" { DbName = _DbName } else { - panic("环境变量 DB_NAME 的值为空") + panic("环境变量 DB_NAME 的值不能为空") } _DbUserName := os.Getenv("DB_USERNAME") if _DbUserName != "" { DbUserName = _DbUserName } else { - panic("环境变量 DB_USERNAME 的值为空") + panic("环境变量 DB_USERNAME 的值不能为空") } _DbPassword := os.Getenv("DB_PASSWORD") if _DbPassword != "" { DbPassword = _DbPassword } else { - panic("环境变量 DB_PASSWORD 的值为空") + panic("环境变量 DB_PASSWORD 的值不能为空") } } @@ -134,6 +134,60 @@ func loadLog() { // endregion +// region remote + +var ( + RemoteAddr = "http://103.139.212.110:9989" + RemoteToken string +) + +func loadRemote() { + _RemoteAddr := os.Getenv("REMOTE_ADDR") + if _RemoteAddr != "" { + RemoteAddr = _RemoteAddr + } + + _RemoteToken := os.Getenv("REMOTE_TOKEN") + if _RemoteToken == "" { + panic("环境变量 REMOTE_TOKEN 的值不能为空") + } + RemoteToken = _RemoteToken +} + +// endregion + +// region debug + +var ( + // DebugHttpDump 是否打印请求和响应的原始数据 + DebugHttpDump = false + // DebugExternalChange 是否实际执行非幂等外部接口的调用。 + // 例如外部数据修改接口,在内部接口调试时可以关闭,避免对外部数据产生影响 + DebugExternalChange = true +) + +func loadDebug() { + debugHttpDump := os.Getenv("DEBUG_HTTP_DUMP") + if debugHttpDump != "" { + value, err := strconv.ParseBool(debugHttpDump) + if err != nil { + panic("环境变量 DEBUG_HTTP_DUMP 的值不是布尔值") + } + DebugHttpDump = value + } + + debugExternalChange := os.Getenv("DEBUG_EXTERNAL_CHANGE") + if debugExternalChange != "" { + value, err := strconv.ParseBool(debugExternalChange) + if err != nil { + panic("环境变量 DEBUG_EXTERNAL_CHANGE 的值不是布尔值") + } + DebugExternalChange = value + } +} + +// endregion + func Init() { err := godotenv.Load() if err != nil { @@ -146,4 +200,6 @@ func Init() { loadDb() loadRedis() loadLog() + loadDebug() + loadRemote() } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 4b70d18..719ee42 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -6,6 +6,9 @@ import ( "fmt" "io" "net/http" + "net/http/httputil" + "net/url" + "platform/pkg/env" "strconv" "strings" ) @@ -18,10 +21,9 @@ type client struct { var Client client func Init() { - // todo 从环境变量中获取参数 Client = client{ - url: "http://103.139.212.110:9989", - token: "PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ==", + url: env.RemoteAddr, + token: env.RemoteToken, } } @@ -151,8 +153,8 @@ func (c *client) CloudConnect(param CloudConnectReq) error { type CloudDisconnectReq struct { Uuid string `json:"uuid"` - Edge []string `json:"edge"` - Config []Config `json:"auto_config"` + Edge []string `json:"edge,omitempty"` + Config []Config `json:"auto_config,omitempty"` } type Config struct { @@ -246,11 +248,29 @@ func (c *client) requestCloud(method string, url string, data string) (*http.Res req.Header.Set("Content-Type", "application/json") req.Header.Set("token", c.token) + if env.DebugHttpDump { + str, err := httputil.DumpRequest(req, true) + if err != nil { + return nil, err + } + fmt.Println("==============================") + fmt.Println(string(str)) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } + if env.DebugHttpDump { + str, err := httputil.DumpResponse(resp, true) + if err != nil { + return nil, err + } + fmt.Println("==============================") + fmt.Println(string(str)) + } + return resp, nil } @@ -267,22 +287,22 @@ func InitGateway(url, username, password string) *Gateway { // region gateway:/port/configs type PortConfigsReq struct { - Port int `json:"port"` - Edge []string `json:"edge,omitempty"` - Type string `json:"type,omitempty"` - Time int `json:"time,omitempty"` - Status bool `json:"status,omitempty"` - Rate int `json:"rate,omitempty"` - Whitelist []string `json:"whitelist,omitempty"` - Userpass string `json:"userpass,omitempty"` - AutoEdgeConfig AutoEdgeConfig `json:"auto_edge_config,omitempty"` + Port int `json:"port"` + Edge *[]string `json:"edge,omitempty"` + Type string `json:"type,omitempty"` + Time int `json:"time,omitempty"` + Status bool `json:"status"` + Rate int `json:"rate,omitempty"` + Whitelist *[]string `json:"whitelist,omitempty"` + Userpass *string `json:"userpass,omitempty"` + AutoEdgeConfig *AutoEdgeConfig `json:"auto_edge_config,omitempty"` } type AutoEdgeConfig struct { Province string `json:"province,omitempty"` City string `json:"city,omitempty"` Isp string `json:"isp,omitempty"` - Count int `json:"count,omitempty"` + Count *int `json:"count,omitempty"` PacketLoss int `json:"packet_loss,omitempty"` } @@ -332,8 +352,8 @@ func (c *Gateway) GatewayPortConfigs(params []PortConfigsReq) error { type PortActiveReq struct { Port string `json:"port"` - Active bool `json:"active"` - Status bool `json:"status"` + Active *bool `json:"active"` + Status *bool `json:"status"` } type PortActiveResp struct { @@ -352,22 +372,34 @@ type PortData struct { Userpass string `json:"userpass"` } -func (c *Gateway) GatewayPortActive(param PortActiveReq) (*PortActiveResp, error) { - - url := strings.Builder{} - url.WriteString("/port/active") - - if param.Port != "" { - url.WriteString("/") - url.WriteString(param.Port) +func (c *Gateway) GatewayPortActive(param ...PortActiveReq) (map[string]PortData, error) { + _param := PortActiveReq{} + if len(param) != 0 { + _param = param[0] } - url.WriteString("?active=") - url.WriteString(strconv.FormatBool(param.Active)) - url.WriteString("&status=") - url.WriteString(strconv.FormatBool(param.Status)) + path := strings.Builder{} + path.WriteString("/port/active") - resp, err := c.requestGateway("POST", url.String(), "") + if _param.Port != "" { + path.WriteString("/") + path.WriteString(_param.Port) + } + + values := url.Values{} + if _param.Active != nil { + values.Set("active", strconv.FormatBool(*_param.Active)) + } + if _param.Status != nil { + values.Set("status", strconv.FormatBool(*_param.Status)) + } + + if len(values) > 0 { + path.WriteString("?") + path.WriteString(values.Encode()) + } + + resp, err := c.requestGateway("GET", path.String(), "") if err != nil { return nil, err } @@ -390,7 +422,11 @@ func (c *Gateway) GatewayPortActive(param PortActiveReq) (*PortActiveResp, error return nil, err } - return &result, nil + if result.Code != 0 { + return nil, errors.New(result.Msg) + } + + return result.Data, nil } // endregion @@ -404,10 +440,28 @@ func (c *Gateway) requestGateway(method string, url string, data string) (*http. req.Header.Set("Content-Type", "application/json") + if env.DebugHttpDump { + str, err := httputil.DumpRequest(req, true) + if err != nil { + return nil, err + } + fmt.Println("==============================") + fmt.Println(string(str)) + } + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } + if env.DebugHttpDump { + str, err := httputil.DumpResponse(resp, true) + if err != nil { + return nil, err + } + fmt.Println("==============================") + fmt.Println(string(str)) + } + return resp, nil } diff --git a/pkg/v/v.go b/pkg/v/v.go new file mode 100644 index 0000000..68e278b --- /dev/null +++ b/pkg/v/v.go @@ -0,0 +1,6 @@ +package v + +// P 是一个工具函数,用于在表达式内原地创建一个指针 +func P[T any](v T) *T { + return &v +} diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index a0ed496..e8f3d99 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -388,6 +388,7 @@ create table proxy ( name varchar(255) not null unique, host varchar(255) not null, type int not null default 0, + secret varchar(255), created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, deleted_at timestamp @@ -402,6 +403,7 @@ comment on column proxy.version is '代理服务版本'; comment on column proxy.name is '代理服务名称'; comment on column proxy.host is '代理服务地址'; comment on column proxy.type is '代理服务类型:0-自有,1-三方'; +comment on column proxy.secret is '代理服务密钥'; comment on column proxy.created_at is '创建时间'; comment on column proxy.updated_at is '更新时间'; comment on column proxy.deleted_at is '删除时间'; @@ -507,6 +509,7 @@ create index channel_user_host_index on channel (user_host); create index channel_proxy_port_index on channel (proxy_port); create index channel_node_host_index on channel (node_host); create index channel_expiration_index on channel (expiration); +create index channel_deleted_at_index on channel (deleted_at); -- channel表字段注释 comment on table channel is '通道表'; @@ -593,8 +596,8 @@ create table resource_pss ( type int, live int, quota int, - used int, expire timestamp, + used int not null default 0, daily_limit int not null default 0, daily_used int not null default 0, daily_last timestamp, diff --git a/test/test-api.http b/test/test-api.http index a1acdef..0b3ba84 100644 --- a/test/test-api.http +++ b/test/test-api.http @@ -3,46 +3,37 @@ GET http://110.40.82.250:18702/server/index/getToken/key/juipbyjdapiverify ### remote 配置信息 GET http://103.139.212.110:9989/api/auto_query -token: PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ== +token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ== -### remote 配置重置 +### remote 配置连接 POST http://103.139.212.110:9989/api/connect -token: PhdnRF3z6VF2sPgygTSl1Xx6QJN0yFIK.anVpcA==.MTc0MzE2ODAwMQ== +token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ== Content-Type: application/json { "uuid": "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", "auto_config": [ { - "count": 1 + "count": 10 } ] } -### gateway 端口信息 -GET http://api:123456@110.40.82.248:9990/port/active/ - -### gateway 配置端口代理 -POST http://api:123456@110.40.82.248:9990/port/configs +### remote 下线全部 +POST http://103.139.212.110:9989/api/disconnect +token: et1wWdrLLRsiQPCar8GunNFEZqcxATFa.anVpcA==.MTc0MzM0MjAwMQ== Content-Type: application/json -//[ -// { -// "port": 10000, -// "status": true, -// "userpass": "mIfSlXIBwVUrKqObNdTzvB:cyDaGtfeBlhRfojTYJP2tR", -// "auto_edge_config": { -// "count": 1 -// } -// } -//] -[ - { - "port": 10000, - "status": false, - "edge": [] - } -] +{ + "uuid": "7a17e8b4-cdc3-4500-bf16-4a665991a7f6", + "config": {} +} + +### gateway 连接信息 +GET http://api:123456@110.40.82.248:9990/edge + +### gateway 端口信息 +GET http://api:123456@110.40.82.248:9990/port/active ### 设备令牌 POST http://localhost:8080/api/auth/token @@ -59,23 +50,20 @@ Content-Type: application/json client.global.set("refresh_token", response.body.refresh_token); %} -### 密码模式代理 +### 提取代理 POST http://localhost:8080/api/channel/create -Authorization: Bearer {{access_token}} +Authorization: Basic test test Content-Type: application/json Accept: application/json { "resource_id": 1, "protocol": "http", - "auth_type": 1, - "count": 1, + "auth_type": 0, + "count": 200, "prov": "", "city": "", "isp": "", "result_type": "text", "result_separator": "both" } - -### 白名单模式代理 - diff --git a/web/auth.go b/web/auth.go index e16941b..89d0115 100644 --- a/web/auth.go +++ b/web/auth.go @@ -1,7 +1,12 @@ package web import ( + "context" + "encoding/base64" + "errors" + "log/slog" "platform/web/common" + q "platform/web/queries" "slices" "strings" @@ -14,16 +19,36 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler { return func(c *fiber.Ctx) error { // 获取令牌 var header = c.Get("Authorization") - var token = strings.TrimPrefix(header, "Bearer ") - if token == "" { - return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ + var split = strings.Split(header, " ") + if len(split) != 2 { + return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{ Error: true, - Message: "没有权限", + Message: "无效的令牌", }) } - // 验证令牌 - auth, err := services.Session.Find(c.Context(), token) + var token = split[1] + if token == "" { + return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{ + Error: true, + Message: "无效的令牌", + }) + } + + var auth *services.AuthContext + var err error + switch split[0] { + case "Bearer": + auth, err = authBearer(c.Context(), token) + case "Basic": + if !slices.Contains(types, services.PayloadClientConfidential) { + return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ + Error: true, + Message: "没有权限", + }) + } + auth, err = authBasic(c.Context(), token) + } if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, @@ -32,22 +57,6 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler { } // 检查权限 - // switch auth.Payload.Type { - // case services.PayloadAdmin: - // // 管理员不需要权限检查 - // case services.PayloadUser: - // if len(permissions) > 0 && !auth.AnyPermission(permissions...) { - // return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - // Error: true, - // Message: "拒绝访问", - // }) - // } - // default: - // return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - // Error: true, - // Message: "拒绝访问", - // }) - // } if !slices.Contains(types, auth.Payload.Type) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, @@ -70,97 +79,95 @@ func Permit(types []services.PayloadType, permissions ...string) fiber.Handler { } +func PermitAll(permissions ...string) fiber.Handler { + return Permit([]services.PayloadType{ + services.PayloadClientPublic, + services.PayloadClientConfidential, + services.PayloadUser, + services.PayloadAdmin, + }, permissions...) +} + // PermitUser 创建针对单个路由的鉴权中间件 func PermitUser(permissions ...string) fiber.Handler { - return func(c *fiber.Ctx) error { - // 获取令牌 - var header = c.Get("Authorization") - var token = strings.TrimPrefix(header, "Bearer ") - if token == "" { - return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ - Error: true, - Message: "没有权限", - }) - } - - // 验证令牌 - auth, err := services.Session.Find(c.Context(), token) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ - Error: true, - Message: "没有权限", - }) - } - - // 检查权限 - switch auth.Payload.Type { - case services.PayloadAdmin: - // 管理员不需要权限检查 - case services.PayloadUser: - if len(permissions) > 0 && !auth.AnyPermission(permissions...) { - return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - Error: true, - Message: "拒绝访问", - }) - } - default: - return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - Error: true, - Message: "拒绝访问", - }) - } - - // 将认证信息存储在上下文中 - c.Locals("auth", auth) - c.Locals("access_token", token) // 存储原始令牌,便于后续操作 - - return c.Next() - } + return Permit([]services.PayloadType{ + services.PayloadUser, + services.PayloadAdmin, + }, permissions...) } func PermitDevice(permissions ...string) fiber.Handler { - return func(c *fiber.Ctx) error { - // 获取令牌 - var header = c.Get("Authorization") - var token = strings.TrimPrefix(header, "Bearer ") - if token == "" { - return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ - Error: true, - Message: "没有权限", - }) - } - - // 验证令牌 - auth, err := services.Session.Find(c.Context(), token) - if err != nil { - return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ - Error: true, - Message: "没有权限", - }) - } - - // 检查权限 - switch auth.Payload.Type { - case services.PayloadAdmin: - // 管理员不需要权限检查 - case services.PayloadClientPublic, services.PayloadClientConfidential: - if len(permissions) > 0 && !auth.AnyPermission(permissions...) { - return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - Error: true, - Message: "拒绝访问", - }) - } - default: - return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ - Error: true, - Message: "拒绝访问", - }) - } - - // 将认证信息存储在上下文中 - c.Locals("auth", auth) - c.Locals("access_token", token) // 存储原始令牌,便于后续操作 - - return c.Next() - } + return Permit([]services.PayloadType{ + services.PayloadClientPublic, + services.PayloadClientConfidential, + services.PayloadAdmin, + }, permissions...) +} + +func PermitPublic(permissions ...string) fiber.Handler { + return Permit([]services.PayloadType{ + services.PayloadClientPublic, + services.PayloadAdmin, + }, permissions...) +} + +func PermitConfidential(permissions ...string) fiber.Handler { + return Permit([]services.PayloadType{ + services.PayloadClientConfidential, + services.PayloadAdmin, + }, permissions...) +} + +func authBearer(ctx context.Context, token string) (*services.AuthContext, error) { + auth, err := services.Session.Find(ctx, token) + if err != nil { + slog.Debug(err.Error()) + return nil, err + } + return auth, nil +} + +func authBasic(ctx context.Context, token string) (*services.AuthContext, error) { + + // 解析 Basic 认证信息 + var base, err = base64.URLEncoding.DecodeString(token) + if err != nil { + slog.Debug(err.Error()) + return nil, err + } + + var split = strings.Split(string(base), ":") + if len(split) != 2 { + msg := "无法解析 Basic 认证信息" + slog.Debug(msg) + return nil, errors.New(msg) + } + + var clientID = split[0] + + // 获取客户端信息 + client, err := q.Client. + Where( + q.Client.ClientID.Eq(clientID), + q.Client.Spec.Eq(0), + q.Client.GrantClient.Is(true), + q.Client.Status.Eq(1)). + Take() + if err != nil { + return nil, err + } + + // todo 查询客户端关联权限 + + // 组织授权信息(一次性请求) + return &services.AuthContext{ + Payload: services.Payload{ + Id: client.ID, + Type: services.PayloadClientConfidential, + Name: client.Name, + Avatar: client.Icon, + }, + Permissions: nil, + Metadata: nil, + }, nil } diff --git a/web/handlers/channel.go b/web/handlers/channel.go index 433adc3..d9e98bb 100644 --- a/web/handlers/channel.go +++ b/web/handlers/channel.go @@ -2,7 +2,6 @@ package handlers import ( "errors" - "fmt" "platform/web/services" "strings" @@ -35,7 +34,7 @@ func CreateChannel(c *fiber.Ctx) error { return errors.New("user not found") } - assigns, err := services.Channel.RemoteCreateChannel( + result, err := services.Channel.CreateChannel( c.Context(), auth, req.ResourceId, @@ -52,17 +51,6 @@ func CreateChannel(c *fiber.Ctx) error { return err } - // 返回连接通道列表 - var result []string - for _, assign := range assigns { - var proxy = assign.Proxy - var channels = assign.Channels - for _, channel := range channels { - url := fmt.Sprintf("%s://%s:%d", channel.Protocol, proxy.Host, channel.ProxyPort) - result = append(result, url) - } - } - switch req.ResultType { case CreateChannelResultTypeJson: return c.JSON(fiber.Map{ @@ -101,3 +89,32 @@ const ( ) // endregion + +// region RemoveChannels + +type RemoveChannelsReq struct { + ByIds []int32 `json:"by_ids" validate:"required"` +} + +func RemoveChannels(c *fiber.Ctx) error { + req := new(RemoveChannelsReq) + if err := c.BodyParser(req); err != nil { + return err + } + + // 获取用户信息 + auth, ok := c.Locals("auth").(*services.AuthContext) + if !ok { + return errors.New("user not found") + } + + // 删除通道 + err := services.Channel.RemoveChannels(c.Context(), auth, req.ByIds...) + if err != nil { + return err + } + + return c.SendStatus(fiber.StatusOK) +} + +// endregion diff --git a/web/models/proxy.gen.go b/web/models/proxy.gen.go index be0e452..7447378 100644 --- a/web/models/proxy.gen.go +++ b/web/models/proxy.gen.go @@ -19,6 +19,7 @@ type Proxy struct { Name string `gorm:"column:name;not null;comment:代理服务名称" json:"name"` // 代理服务名称 Host string `gorm:"column:host;not null;comment:代理服务地址" json:"host"` // 代理服务地址 Type int32 `gorm:"column:type;not null;comment:代理服务类型:0-自有,1-三方" json:"type"` // 代理服务类型:0-自有,1-三方 + Secret string `gorm:"column:secret;comment:代理服务密钥" json:"secret"` // 代理服务密钥 CreatedAt time.Time `gorm:"column:created_at;default:CURRENT_TIMESTAMP;comment:创建时间" json:"created_at"` // 创建时间 UpdatedAt time.Time `gorm:"column:updated_at;default:CURRENT_TIMESTAMP;comment:更新时间" json:"updated_at"` // 更新时间 DeletedAt gorm.DeletedAt `gorm:"column:deleted_at;comment:删除时间" json:"deleted_at"` // 删除时间 diff --git a/web/models/resource_pss.gen.go b/web/models/resource_pss.gen.go index f3f813e..d3c3fa2 100644 --- a/web/models/resource_pss.gen.go +++ b/web/models/resource_pss.gen.go @@ -19,8 +19,8 @@ type ResourcePss struct { Type int32 `gorm:"column:type;comment:套餐类型:1-包时,2-包量" json:"type"` // 套餐类型:1-包时,2-包量 Live int32 `gorm:"column:live;comment:可用时长(秒)" json:"live"` // 可用时长(秒) Quota int32 `gorm:"column:quota;comment:配额数量" json:"quota"` // 配额数量 - Used int32 `gorm:"column:used;comment:已用数量" json:"used"` // 已用数量 Expire time.Time `gorm:"column:expire;comment:过期时间" json:"expire"` // 过期时间 + Used int32 `gorm:"column:used;not null;comment:已用数量" json:"used"` // 已用数量 DailyLimit int32 `gorm:"column:daily_limit;not null;comment:每日限制" json:"daily_limit"` // 每日限制 DailyUsed int32 `gorm:"column:daily_used;not null;comment:今日已用数量" json:"daily_used"` // 今日已用数量 DailyLast time.Time `gorm:"column:daily_last;comment:今日最后使用时间" json:"daily_last"` // 今日最后使用时间 diff --git a/web/queries/proxy.gen.go b/web/queries/proxy.gen.go index 866ea9c..9eeab84 100644 --- a/web/queries/proxy.gen.go +++ b/web/queries/proxy.gen.go @@ -32,6 +32,7 @@ func newProxy(db *gorm.DB, opts ...gen.DOOption) proxy { _proxy.Name = field.NewString(tableName, "name") _proxy.Host = field.NewString(tableName, "host") _proxy.Type = field.NewInt32(tableName, "type") + _proxy.Secret = field.NewString(tableName, "secret") _proxy.CreatedAt = field.NewTime(tableName, "created_at") _proxy.UpdatedAt = field.NewTime(tableName, "updated_at") _proxy.DeletedAt = field.NewField(tableName, "deleted_at") @@ -50,6 +51,7 @@ type proxy struct { Name field.String // 代理服务名称 Host field.String // 代理服务地址 Type field.Int32 // 代理服务类型:0-自有,1-三方 + Secret field.String // 代理服务密钥 CreatedAt field.Time // 创建时间 UpdatedAt field.Time // 更新时间 DeletedAt field.Field // 删除时间 @@ -74,6 +76,7 @@ func (p *proxy) updateTableName(table string) *proxy { p.Name = field.NewString(table, "name") p.Host = field.NewString(table, "host") p.Type = field.NewInt32(table, "type") + p.Secret = field.NewString(table, "secret") p.CreatedAt = field.NewTime(table, "created_at") p.UpdatedAt = field.NewTime(table, "updated_at") p.DeletedAt = field.NewField(table, "deleted_at") @@ -93,12 +96,13 @@ func (p *proxy) GetFieldByName(fieldName string) (field.OrderExpr, bool) { } func (p *proxy) fillFieldMap() { - p.fieldMap = make(map[string]field.Expr, 8) + p.fieldMap = make(map[string]field.Expr, 9) p.fieldMap["id"] = p.ID p.fieldMap["version"] = p.Version p.fieldMap["name"] = p.Name p.fieldMap["host"] = p.Host p.fieldMap["type"] = p.Type + p.fieldMap["secret"] = p.Secret p.fieldMap["created_at"] = p.CreatedAt p.fieldMap["updated_at"] = p.UpdatedAt p.fieldMap["deleted_at"] = p.DeletedAt diff --git a/web/queries/resource_pss.gen.go b/web/queries/resource_pss.gen.go index 2583bb2..b61a4a0 100644 --- a/web/queries/resource_pss.gen.go +++ b/web/queries/resource_pss.gen.go @@ -32,8 +32,8 @@ func newResourcePss(db *gorm.DB, opts ...gen.DOOption) resourcePss { _resourcePss.Type = field.NewInt32(tableName, "type") _resourcePss.Live = field.NewInt32(tableName, "live") _resourcePss.Quota = field.NewInt32(tableName, "quota") - _resourcePss.Used = field.NewInt32(tableName, "used") _resourcePss.Expire = field.NewTime(tableName, "expire") + _resourcePss.Used = field.NewInt32(tableName, "used") _resourcePss.DailyLimit = field.NewInt32(tableName, "daily_limit") _resourcePss.DailyUsed = field.NewInt32(tableName, "daily_used") _resourcePss.DailyLast = field.NewTime(tableName, "daily_last") @@ -55,8 +55,8 @@ type resourcePss struct { Type field.Int32 // 套餐类型:1-包时,2-包量 Live field.Int32 // 可用时长(秒) Quota field.Int32 // 配额数量 - Used field.Int32 // 已用数量 Expire field.Time // 过期时间 + Used field.Int32 // 已用数量 DailyLimit field.Int32 // 每日限制 DailyUsed field.Int32 // 今日已用数量 DailyLast field.Time // 今日最后使用时间 @@ -84,8 +84,8 @@ func (r *resourcePss) updateTableName(table string) *resourcePss { r.Type = field.NewInt32(table, "type") r.Live = field.NewInt32(table, "live") r.Quota = field.NewInt32(table, "quota") - r.Used = field.NewInt32(table, "used") r.Expire = field.NewTime(table, "expire") + r.Used = field.NewInt32(table, "used") r.DailyLimit = field.NewInt32(table, "daily_limit") r.DailyUsed = field.NewInt32(table, "daily_used") r.DailyLast = field.NewTime(table, "daily_last") @@ -114,8 +114,8 @@ func (r *resourcePss) fillFieldMap() { r.fieldMap["type"] = r.Type r.fieldMap["live"] = r.Live r.fieldMap["quota"] = r.Quota - r.fieldMap["used"] = r.Used r.fieldMap["expire"] = r.Expire + r.fieldMap["used"] = r.Used r.fieldMap["daily_limit"] = r.DailyLimit r.fieldMap["daily_used"] = r.DailyUsed r.fieldMap["daily_last"] = r.DailyLast diff --git a/web/router.go b/web/router.go index df6c655..3d9d7c9 100644 --- a/web/router.go +++ b/web/router.go @@ -2,7 +2,6 @@ package web import ( "platform/web/handlers" - "platform/web/services" "github.com/gofiber/fiber/v2" ) @@ -18,10 +17,6 @@ func ApplyRouters(app *fiber.App) { // 通道 channel := api.Group("/channel") - channel.Post("/create", Permit([]services.PayloadType{ - services.PayloadClientConfidential, - services.PayloadClientPublic, - services.PayloadUser, - services.PayloadAdmin, - }), handlers.CreateChannel) + channel.Post("/create", PermitAll(), handlers.CreateChannel) + channel.Post("/remove", PermitAll(), handlers.RemoveChannels) } diff --git a/web/services/channel.go b/web/services/channel.go index 7248598..3cdcdb6 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -7,12 +7,16 @@ import ( "fmt" "log/slog" "math" + "platform/pkg/env" "platform/pkg/orm" "platform/pkg/rds" "platform/pkg/remote" + "platform/pkg/v" "platform/web/common" "platform/web/models" q "platform/web/queries" + "strconv" + "strings" "time" "github.com/google/uuid" @@ -26,143 +30,6 @@ var Channel = &channelService{} type channelService struct { } -// CreateChannel 创建连接通道,并返回连接信息,如果配额不足则返回错误 -func (s *channelService) CreateChannel( - ctx context.Context, - auth *AuthContext, - resourceId int32, - protocol ChannelProtocol, - authType ChannelAuthType, - count int, - nodeFilter ...NodeFilterConfig, -) ([]*models.Channel, error) { - - // 创建通道 - var channels []*models.Channel - err := q.Q.Transaction(func(tx *q.Query) error { - // 查找套餐 - var resource = ResourceInfo{} - err := q.Resource.As("data"). - LeftJoin(q.ResourcePss.As("pss"), q.ResourcePss.ResourceID.EqCol(q.Resource.ID)). - Where(q.Resource.ID.Eq(resourceId)). - Scan(&resource) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return ChannelServiceErr("套餐不存在") - } - return err - } - - // 检查使用人 - if auth.Payload.Type == PayloadUser && auth.Payload.Id != resource.UserId { - return common.AuthForbiddenErr("无权限访问") - } - - // 检查套餐状态 - if !resource.Active { - return ChannelServiceErr("套餐已失效") - } - - // 检查每日限额 - today := time.Now().Format("2006-01-02") == resource.DailyLast.Format("2006-01-02") - dailyRemain := int(math.Max(float64(resource.DailyLimit-resource.DailyUsed), 0)) - if today && dailyRemain < count { - return ChannelServiceErr("套餐每日配额不足") - } - - // 检查时间或配额 - if resource.Type == 1 { // 包时 - if resource.Expire.Before(time.Now()) { - return ChannelServiceErr("套餐已过期") - } - } else { // 包量 - remain := int(math.Max(float64(resource.Quota-resource.Used), 0)) - if remain < count { - return ChannelServiceErr("套餐配额不足") - } - } - - // 筛选可用节点 - nodes, err := Node.Filter(ctx, auth.Payload.Id, count, nodeFilter...) - if err != nil { - return err - } - - // 获取用户配置白名单 - whitelist, err := q.Whitelist.Where( - q.Whitelist.UserID.Eq(auth.Payload.Id), - ).Find() - if err != nil { - return err - } - - // 创建连接通道 - channels = make([]*models.Channel, 0, len(nodes)*len(whitelist)) - for _, node := range nodes { - for _, allowed := range whitelist { - username, password := genPassPair() - channels = append(channels, &models.Channel{ - UserID: auth.Payload.Id, - NodeID: node.ID, - UserHost: allowed.Host, - NodeHost: node.Host, - ProxyPort: node.ProxyPort, - Protocol: string(protocol), - AuthIP: authType == ChannelAuthTypeIp, - AuthPass: authType == ChannelAuthTypePass, - Username: username, - Password: password, - Expiration: time.Now().Add(time.Duration(resource.Live) * time.Second), - }) - } - } - - // 保存到数据库 - err = tx.Channel.Create(channels...) - if err != nil { - return err - } - - // 更新套餐使用记录 - if today { - resource.DailyUsed += int32(count) - resource.Used += int32(count) - } else { - resource.DailyLast = time.Now() - resource.DailyUsed = int32(count) - resource.Used += int32(count) - } - - err = tx.ResourcePss. - Where(q.ResourcePss.ID.Eq(resource.Id)). - Select( - q.ResourcePss.Used, - q.ResourcePss.DailyUsed, - q.ResourcePss.DailyLast). - Save(&models.ResourcePss{ - Used: resource.Used, - DailyUsed: resource.DailyUsed, - DailyLast: resource.DailyLast}) - if err != nil { - return err - } - - return nil - }) - if err != nil { - return nil, err - } - - // 缓存通道信息与异步删除任务 - // err = cache(ctx, channels) - // if err != nil { - // return nil, err - // } - - // 返回连接通道列表 - return channels, errors.New("not implemented") -} - type ChannelAuthType int const ( @@ -178,24 +45,23 @@ const ( ProtocolHttps = ChannelProtocol("https") ) -func genPassPair() (string, string) { - usernameBytes, err := uuid.New().MarshalBinary() - if err != nil { - panic(err) - } - passwordBytes, err := uuid.New().MarshalBinary() - if err != nil { - panic(err) - } - username := base62.EncodeToString(usernameBytes) - password := base62.EncodeToString(passwordBytes) - return username, password +type ResourceInfo struct { + Id int32 + UserId int32 + Active bool + Type int32 + Live int32 + DailyLimit int32 + DailyUsed int32 + DailyLast time.Time + Quota int32 + Used int32 + Expire time.Time } +// region RemoveChannel + func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, id ...int32) error { - - var channels []*models.Channel - // 删除通道 err := q.Q.Transaction(func(tx *q.Query) error { // 查找通道 @@ -206,15 +72,30 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, return err } - // 检查权限,只有用户自己和管理员能删除 + // 检查权限,如果为用户操作的话,则只能删除自己的通道 for _, channel := range channels { if auth.Payload.Type == PayloadUser && auth.Payload.Id != channel.UserID { return common.AuthForbiddenErr("无权限访问") } } + // 查找代理 + proxySet := make(map[int32]struct{}) + proxyIds := make([]int32, 0) + for _, channel := range channels { + if _, ok := proxySet[channel.ProxyID]; !ok { + proxyIds = append(proxyIds, channel.ProxyID) + proxySet[channel.ProxyID] = struct{}{} + } + } + proxies, err := tx.Proxy.Where( + q.Proxy.ID.In(proxyIds...), + ).Find() + // 删除指定的通道 - result, err := tx.Channel.Delete(channels...) + result, err := tx.Channel. + Where(q.Channel.ID.In(id...)). + Update(q.Channel.DeletedAt, time.Now()) if err != nil { return err } @@ -222,30 +103,103 @@ func (s *channelService) RemoveChannels(ctx context.Context, auth *AuthContext, return ChannelServiceErr("删除通道失败") } + // 删除缓存,异步任务直接在消费端处理删除 + err = deleteCache(ctx, channels) + if err != nil { + return err + } + + // 禁用代理端口并下线用过的节点 + if env.DebugExternalChange { + var configMap = make(map[int32][]remote.PortConfigsReq, len(proxies)) + var proxyMap = make(map[int32]*models.Proxy, len(proxies)) + for _, proxy := range proxies { + configMap[proxy.ID] = make([]remote.PortConfigsReq, 0) + proxyMap[proxy.ID] = proxy + } + var portMap = make(map[uint64]struct{}) + for _, channel := range channels { + var config = remote.PortConfigsReq{ + Port: int(channel.ProxyPort), + Edge: &[]string{}, + AutoEdgeConfig: &remote.AutoEdgeConfig{ + Count: v.P(0), + }, + Status: false, + } + configMap[channel.ProxyID] = append(configMap[channel.ProxyID], config) + + key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort) + portMap[key] = struct{}{} + } + + for proxyId, configs := range configMap { + if len(configs) == 0 { + continue + } + proxy, ok := proxyMap[proxyId] + if !ok { + return ChannelServiceErr("代理不存在") + } + + var secret = strings.Split(proxy.Secret, ":") + gateway := remote.InitGateway( + proxy.Host, + secret[0], + secret[1], + ) + + // 查询配置的节点 + actives, err := gateway.GatewayPortActive() + if err != nil { + return err + } + + // 取消配置 + err = gateway.GatewayPortConfigs(configs) + if err != nil { + return err + } + + // 下线对应节点 + var edges []string + for portStr, active := range actives { + port, err := strconv.Atoi(portStr) + if err != nil { + return err + } + key := uint64(proxyId)<<32 | uint64(port) + if _, ok := portMap[key]; ok { + edges = append(edges, active.Edge...) + } + } + if len(edges) > 0 { + _, err := remote.Client.CloudDisconnect(remote.CloudDisconnectReq{ + Uuid: proxy.Name, + Edge: edges, + }) + if err != nil { + return err + } + } + } + + } + return nil }) if err != nil { return err } - // 删除缓存,异步任务直接在消费端处理删除 - err = deleteCache(ctx, channels) - if err != nil { - return err - } - return nil } -type ChannelServiceErr string +// endregion -func (c ChannelServiceErr) Error() string { - return string(c) -} +// region CreateChannel -// region channel by remote - -func (s *channelService) RemoteCreateChannel( +func (s *channelService) CreateChannel( ctx context.Context, auth *AuthContext, resourceId int32, @@ -253,57 +207,78 @@ func (s *channelService) RemoteCreateChannel( authType ChannelAuthType, count int, nodeFilter ...NodeFilterConfig, -) ([]AssignPortResult, error) { +) ([]string, error) { filter := NodeFilterConfig{} if len(nodeFilter) > 0 { filter = nodeFilter[0] } - // 查找套餐 - var resource = new(ResourceInfo) - data := q.Resource.As("data") - pss := q.ResourcePss.As("pss") - err := data.Debug().Scopes(orm.Alias(data)). - Select( - data.ID, data.UserID, data.Active, - pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire, - ). - LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)). - Where(data.ID.Eq(resourceId)). - Scan(&resource) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, ChannelServiceErr("套餐不存在") + var addr []string + err := q.Q.Transaction(func(tx *q.Query) error { + + // 查找套餐 + var resource = new(ResourceInfo) + data := q.Resource.As("data") + pss := q.ResourcePss.As("pss") + err := data.Scopes(orm.Alias(data)). + Select( + data.ID, data.UserID, data.Active, + pss.Type, pss.Live, pss.DailyUsed, pss.DailyLimit, pss.DailyLast, pss.Quota, pss.Used, pss.Expire, + ). + LeftJoin(q.ResourcePss.As("pss"), pss.ResourceID.EqCol(data.ID)). + Where(data.ID.Eq(resourceId)). + Scan(&resource) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ChannelServiceErr("套餐不存在") + } + return err } - return nil, err - } - // 检查用户权限 - err = checkUser(auth, resource, count) - if err != nil { - return nil, err - } - slog.Debug("检查用户权限完成") + // 检查用户权限 + err = checkUser(auth, resource, count) + if err != nil { + return err + } - var postAssigns []AssignPortResult - err = q.Q.Transaction(func(tx *q.Query) error { // 申请节点 edgeAssigns, err := assignEdge(count, filter) if err != nil { return err } - debugAssigned := fmt.Sprintf("%+v", edgeAssigns) - slog.Debug("申请节点完成", "edgeAssigns", debugAssigned) // 分配端口 - expiration := time.Now().Add(time.Duration(resource.Live) * time.Second) - postAssigns, err = assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter) + now := time.Now() + expiration := now.Add(time.Duration(resource.Live) * time.Second) + _addr, channels, err := assignPort(edgeAssigns, auth.Payload.Id, protocol, authType, expiration, filter) + if err != nil { + return err + } + addr = _addr + + // 更新套餐使用记录 + _, err = q.ResourcePss. + Where(q.ResourcePss.ResourceID.Eq(resourceId)). + Select( + q.ResourcePss.Used, + q.ResourcePss.DailyUsed, + q.ResourcePss.DailyLast, + ). + Updates(models.ResourcePss{ + Used: resource.Used + int32(count), + DailyUsed: resource.DailyUsed + int32(count), + DailyLast: now, + }) + if err != nil { + return err + } + + // 缓存通道数据 + err = cache(ctx, channels) if err != nil { return err } - debugChannels := fmt.Sprintf("%+v", postAssigns) - slog.Debug("分配端口完成", "portAssigns", debugChannels) return nil }) @@ -311,17 +286,9 @@ func (s *channelService) RemoteCreateChannel( return nil, err } - // 缓存并关闭代理 - err = cache(ctx, postAssigns) - if err != nil { - return nil, err - } - - return postAssigns, nil + return addr, nil } -// endregion - func checkUser(auth *AuthContext, resource *ResourceInfo, count int) error { // 检查使用人 @@ -368,17 +335,17 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) { } // 查询已配置的节点 - allConfigs, err := remote.Client.CloudAutoQuery() + rProxyConfigs, err := remote.Client.CloudAutoQuery() if err != nil { return nil, err } - // 查询已分配的节点 + // 查询已使用的节点 var proxyIds = make([]int32, len(proxies)) for i, proxy := range proxies { proxyIds[i] = proxy.ID } - assigns, err := q.Channel. + channels, err := q.Channel. Select( q.Channel.ProxyID, q.Channel.ProxyPort). @@ -392,80 +359,86 @@ func assignEdge(count int, filter NodeFilterConfig) (*AssignEdgeResult, error) { if err != nil { return nil, err } + var proxyUses = make(map[int32]int, len(channels)) + for _, channel := range channels { + proxyUses[channel.ProxyID]++ + } - // 过滤需要变动的连接配置 - var current = 0 - var result = make([]ProxyConfig, len(proxies)) + // 组织数据 + var infos = make([]*ProxyInfo, len(proxies)) for i, proxy := range proxies { - remoteConfigs, ok := allConfigs[proxy.Name] + infos[i] = &ProxyInfo{ + proxy: proxy, + used: proxyUses[proxy.ID], + } + + rConfigs, ok := rProxyConfigs[proxy.Name] if !ok { - result[i] = ProxyConfig{ - proxy: proxy, - config: &remote.AutoConfig{ - Province: filter.Prov, - City: filter.City, - Isp: filter.Isp, - Count: 0, - }, - } + infos[i].count = 0 continue } - for _, config := range remoteConfigs { - if config.Isp == filter.Isp && config.City == filter.City && config.Province == filter.Prov { - current += config.Count - result[i] = ProxyConfig{ - proxy: proxy, - config: &config, - } + + for _, rConfig := range rConfigs { + if rConfig.Isp == filter.Isp && rConfig.City == filter.City && rConfig.Province == filter.Prov { + infos[i].count = rConfig.Count } } } - // 如果需要新增节点 - var needed = len(assigns) + count - if needed-current > 0 { - slog.Debug("新增新节点", "needed", needed, "current", current) - avg := int(math.Ceil(float64(needed) / float64(len(proxies)))) - for i, assign := range result { - var prev = assign.config.Count - var next = assign.config.Count - if prev >= avg || prev >= needed { - continue - } + // 分配新增的节点 + var configs = make([]*ProxyConfig, len(proxies)) + var needed = len(channels) + count + avg := int(math.Ceil(float64(needed) / float64(len(proxies)))) + for i, info := range infos { + var prev = info.used + var next = int(math.Min(float64(avg), float64(needed))) - next = int(math.Min(float64(avg), float64(needed))) - result[i].config.Count = next - prev - needed -= next + info.used = int(math.Max(float64(prev), float64(next))) + needed -= info.used + + if env.DebugExternalChange && info.used > info.count { + slog.Debug("新增新节点", "proxy", info.proxy.Name, "used", info.used, "count", info.count) err := remote.Client.CloudConnect(remote.CloudConnectReq{ - Uuid: assign.proxy.Name, + Uuid: info.proxy.Name, Edge: nil, AutoConfig: []remote.AutoConfig{{ Province: filter.Prov, City: filter.City, Isp: filter.Isp, - Count: next, + Count: int(math.Ceil(float64(info.used) * 1.1)), }}, }) if err != nil { return nil, err } } + + configs[i] = &ProxyConfig{ + proxy: info.proxy, + count: int(math.Max(float64(next-prev), 0)), + } } return &AssignEdgeResult{ - configs: result, - channels: assigns, + configs: configs, + channels: channels, }, nil } +type ProxyInfo struct { + proxy *models.Proxy + used int + count int +} + type AssignEdgeResult struct { - configs []ProxyConfig + configs []*ProxyConfig channels []*models.Channel } type ProxyConfig struct { - proxy *models.Proxy - config *remote.AutoConfig + proxy *models.Proxy + count int } // assignPort 分配指定数量的端口 @@ -476,9 +449,9 @@ func assignPort( authType ChannelAuthType, expiration time.Time, filter NodeFilterConfig, -) ([]AssignPortResult, error) { +) ([]string, []*models.Channel, error) { var assigns = proxies.configs - var channels = proxies.channels + var exists = proxies.channels // 查询代理已配置端口 var proxyIds = make([]int32, 0, len(assigns)) @@ -488,24 +461,20 @@ func assignPort( // 端口查找表 var proxyPorts = make(map[uint64]struct{}) - for _, channel := range channels { + for _, channel := range exists { key := uint64(channel.ProxyID)<<32 | uint64(channel.ProxyPort) proxyPorts[key] = struct{}{} } // 配置启用代理 - var result = make([]AssignPortResult, len(assigns)) - for i, assign := range assigns { + var result []string + var channels []*models.Channel + for _, assign := range assigns { var err error var proxy = assign.proxy - var count = assign.config.Count - - result[i] = AssignPortResult{ - Proxy: proxy, - } + var count = assign.count // 筛选可用端口 - var channels = result[i].Channels var configs = make([]remote.PortConfigsReq, 0, count) for port := 10000; port < 20000 && len(configs) < count; port++ { // 跳过存在的端口 @@ -521,13 +490,14 @@ func assignPort( Port: port, Edge: nil, Status: true, - AutoEdgeConfig: remote.AutoEdgeConfig{ + AutoEdgeConfig: &remote.AutoEdgeConfig{ Province: filter.Prov, City: filter.City, Isp: filter.Isp, - Count: 1, + Count: v.P(1), }, }) + switch authType { case ChannelAuthTypeIp: var whitelist []string @@ -536,9 +506,10 @@ func assignPort( Select(q.Whitelist.Host). Scan(&whitelist) if err != nil { - return nil, err + return nil, nil, err } - configs[i].Whitelist = whitelist + configs[i].Whitelist = &whitelist + configs[i].Userpass = v.P("") for _, item := range whitelist { channels = append(channels, &models.Channel{ UserID: userId, @@ -553,7 +524,8 @@ func assignPort( } case ChannelAuthTypePass: username, password := genPassPair() - configs[i].Userpass = fmt.Sprintf("%s:%s", username, password) + configs[i].Whitelist = new([]string) + configs[i].Userpass = v.P(fmt.Sprintf("%s:%s", username, password)) channels = append(channels, &models.Channel{ UserID: userId, ProxyID: proxy.ID, @@ -566,66 +538,82 @@ func assignPort( Expiration: expiration, }) } - } - result[i].Channels = channels + result = append(result, fmt.Sprintf("%s://%s:%d", protocol, proxy.Host, port)) + } if len(configs) < count { - return nil, ChannelServiceErr("网关端口数量到达上限,无法分配") - } - - // 提交端口配置 - gateway := remote.InitGateway( - proxy.Host, - "api", - "123456", - ) - err = gateway.GatewayPortConfigs(configs) - if err != nil { - return nil, err + return nil, nil, ChannelServiceErr("网关端口数量到达上限,无法分配") } // 保存到数据库 err = q.Channel. - Omit(q.Channel.NodeID). + Omit( + q.Channel.NodeID, + q.Channel.NodeHost, + q.Channel.Username, + q.Channel.Password, + q.Channel.DeletedAt, + ). Save(channels...) if err != nil { - return nil, err + return nil, nil, err + } + + // 提交端口配置并更新节点列表 + if env.DebugExternalChange { + var secret = strings.Split(proxy.Secret, ":") + gateway := remote.InitGateway( + proxy.Host, + secret[0], + secret[1], + ) + err = gateway.GatewayPortConfigs(configs) + if err != nil { + return nil, nil, err + } } } - return result, nil + return result, channels, nil } -type AssignPortResult struct { - Proxy *models.Proxy - Channels []*models.Channel +// endregion + +func genPassPair() (string, string) { + usernameBytes, err := uuid.New().MarshalBinary() + if err != nil { + panic(err) + } + passwordBytes, err := uuid.New().MarshalBinary() + if err != nil { + panic(err) + } + username := base62.EncodeToString(usernameBytes) + password := base62.EncodeToString(passwordBytes) + return username, password } func chKey(channel *models.Channel) string { - return fmt.Sprintf("channel:%s:%s", channel.UserHost, channel.NodeHost) + return fmt.Sprintf("channel:%d", channel.ID) } -func cache(ctx context.Context, assigns []AssignPortResult) error { +func cache(ctx context.Context, channels []*models.Channel) error { pipe := rds.Client.TxPipeline() - zList := make([]redis.Z, 0, len(assigns)) - for _, assign := range assigns { - var channels = assign.Channels - for _, channel := range channels { - marshal, err := json.Marshal(assign) - if err != nil { - return err - } - - pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now())) - zList = append(zList, redis.Z{ - Score: float64(channel.Expiration.Unix()), - Member: channel.ID, - }) + zList := make([]redis.Z, 0, len(channels)) + for _, channel := range channels { + marshal, err := json.Marshal(channel) + if err != nil { + return err } + pipe.Set(ctx, chKey(channel), string(marshal), channel.Expiration.Sub(time.Now())) + zList = append(zList, redis.Z{ + Score: float64(channel.Expiration.Unix()), + Member: channel.ID, + }) } - pipe.ZAdd(ctx, "tasks:assign", zList...) + pipe.ZAdd(ctx, "tasks:channel", zList...) _, err := pipe.Exec(ctx) if err != nil { @@ -636,14 +624,11 @@ func cache(ctx context.Context, assigns []AssignPortResult) error { } func deleteCache(ctx context.Context, channels []*models.Channel) error { - pipe := rds.Client.TxPipeline() - keys := make([]string, 0, len(channels)) - for i := range keys { + keys := make([]string, len(channels)) + for i := range channels { keys[i] = chKey(channels[i]) } - pipe.Del(ctx, keys...) - // 忽略异步任务,zrem 效率较低,在使用时再删除 - _, err := pipe.Exec(ctx) + _, err := rds.Client.Del(ctx, keys...).Result() if err != nil { return err } @@ -651,16 +636,8 @@ func deleteCache(ctx context.Context, channels []*models.Channel) error { return nil } -type ResourceInfo struct { - Id int32 - UserId int32 - Active bool - Type int32 - Live int32 - DailyLimit int32 - DailyUsed int32 - DailyLast time.Time - Quota int32 - Used int32 - Expire time.Time +type ChannelServiceErr string + +func (c ChannelServiceErr) Error() string { + return string(c) } diff --git a/web/web.go b/web/web.go index a510427..a0ec864 100644 --- a/web/web.go +++ b/web/web.go @@ -1,13 +1,17 @@ package web import ( + "net/http" "platform/pkg/env" + "log/slog" + "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/requestid" ) -import "log/slog" + +import _ "net/http/pprof" type Config struct { Listen string @@ -30,6 +34,7 @@ func New(config *Config) (*Server, error) { } func (s *Server) Run() error { + s.fiber = fiber.New(fiber.Config{ ErrorHandler: ErrorHandler, }) @@ -39,6 +44,13 @@ func (s *Server) Run() error { ApplyRouters(s.fiber) + go func() { + err := http.ListenAndServe(":6060", nil) + if err != nil { + slog.Error("pprof 服务错误", slog.Any("err", err)) + } + }() + port := env.AppPort slog.Info("Server started on :" + port) err := s.fiber.Listen(":" + port)