diff --git a/.gitignore b/.gitignore index 44b6088..3bd5f65 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,5 @@ scripts/* !scripts/env/dev/ !scripts/pre/ !scripts/sql/ + +*/uploads/ diff --git a/README.md b/README.md index c8f7a7f..f6bef66 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ ## TODO ---- +上传文件平铺到 uploads,不分子文件夹 错误提示增强,展示整链路信息 diff --git a/pkg/env/env.go b/pkg/env/env.go index 0e94616..c7e8e21 100644 --- a/pkg/env/env.go +++ b/pkg/env/env.go @@ -45,6 +45,9 @@ var ( BaiyinCloudUrl string BaiyinTokenUrl string + GostApiPort = 8900 + GostApiPathPrefix = "" + IdenCallbackUrl string IdenAccessKey string IdenSecretKey string @@ -129,6 +132,8 @@ func Init() { errs = append(errs, parse(&BaiyinCloudUrl, "BAIYIN_CLOUD_URL", false, nil)) errs = append(errs, parse(&BaiyinTokenUrl, "BAIYIN_TOKEN_URL", false, nil)) + errs = append(errs, parse(&GostApiPort, "GOST_API_PORT", true, nil)) + errs = append(errs, parse(&GostApiPathPrefix, "GOST_API_PATH_PREFIX", true, nil)) errs = append(errs, parse(&IdenCallbackUrl, "IDEN_CALLBACK_URL", false, nil)) errs = append(errs, parse(&IdenAccessKey, "IDEN_ACCESS_KEY", false, nil)) diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index 7a98b2a..a3c9907 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -625,7 +625,7 @@ comment on column proxy.mac is '代理服务名称'; comment on column proxy.ip is '代理服务地址'; comment on column proxy.host is '代理服务域名'; comment on column proxy.secret is '代理服务密钥'; -comment on column proxy.type is '代理服务类型:1-自有,2-白银'; +comment on column proxy.type is '代理服务类型:1-自有,2-白银,3-GOST'; comment on column proxy.status is '代理服务状态:0-离线,1-在线'; comment on column proxy.meta is '代理服务元信息'; comment on column proxy.created_at is '创建时间'; @@ -640,6 +640,7 @@ create table edge ( version int not null, mac text not null, ip inet not null, + port int, isp int not null, prov text not null, city text not null, @@ -659,10 +660,11 @@ create index idx_edge_created_at on edge (created_at) where deleted_at is null; -- edge表字段注释 comment on table edge is '节点表'; comment on column edge.id is '节点ID'; -comment on column edge.type is '节点类型:1-自建'; +comment on column edge.type is '节点类型:1-自建,2-GOST chain'; comment on column edge.version is '节点版本'; -comment on column edge.mac is '节点 mac 地址'; -comment on column edge.ip is '节点地址'; +comment on column edge.mac is '节点 mac 地址或 GOST chain 名称'; +comment on column edge.ip is '节点地址或 GOST chain addr 的 IP'; +comment on column edge.port is 'GOST chain addr 的端口'; comment on column edge.isp is '运营商:1-电信,2-联通,3-移动'; comment on column edge.prov is '省份'; comment on column edge.city is '城市'; diff --git a/web/globals/gost.go b/web/globals/gost.go new file mode 100644 index 0000000..555e9b2 --- /dev/null +++ b/web/globals/gost.go @@ -0,0 +1,215 @@ +package globals + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "platform/web/core" + "strings" +) + +var ErrGostNotFound = errors.New("gost resource not found") + +func IsGostNotFound(err error) bool { + return errors.Is(err, ErrGostNotFound) +} + +type GostClient interface { + GetChain(name string) (*GostChainConfig, error) + CreateService(service *GostServiceConfig) error + DeleteService(name string) error + CreateAuther(auther *GostAutherConfig) error + DeleteAuther(name string) error + CreateAdmission(admission *GostAdmissionConfig) error + DeleteAdmission(name string) error +} + +type gostClient struct { + baseURL string + pathPrefix string + username string + password string +} + +var GostInitializer = func(host string, port int, pathPrefix, username, password string) GostClient { + baseURL := strings.TrimSpace(host) + if !strings.Contains(baseURL, "://") { + baseURL = fmt.Sprintf("http://%s:%d", baseURL, port) + } + + return &gostClient{ + baseURL: strings.TrimRight(baseURL, "/"), + pathPrefix: normalizeGostPathPrefix(pathPrefix), + username: username, + password: password, + } +} + +func NewGost(host string, port int, pathPrefix, username, password string) GostClient { + return GostInitializer(host, port, pathPrefix, username, password) +} + +type GostChainConfig struct { + Name string `json:"name"` +} + +type GostServiceConfig struct { + Name string `json:"name"` + Addr string `json:"addr"` + Admission string `json:"admission,omitempty"` + Handler GostHandlerConfig `json:"handler"` + Listener GostListenerConfig `json:"listener"` +} + +type GostHandlerConfig struct { + Type string `json:"type"` + Chain string `json:"chain,omitempty"` + Auther string `json:"auther,omitempty"` +} + +type GostListenerConfig struct { + Type string `json:"type"` +} + +type GostAutherConfig struct { + Name string `json:"name"` + Auths []GostAuthConfig `json:"auths"` +} + +type GostAuthConfig struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type GostAdmissionConfig struct { + Name string `json:"name"` + Whitelist bool `json:"whitelist"` + Matchers []string `json:"matchers"` +} + +func (c *gostClient) GetChain(name string) (*GostChainConfig, error) { + body, err := c.get("/config/chains/" + url.PathEscape(name)) + if err != nil { + return nil, err + } + if len(body) == 0 { + return &GostChainConfig{Name: name}, nil + } + + var direct GostChainConfig + if err := json.Unmarshal(body, &direct); err == nil && direct.Name != "" { + return &direct, nil + } + + var wrapper struct { + Data *GostChainConfig `json:"data"` + } + if err := json.Unmarshal(body, &wrapper); err == nil && wrapper.Data != nil && wrapper.Data.Name != "" { + return wrapper.Data, nil + } + + return &GostChainConfig{Name: name}, nil +} + +func (c *gostClient) CreateService(service *GostServiceConfig) error { + return c.create("/config/services", service) +} + +func (c *gostClient) DeleteService(name string) error { + return c.delete("/config/services/" + url.PathEscape(name)) +} + +func (c *gostClient) CreateAuther(auther *GostAutherConfig) error { + return c.create("/config/authers", auther) +} + +func (c *gostClient) DeleteAuther(name string) error { + return c.delete("/config/authers/" + url.PathEscape(name)) +} + +func (c *gostClient) CreateAdmission(admission *GostAdmissionConfig) error { + return c.create("/config/admissions", admission) +} + +func (c *gostClient) DeleteAdmission(name string) error { + return c.delete("/config/admissions/" + url.PathEscape(name)) +} + +func (c *gostClient) create(path string, payload any) error { + _, err := c.request(http.MethodPost, path, payload) + return err +} + +func (c *gostClient) get(path string) ([]byte, error) { + body, err := c.request(http.MethodGet, path, nil) + if err != nil { + return nil, err + } + return body, nil +} + +func (c *gostClient) delete(path string) error { + _, err := c.request(http.MethodDelete, path, nil) + return err +} + +func (c *gostClient) request(method string, path string, payload any) ([]byte, error) { + var bodyReader io.Reader + if payload != nil { + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + bodyReader = bytes.NewReader(data) + } + + req, err := http.NewRequest(method, c.endpoint(path), bodyReader) + if err != nil { + return nil, err + } + req.SetBasicAuth(c.username, c.password) + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + + resp, err := core.Fetch(req) + if err != nil { + return nil, err + } + defer func(Body io.ReadCloser) { + _ = Body.Close() + }(resp.Body) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusNotFound { + return nil, fmt.Errorf("%w: %s", ErrGostNotFound, string(body)) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("gost api %s %s failed: %d %s", method, path, resp.StatusCode, string(body)) + } + + return body, nil +} + +func (c *gostClient) endpoint(path string) string { + return c.baseURL + c.pathPrefix + path +} + +func normalizeGostPathPrefix(prefix string) string { + prefix = strings.TrimSpace(prefix) + if prefix == "" { + return "" + } + if !strings.HasPrefix(prefix, "/") { + prefix = "/" + prefix + } + return strings.TrimRight(prefix, "/") +} diff --git a/web/globals/gost_test.go b/web/globals/gost_test.go new file mode 100644 index 0000000..dd65949 --- /dev/null +++ b/web/globals/gost_test.go @@ -0,0 +1,96 @@ +package globals + +import ( + "encoding/base64" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestGostClientCreateServiceUsesBasicAuthAndPathPrefix(t *testing.T) { + var ( + gotPath string + gotAuth string + gotBody GostServiceConfig + ) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil { + t.Fatalf("Decode failed: %v", err) + } + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + client := NewGost(server.URL, 0, "/api", "user", "pass") + err := client.CreateService(&GostServiceConfig{ + Name: "svc-1", + Addr: ":10000", + Handler: GostHandlerConfig{ + Type: "auto", + Chain: "chain-a", + Auther: "auther-a", + }, + Listener: GostListenerConfig{Type: "tcp"}, + }) + if err != nil { + t.Fatalf("CreateService returned error: %v", err) + } + + if gotPath != "/api/config/services" { + t.Fatalf("unexpected path: %s", gotPath) + } + if gotAuth != "Basic "+base64.StdEncoding.EncodeToString([]byte("user:pass")) { + t.Fatalf("unexpected auth header: %s", gotAuth) + } + if gotBody.Name != "svc-1" || gotBody.Handler.Type != "auto" || gotBody.Handler.Chain != "chain-a" { + t.Fatalf("unexpected request body: %+v", gotBody) + } +} + +func TestGostClientDeleteServiceTreats404AsIdempotent(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer server.Close() + + client := NewGost(server.URL, 0, "", "user", "pass") + if err := client.DeleteService("svc-1"); !IsGostNotFound(err) { + t.Fatalf("expected gost not found error, got: %v", err) + } +} + +func TestGostClientGetChainReadsTopLevelName(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/config/chains/chain-a" { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + _, _ = w.Write([]byte(`{"name":"chain-a"}`)) + })) + defer server.Close() + + client := NewGost(server.URL, 0, "", "user", "pass") + chain, err := client.GetChain("chain-a") + if err != nil { + t.Fatalf("GetChain returned error: %v", err) + } + if chain.Name != "chain-a" { + t.Fatalf("unexpected chain: %+v", chain) + } +} + +func TestNormalizeGostPathPrefix(t *testing.T) { + if got := normalizeGostPathPrefix("api/"); got != "/api" { + t.Fatalf("unexpected prefix: %s", got) + } + if got := normalizeGostPathPrefix(""); got != "" { + t.Fatalf("unexpected empty prefix: %s", got) + } + if !strings.HasPrefix(normalizeGostPathPrefix("/v1"), "/") { + t.Fatal("expected normalized prefix to start with slash") + } +} diff --git a/web/models/edge.go b/web/models/edge.go index c10c042..68794a7 100644 --- a/web/models/edge.go +++ b/web/models/edge.go @@ -8,16 +8,17 @@ import ( // Edge 节点表 type Edge struct { core.Model - Type EdgeType `json:"type" gorm:"column:type"` // 节点类型:1-自建 - Version int32 `json:"version" gorm:"column:version"` // 节点版本 - Mac string `json:"mac" gorm:"column:mac"` // 节点 mac 地址 - IP orm.Inet `json:"ip" gorm:"column:ip;not null"` // 节点地址 - ISP EdgeISP `json:"isp" gorm:"column:isp"` // 运营商:0-未知,1-电信,2-联通,3-移动 - Prov string `json:"prov" gorm:"column:prov"` // 省份 - City string `json:"city" gorm:"column:city"` // 城市 - Status EdgeStatus `json:"status" gorm:"column:status"` // 节点状态:0-离线,1-正常 - RTT int32 `json:"rtt" gorm:"column:rtt"` // 最近平均延迟 - Loss int32 `json:"loss" gorm:"column:loss"` // 最近丢包率 + Type EdgeType `json:"type" gorm:"column:type"` // 节点类型:1-自建,2-GOST chain + Version int32 `json:"version" gorm:"column:version"` // 节点版本 + Mac string `json:"mac" gorm:"column:mac"` // 节点 mac 地址或 GOST chain 名称 + IP orm.Inet `json:"ip" gorm:"column:ip;not null"` // 节点地址或 GOST chain addr 的 IP + Port *uint16 `json:"port,omitempty" gorm:"column:port"` // GOST chain addr 的端口 + ISP EdgeISP `json:"isp" gorm:"column:isp"` // 运营商:0-未知,1-电信,2-联通,3-移动 + Prov string `json:"prov" gorm:"column:prov"` // 省份 + City string `json:"city" gorm:"column:city"` // 城市 + Status EdgeStatus `json:"status" gorm:"column:status"` // 节点状态:0-离线,1-正常 + RTT int32 `json:"rtt" gorm:"column:rtt"` // 最近平均延迟 + Loss int32 `json:"loss" gorm:"column:loss"` // 最近丢包率 } // EdgeType 节点类型枚举 @@ -25,6 +26,7 @@ type EdgeType int const ( EdgeTypeSelfBuilt EdgeType = 1 // 自建 + EdgeTypeGostChain EdgeType = 2 // GOST chain ) // EdgeStatus 节点状态枚举 diff --git a/web/models/proxy.go b/web/models/proxy.go index c55eebd..04b1a20 100644 --- a/web/models/proxy.go +++ b/web/models/proxy.go @@ -28,6 +28,7 @@ type ProxyType int const ( ProxyTypeSelfHosted ProxyType = 1 // 自有 ProxyTypeBaiYin ProxyType = 2 // 白银 + ProxyTypeGost ProxyType = 3 // GOST ) // ProxyStatus 代理服务状态枚举 diff --git a/web/services/channel.go b/web/services/channel.go index cfbbae8..cc85994 100644 --- a/web/services/channel.go +++ b/web/services/channel.go @@ -4,24 +4,28 @@ import ( "context" "errors" "fmt" + "log/slog" "math/rand/v2" "net/netip" "platform/pkg/env" "platform/pkg/u" "platform/web/core" + e "platform/web/events" g "platform/web/globals" m "platform/web/models" q "platform/web/queries" "strconv" + "strings" "time" + "github.com/hibiken/asynq" "github.com/redis/go-redis/v9" "gorm.io/gen/field" ) // 通道服务 var Channel = &channelServer{ - provider: &channelBaiyinProvider{}, + provider: &channelGostProvider{}, } type ChannelServiceProvider interface { @@ -46,13 +50,77 @@ func (s *channelServer) ClearExpiredChannels(proxyId int32) (int, error) { return s.provider.ClearExpiredChannels(proxyId) } +func lockChannelCreateKey(resourceNo string) string { + return fmt.Sprintf("platform:channel:create:%s", resourceNo) +} + +func lockChannelRemoveKey(bid string) string { + return fmt.Sprintf("platform:batch:remove_expired:%s", bid) +} + +func selectPorts(proxyId int32, batchNo string, count int, expire time.Time) ([]netip.AddrPort, error) { + chans, err := lockChans(proxyId, batchNo, count) + if err != nil { + return nil, core.NewBizErr("无可用通道,请稍后再试", err) + } + + _, err = g.Asynq.Enqueue( + e.NewRemoveChannel(batchNo), + asynq.ProcessAt(expire), + ) + if err != nil { + return nil, core.NewServErr("注册异步关闭通道任务失败", err) + } + + return chans, nil +} + +func selectProxyByType(proxyType m.ProxyType, count int) (*m.Proxy, error) { + proxies, err := q.Proxy.Where( + q.Proxy.Type.Eq(int(proxyType)), + q.Proxy.Status.Eq(int(m.ProxyStatusOnline)), + ).Find() + if err != nil { + return nil, core.NewBizErr("获取可用代理失败", err) + } + if len(proxies) == 0 { + return nil, core.NewBizErr("无可用代理") + } + + proxyIDs := make([]int32, 0, len(proxies)) + proxyMap := make(map[int32]*m.Proxy, len(proxies)) + for _, item := range proxies { + proxyIDs = append(proxyIDs, item.ID) + proxyMap[item.ID] = item + } + + maxID := int32(0) + maxCount := -1 + for _, id := range proxyIDs { + idCount, err := g.Redis.SCard(context.Background(), freeChansKey(id)).Result() + if err != nil { + return nil, core.NewServErr("查询可用通道数量失败", err) + } + if idCount > int64(maxCount) { + maxCount = int(idCount) + maxID = id + } + } + if maxCount < count { + return nil, core.NewBizErr("无可用代理") + } + + return proxyMap[maxID], nil +} + func (s *channelServer) RefreshEdges() error { if env.RunMode != env.RunModeProd { return nil } - // 找到所有网关 + // 仅白银网关支持边缘节点刷新,GOST 不参与此流程。 proxies, err := q.Proxy.Where( + q.Proxy.Type.Eq(int(m.ProxyTypeBaiYin)), q.Proxy.Status.Eq(int(m.ProxyStatusOnline)), ).Find() if err != nil { @@ -282,6 +350,83 @@ func usedChansKey(proxy int32, batch string) string { return "channel:used:" + strconv.Itoa(int(proxy)) + ":" + batch } +type usedChanBatch struct { + ProxyID int32 + Chans []netip.AddrPort +} + +func findUsedChanBatch(batch string) (*usedChanBatch, error) { + keys, err := g.Redis.Keys(context.Background(), "channel:used:*:"+batch).Result() + if err != nil { + return nil, core.NewServErr("查询使用中通道失败", err) + } + + key, ok, err := selectUsedChanBatchKey(batch, keys) + if err != nil { + return nil, err + } + if !ok { + return nil, nil + } + + chans, err := g.Redis.LRange(context.Background(), key, 0, -1).Result() + if err != nil { + return nil, core.NewServErr("查询使用中通道失败", err) + } + + return parseUsedChanBatch(key, chans) +} + +func selectUsedChanBatchKey(batch string, keys []string) (string, bool, error) { + switch len(keys) { + case 0: + return "", false, nil + case 1: + return keys[0], true, nil + default: + slog.Error("batchNo 全局唯一约束被破坏", "batch", batch, "keys", keys) + return "", false, core.NewServErr( + fmt.Sprintf("检测到重复 usedChans 键,batchNo 全局唯一被破坏: %s", batch), + fmt.Errorf("keys=%s", strings.Join(keys, ",")), + ) + } +} + +func parseUsedChanBatch(key string, chans []string) (*usedChanBatch, error) { + proxyID, err := parseUsedChansKey(key) + if err != nil { + return nil, err + } + + addrs := make([]netip.AddrPort, len(chans)) + for i, ch := range chans { + addr, err := netip.ParseAddrPort(ch) + if err != nil { + return nil, core.NewServErr(fmt.Sprintf("解析通道数据失败: %s", ch), err) + } + addrs[i] = addr + } + + return &usedChanBatch{ + ProxyID: proxyID, + Chans: addrs, + }, nil +} + +func parseUsedChansKey(key string) (int32, error) { + parts := strings.Split(key, ":") + if len(parts) != 4 { + return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), nil) + } + + proxyID, err := strconv.Atoi(parts[2]) + if err != nil { + return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), err) + } + + return int32(proxyID), nil +} + // 扩容通道 func regChans(proxy int32, chans []netip.AddrPort) error { strs := make([]any, len(chans)) diff --git a/web/services/channel_baiyin.go b/web/services/channel_baiyin.go index 2600494..8af05e5 100644 --- a/web/services/channel_baiyin.go +++ b/web/services/channel_baiyin.go @@ -9,7 +9,6 @@ import ( "platform/pkg/env" "platform/pkg/u" "platform/web/core" - e "platform/web/events" g "platform/web/globals" "platform/web/globals/orm" m "platform/web/models" @@ -17,7 +16,6 @@ import ( "strings" "time" - "github.com/hibiken/asynq" "gorm.io/gen" "gorm.io/gen/field" ) @@ -213,41 +211,19 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error { return g.Redsync.WithLock(lockChannelRemoveKey(batchNo), func() error { start := time.Now() - // 获取连接数据 - channels, err := q.Channel.Where(q.Channel.BatchNo.Eq(batchNo)).Find() + batch, err := findUsedChanBatch(batchNo) if err != nil { - return core.NewServErr(fmt.Sprintf("获取通道数据失败,batch:%s", batchNo), err) + return err } - if len(channels) == 0 { - slog.Warn(fmt.Sprintf("未找到通道数据,batch:%s", batchNo)) + if batch == nil { + slog.Debug("通道为空,跳过清理", "batch", batchNo) return nil } - proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(channels[0].ProxyID)).Take() - if err != nil { - return core.NewServErr(fmt.Sprintf("获取代理数据失败,batch:%s", batchNo), err) - } - - // 检查通道是否存在 - chans, err := g.Redis.LRange(context.Background(), usedChansKey(proxy.ID, batchNo), 0, -1).Result() - if err != nil { - return core.NewServErr("查询使用中通道失败", err) - } - if len(chans) == 0 { - slog.Debug("通道为空,跳过清理", "key", usedChansKey(proxy.ID, batchNo)) - return nil // 没有使用中通道,已经被清理过了 - } - - // 准备配置数据 - configs := make([]*g.PortConfigsReq, len(chans)) - for i, ch := range chans { - ap, err := netip.ParseAddrPort(ch) - if err != nil { - return core.NewServErr(fmt.Sprintf("解析通道数据失败: %s", ch), err) - } - + configs := make([]*g.PortConfigsReq, len(batch.Chans)) + for i, ch := range batch.Chans { configs[i] = &g.PortConfigsReq{ - Port: int(ap.Port()), + Port: int(ch.Port()), Edge: &[]string{}, AutoEdgeConfig: &g.AutoEdgeConfig{Count: u.P(0)}, Status: false, @@ -256,6 +232,11 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error { // 提交配置 if env.RunMode == env.RunModeProd { + proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(batch.ProxyID)).Take() + if err != nil { + return core.NewServErr("获取代理数据失败", err) + } + gateway, err := proxyGateway(proxy) if err != nil { return core.NewServErr("创建代理网关失败", err) @@ -271,13 +252,11 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error { } } - // 释放端口 - err = freeChans(proxy.ID, batchNo) - if err != nil { + if err := freeChans(batch.ProxyID, batchNo); err != nil { return err } - slog.Debug("清除代理端口配置", "proxy", proxy.ID, "batch", batchNo, "duration", time.Since(start).String()) + slog.Debug("清除代理端口配置", "proxy", batch.ProxyID, "batch", batchNo, "duration", time.Since(start).String()) return nil }) } @@ -335,53 +314,12 @@ func (s *channelBaiyinProvider) ClearExpiredChannels(proxyId int32) (int, error) return len(batchSet), nil } -func lockChannelCreateKey(resourceNo string) string { - return fmt.Sprintf("platform:channel:create:%s", resourceNo) -} - -func lockChannelRemoveKey(bid string) string { - return fmt.Sprintf("platform:batch:remove_expired:%s", bid) -} - func selectProxy(count int) (*m.Proxy, g.GatewayClient, error) { - // 获取在线节点 - proxies, err := q.Proxy.Where( - q.Proxy.Type.Eq(int(m.ProxyTypeBaiYin)), - q.Proxy.Status.Eq(int(m.ProxyStatusOnline)), - ).Find() + proxy, err := selectProxyByType(m.ProxyTypeBaiYin, count) if err != nil { - return nil, nil, core.NewBizErr("获取可用代理失败", err) - } - if len(proxies) == 0 { - return nil, nil, core.NewBizErr("无可用代理") + return nil, nil, err } - proxyIDs := make([]int32, 0, len(proxies)) - proxyMap := make(map[int32]*m.Proxy, len(proxies)) - for _, item := range proxies { - proxyIDs = append(proxyIDs, item.ID) - proxyMap[item.ID] = item - } - - // 获取最空闲节点 - maxId := int32(0) - maxCount := -1 - for _, id := range proxyIDs { - idCount, err := g.Redis.SCard(context.Background(), freeChansKey(id)).Result() - if err != nil { - return nil, nil, fmt.Errorf("查询可用通道数量失败: %w", err) - } - - if idCount > int64(maxCount) { - maxCount = int(idCount) - maxId = id - } - } - if maxCount < count { - return nil, nil, core.NewBizErr("无可用代理") - } - - proxy := proxyMap[maxId] gateway, err := proxyGateway(proxy) if err != nil { return nil, nil, core.NewServErr("创建代理网关失败", err) @@ -390,23 +328,6 @@ func selectProxy(count int) (*m.Proxy, g.GatewayClient, error) { return proxy, gateway, nil } -func selectPorts(proxyId int32, batchNo string, count int, expire time.Time) ([]netip.AddrPort, error) { - chans, err := lockChans(proxyId, batchNo, count) - if err != nil { - return nil, core.NewBizErr("无可用通道,请稍后再试", err) - } - - _, err = g.Asynq.Enqueue( - e.NewRemoveChannel(batchNo), - asynq.ProcessAt(expire), - ) - if err != nil { - return nil, core.NewServErr("注册异步关闭通道任务失败", err) - } - - return chans, nil -} - // ensureEdges 检查本地节点是否足够,如果不足从云端连入 // 本地节点通过 Assigned = false 排除已分配节点 // 云端节点通过 NoRepeat = true 排除已分配节点 diff --git a/web/services/channel_gost.go b/web/services/channel_gost.go new file mode 100644 index 0000000..57bd0d6 --- /dev/null +++ b/web/services/channel_gost.go @@ -0,0 +1,390 @@ +package services + +import ( + "context" + "fmt" + "log/slog" + "net/netip" + "platform/pkg/env" + "platform/pkg/u" + "platform/web/core" + g "platform/web/globals" + "platform/web/globals/orm" + m "platform/web/models" + q "platform/web/queries" + "strings" + "time" + + "gorm.io/gen" + "gorm.io/gen/field" +) + +type channelGostProvider struct{} + +func (s *channelGostProvider) CreateChannels(source netip.Addr, resourceNo string, authWhitelist bool, authPassword bool, count int, filter *EdgeFilter) ([]*m.Channel, error) { + now := time.Now() + batchNo := ID.GenReadable("bat") + channels := make([]*m.Channel, count) + if filter == nil { + filter = &EdgeFilter{} + } + + err := g.Redsync.WithLock(lockChannelCreateKey(resourceNo), func() error { + resource, whitelists, err := ensure(now, source, resourceNo, authWhitelist, count) + if err != nil { + return err + } + + user := resource.User + expire := now.Add(resource.Live) + + proxy, err := s.selectProxy(count) + if err != nil { + return err + } + + chans, err := selectPorts(proxy.ID, batchNo, count, expire) + if err != nil { + return err + } + + edges, err := s.selectEdge(filter, count) + if err != nil { + return err + } + + client, err := proxyGost(proxy) + if err != nil { + return err + } + + admissions := make([]*g.GostAdmissionConfig, 0, count) + authers := make([]*g.GostAutherConfig, 0, count) + services := make([]*g.GostServiceConfig, count) + + for i := range count { + ch := chans[i] + edge := edges[i] + port := ch.Port() + host := u.Else(proxy.Host, proxy.IP.String()) + + serviceName := gostServiceName(batchNo, port) + channel := &m.Channel{ + UserID: user.ID, + ResourceID: resource.ID, + BatchNo: batchNo, + ProxyID: proxy.ID, + Host: host, + Port: port, + EdgeID: u.P(edge.ID), + EdgeRef: u.P(serviceName), + FilterISP: filter.Isp, + FilterProv: filter.Prov, + FilterCity: filter.City, + IP: u.P(edge.IP), + ExpiredAt: expire, + Proxy: proxy, + } + + service := &g.GostServiceConfig{ + Name: serviceName, + Addr: fmt.Sprintf(":%d", port), + Handler: g.GostHandlerConfig{ + Type: "auto", + Chain: edge.Mac, + }, + Listener: g.GostListenerConfig{ + Type: "tcp", + }, + } + + if authWhitelist { + channel.Whitelists = u.P(strings.Join(whitelists, ",")) + service.Admission = gostAdmissionName(batchNo, port) + admission := &g.GostAdmissionConfig{ + Name: service.Admission, + Whitelist: true, + Matchers: whitelists, + } + admissions = append(admissions, admission) + } + + if authPassword { + username, password := genPassPair() + channel.Username = &username + channel.Password = &password + service.Handler.Auther = gostAutherName(batchNo, port) + auther := &g.GostAutherConfig{ + Name: service.Handler.Auther, + Auths: []g.GostAuthConfig{{ + Username: username, + Password: password, + }}, + } + authers = append(authers, auther) + } + + services[i] = service + channels[i] = channel + } + + for _, admission := range admissions { + if err := client.CreateAdmission(admission); err != nil { + return core.NewServErr(fmt.Sprintf("创建 GOST admission 失败: %s", admission.Name), err) + } + } + for _, auther := range authers { + if err := client.CreateAuther(auther); err != nil { + return core.NewServErr(fmt.Sprintf("创建 GOST auther 失败: %s", auther.Name), err) + } + } + for _, service := range services { + if err := client.CreateService(service); err != nil { + return core.NewServErr(fmt.Sprintf("创建 GOST service 失败: %s", service.Name), err) + } + } + + err = q.Q.Transaction(func(tx *q.Query) error { + var result gen.ResultInfo + var err error + switch resource.Type { + case m.ResourceTypeShort: + result, err = tx.ResourceShort. + Where( + tx.ResourceShort.ID.Eq(*resource.ShortId), + tx.ResourceShort.Used.Eq(resource.Used), + tx.ResourceShort.Daily.Eq(resource.Daily), + ). + UpdateSimple( + tx.ResourceShort.Used.Add(int32(count)), + tx.ResourceShort.Daily.Value(int32(resource.Today+count)), + tx.ResourceShort.LastAt.Value(now), + ) + case m.ResourceTypeLong: + result, err = tx.ResourceLong. + Where( + tx.ResourceLong.ID.Eq(*resource.LongId), + tx.ResourceLong.Used.Eq(resource.Used), + tx.ResourceLong.Daily.Eq(resource.Daily), + ). + UpdateSimple( + tx.ResourceLong.Used.Add(int32(count)), + tx.ResourceLong.Daily.Value(int32(resource.Today+count)), + tx.ResourceLong.LastAt.Value(now), + ) + default: + return core.NewBizErr("套餐类型不正确,无法更新") + } + if err != nil { + return core.NewServErr("更新套餐使用记录失败", err) + } + if result.RowsAffected == 0 { + return core.NewBizErr("套餐状态已过期") + } + + if err := tx.Channel.Omit(field.AssociationFields).Create(channels...); err != nil { + return core.NewServErr("保存通道失败", err) + } + + if err := tx.LogsUserUsage.Create(&m.LogsUserUsage{ + UserID: user.ID, + ResourceID: resource.ID, + BatchNo: batchNo, + Count: int32(count), + ISP: u.X(filter.Isp.String()), + Prov: filter.Prov, + City: filter.City, + IP: orm.Inet{Addr: source}, + Time: now, + }); err != nil { + return core.NewServErr("保存用户使用记录失败", err) + } + + return nil + }) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + return channels, nil +} + +func (s *channelGostProvider) RemoveChannels(batchNo string) error { + return g.Redsync.WithLock(lockChannelRemoveKey(batchNo), func() error { + start := time.Now() + + batch, err := findUsedChanBatch(batchNo) + if err != nil { + return err + } + if batch == nil { + slog.Debug("通道为空,跳过清理", "batch", batchNo) + return nil + } + + if env.RunMode == env.RunModeProd { + proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(batch.ProxyID)).Take() + if err != nil { + return core.NewServErr("获取代理数据失败", err) + } + + client, err := proxyGost(proxy) + if err != nil { + return core.NewServErr("创建 GOST 客户端失败", err) + } + var deleteErrs []error + for _, ch := range batch.Chans { + port := ch.Port() + serviceName := gostServiceName(batchNo, port) + deleteErrs = append(deleteErrs, deleteGostResource("service", serviceName, func() error { + return client.DeleteService(serviceName) + })) + + autherName := gostAutherName(batchNo, port) + deleteErrs = append(deleteErrs, deleteGostResource("auther", autherName, func() error { + return client.DeleteAuther(autherName) + })) + + admissionName := gostAdmissionName(batchNo, port) + deleteErrs = append(deleteErrs, deleteGostResource("admission", admissionName, func() error { + return client.DeleteAdmission(admissionName) + })) + } + if err := u.CombineErrors(deleteErrs); err != nil { + return err + } + } + + if err := freeChans(batch.ProxyID, batchNo); err != nil { + return err + } + + slog.Debug("清除 GOST 端口配置", "proxy", batch.ProxyID, "batch", batchNo, "duration", time.Since(start).String()) + return nil + }) +} + +func (s *channelGostProvider) ClearExpiredChannels(proxyId int32) (int, error) { + now := time.Now() + + keys, err := g.Redis.Keys(context.Background(), usedChansKey(proxyId, "*")).Result() + if err != nil { + return 0, core.NewServErr("查询使用中通道失败", err) + } + if len(keys) == 0 { + return 0, nil + } + + batchList := make([]string, len(keys)) + batchSet := make(map[string]struct{}, len(keys)) + for i, key := range keys { + parts := strings.Split(key, ":") + if len(parts) != 4 { + return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), nil) + } + batchList[i] = parts[3] + batchSet[parts[3]] = struct{}{} + } + + var batchQueried []struct{ BatchNo string } + err = q.Channel. + Select(q.Channel.BatchNo). + Where( + q.Channel.BatchNo.In(batchList...), + q.Channel.ExpiredAt.Gte(now.UTC()), + ). + Group(q.Channel.BatchNo). + Scan(&batchQueried) + if err != nil { + return 0, core.NewServErr("查询过期通道失败", err) + } + for _, batch := range batchQueried { + delete(batchSet, batch.BatchNo) + } + + slog.Info("批量清理过期 GOST 通道", "count", len(batchSet)) + for batchNo := range batchSet { + if err := s.RemoveChannels(batchNo); err != nil { + slog.Error("清理过期 GOST 通道失败", "batch", batchNo, "error", err) + } + } + + return len(batchSet), nil +} + +func (s *channelGostProvider) selectProxy(count int) (*m.Proxy, error) { + return selectProxyByType(m.ProxyTypeGost, count) +} + +func (s *channelGostProvider) selectEdge(filter *EdgeFilter, count int) ([]*m.Edge, error) { + if filter == nil { + filter = &EdgeFilter{} + } + + do := q.Edge.Where( + q.Edge.Type.Eq(int(m.EdgeTypeGostChain)), + q.Edge.Status.Eq(int(m.EdgeStatusNormal)), + ) + if prov := u.N(filter.Prov); prov != nil { + do = do.Where(q.Edge.Prov.Eq(*prov)) + } + if city := u.N(filter.City); city != nil { + do = do.Where(q.Edge.City.Eq(*city)) + } + if isp := u.X(filter.Isp.String()); isp != nil { + do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp))) + } + + edges, err := q.Edge.Where(do).Order(q.Edge.ID).Limit(count).Find() + if err != nil { + return nil, core.NewBizErr("查询可用节点失败", err) + } + return expandGostEdges(edges, count) +} + +func expandGostEdges(edges []*m.Edge, count int) ([]*m.Edge, error) { + if len(edges) == 0 { + return nil, core.NewBizErr("地区可用节点数量不足") + } + + result := make([]*m.Edge, count) + for i := range count { + result[i] = edges[i%len(edges)] + } + return result, nil +} + +func proxyGost(proxy *m.Proxy) (g.GostClient, error) { + secret := strings.Split(u.Z(proxy.Secret), ":") + if len(secret) != 2 { + return nil, core.NewServErr(fmt.Sprintf("代理 %s 密钥格式错误", proxy.IP.String()), nil) + } + + host := u.Else(proxy.Host, proxy.IP.String()) + return g.NewGost(host, env.GostApiPort, env.GostApiPathPrefix, secret[0], secret[1]), nil +} + +func deleteGostResource(kind string, name string, deleteFn func() error) error { + if err := deleteFn(); err != nil && !g.IsGostNotFound(err) { + return core.NewServErr(fmt.Sprintf("删除 GOST %s 配置失败: %s", kind, name), err) + } + return nil +} + +func gostServiceName(batchNo string, port uint16) string { + return fmt.Sprintf("gost-svc-%s-%d", batchNo, port) +} + +func gostAutherName(batchNo string, port uint16) string { + return fmt.Sprintf("gost-auther-%s-%d", batchNo, port) +} + +func gostAdmissionName(batchNo string, port uint16) string { + return fmt.Sprintf("gost-adm-%s-%d", batchNo, port) +} diff --git a/web/services/channel_gost_test.go b/web/services/channel_gost_test.go new file mode 100644 index 0000000..3f80a46 --- /dev/null +++ b/web/services/channel_gost_test.go @@ -0,0 +1,74 @@ +package services + +import ( + "testing" + + m "platform/web/models" +) + +func TestExpandGostEdgesRejectsEmpty(t *testing.T) { + _, err := expandGostEdges(nil, 1) + if err == nil { + t.Fatal("expected error, got nil") + } +} + +func TestExpandGostEdgesReusesWhenInsufficient(t *testing.T) { + edges := []*m.Edge{ + {Mac: "chain-a"}, + {Mac: "chain-b"}, + } + + result, err := expandGostEdges(edges, 5) + if err != nil { + t.Fatalf("expandGostEdges returned error: %v", err) + } + if len(result) != 5 { + t.Fatalf("unexpected edge count: %d", len(result)) + } + + expected := []string{"chain-a", "chain-b", "chain-a", "chain-b", "chain-a"} + for i, edge := range result { + if edge.Mac != expected[i] { + t.Fatalf("unexpected edge at %d: %s", i, edge.Mac) + } + } +} + +func TestEdgeFilterIsEmpty(t *testing.T) { + if !(*EdgeFilter)(nil).IsEmpty() { + t.Fatal("nil filter should be empty") + } + if (&EdgeFilter{}).IsEmpty() != true { + t.Fatal("empty filter should be empty") + } + if (&EdgeFilter{Prov: strPtr("")}).IsEmpty() != true { + t.Fatal("filter with empty province should be empty") + } + if (&EdgeFilter{City: strPtr("")}).IsEmpty() != true { + t.Fatal("filter with empty city should be empty") + } + if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(0))}).IsEmpty() != true { + t.Fatal("filter with zero ISP should be empty") + } + if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(99))}).IsEmpty() != true { + t.Fatal("filter with invalid ISP should be empty") + } + + prov := "江苏" + if (&EdgeFilter{Prov: &prov}).IsEmpty() { + t.Fatal("filter with province should not be empty") + } + isp := m.EdgeISPTelecom + if (&EdgeFilter{Isp: &isp}).IsEmpty() { + t.Fatal("filter with valid ISP should not be empty") + } +} + +func strPtr(v string) *string { + return &v +} + +func ispPtr(v m.EdgeISP) *m.EdgeISP { + return &v +} diff --git a/web/services/edge.go b/web/services/edge.go index 4a346c9..d2ea407 100644 --- a/web/services/edge.go +++ b/web/services/edge.go @@ -11,14 +11,14 @@ var Edge = &edgeService{} type edgeService struct{} func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) { - do := q.Edge.Where() - if filter.Prov != nil { - do = do.Where(q.Edge.Prov.Eq(*filter.Prov)) + do := q.Edge.Where(q.Edge.Type.Eq(int(m.EdgeTypeSelfBuilt))) + if prov := u.N(filter.Prov); prov != nil { + do = do.Where(q.Edge.Prov.Eq(*prov)) } - if filter.City != nil { - do = do.Where(q.Edge.City.Eq(*filter.City)) + if city := u.N(filter.City); city != nil { + do = do.Where(q.Edge.City.Eq(*city)) } - if filter.Isp != nil { + if isp := u.X(filter.Isp.String()); isp != nil { do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp))) } if count > 0 { @@ -44,9 +44,5 @@ func (f *EdgeFilter) IsEmpty() bool { return true } - if f.Isp.String() == "" || u.Z(f.Prov) != "" || u.Z(f.City) != "" { - return false - } - - return false + return u.X(f.Isp.String()) == nil && u.N(f.Prov) == nil && u.N(f.City) == nil }