重构提取逻辑,新增 area 表

This commit is contained in:
2026-06-10 14:32:45 +08:00
parent dd482dd6b0
commit ebac8042ea
26 changed files with 7939 additions and 666 deletions

112
web/services/area.go Normal file
View File

@@ -0,0 +1,112 @@
package services
import (
"errors"
"platform/pkg/u"
"platform/web/core"
m "platform/web/models"
q "platform/web/queries"
"gorm.io/gorm"
)
var Area = &areaService{}
type areaService struct{}
func (s *areaService) ListAreas() ([]*m.Area, error) {
areas, err := q.Area.
Order(q.Area.Level, q.Area.ParentID, q.Area.ID).
Find()
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
return areas, nil
}
func (s *areaService) FindIdByFilter(prov *string, city *string) (*int32, error) {
prov = u.N(prov)
city = u.N(city)
if prov == nil && city == nil {
return nil, nil
}
switch {
case prov != nil && city == nil:
area, err := q.Area.
Where(
q.Area.Level.Eq(int(m.AreaLevelProvince)),
q.Area.Name.Eq(*prov),
).
Order(q.Area.ID).
Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAreaNotExist
}
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
return u.P(area.ID), nil
case prov == nil && city != nil:
area, err := q.Area.
Where(
q.Area.Level.Eq(int(m.AreaLevelCity)),
q.Area.Name.Eq(*city),
).
Order(q.Area.ID).
Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAreaNotExist
}
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
return u.P(area.ID), nil
default:
province, err := q.Area.
Where(
q.Area.Level.Eq(int(m.AreaLevelProvince)),
q.Area.Name.Eq(*prov),
).
Order(q.Area.ID).
Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAreaNotExist
}
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
area, err := q.Area.
Where(
q.Area.ParentID.Eq(province.ID),
q.Area.Level.Eq(int(m.AreaLevelCity)),
q.Area.Name.Eq(*city),
).
Order(q.Area.ID).
Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAreaNotExist
}
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
return u.P(area.ID), nil
}
}
func (s *areaService) Get(id int32) (*m.Area, error) {
area, err := q.Area.
Preload(q.Area.Parent).
Where(q.Area.ID.Eq(id)).
Take()
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, ErrAreaNotExist
}
if err != nil {
return nil, core.NewServErr("查询地区失败", err)
}
return area, nil
}
var ErrAreaNotExist = core.NewBizErr("地区不存在")

View File

@@ -40,12 +40,21 @@ type channelServer struct {
}
func (s *channelServer) CreateChannels(source netip.Addr, resourceNo string, authWhitelist bool, authPassword bool, count int, edgeFilter *EdgeFilter) ([]*m.Channel, error) {
var area *m.Area
if edgeFilter.AreaID != nil {
var err error
area, err = Area.Get(*edgeFilter.AreaID)
if err != nil {
return nil, err
}
if err := validateChannelArea(area); err != nil {
return nil, err
}
}
now := time.Now()
batchNo := ID.GenReadable("bat")
var channels []*m.Channel
if edgeFilter == nil {
edgeFilter = &EdgeFilter{}
}
var whitelistText *string
err := g.Redsync.WithLock(lockChannelCreateKey(resourceNo), func() error {
@@ -80,6 +89,7 @@ func (s *channelServer) CreateChannels(source netip.Addr, resourceNo string, aut
Expire: expire,
Count: count,
Filter: edgeFilter,
Area: area,
AuthWhitelist: authWhitelist,
AuthPassword: authPassword,
Whitelists: whitelists,
@@ -160,6 +170,7 @@ type channelCreateContext struct {
Expire time.Time
Count int
Filter *EdgeFilter
Area *m.Area
AuthWhitelist bool
AuthPassword bool
Whitelists []string
@@ -172,6 +183,7 @@ type channelCreateResult struct {
}
func newBaseChannel(ctx *channelCreateContext, port uint16) *m.Channel {
prov, city := areaProvinceCity(ctx.Area)
return &m.Channel{
UserID: ctx.Resource.User.ID,
ResourceID: ctx.Resource.ID,
@@ -180,8 +192,8 @@ func newBaseChannel(ctx *channelCreateContext, port uint16) *m.Channel {
Host: u.Else(ctx.Proxy.Host, ctx.Proxy.IP.String()),
Port: port,
FilterISP: ctx.Filter.Isp,
FilterProv: ctx.Filter.Prov,
FilterCity: ctx.Filter.City,
FilterProv: prov,
FilterCity: city,
ExpiredAt: ctx.Expire,
Proxy: ctx.Proxy,
}
@@ -202,6 +214,7 @@ func applyChannelAuth(ctx *channelCreateContext, channel *m.Channel) (username s
}
func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) error {
prov, city := areaProvinceCity(ctx.Area)
return q.Q.Transaction(func(tx *q.Query) error {
var (
result gen.ResultInfo
@@ -252,8 +265,8 @@ func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) erro
BatchNo: ctx.BatchNo,
Count: int32(ctx.Count),
ISP: u.X(ctx.Filter.Isp.String()),
Prov: ctx.Filter.Prov,
City: ctx.Filter.City,
Prov: prov,
City: city,
IP: orm.Inet{Addr: ctx.Source},
Time: ctx.Now,
}); err != nil {
@@ -264,6 +277,37 @@ func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) erro
})
}
func validateChannelArea(area *m.Area) error {
if area == nil {
return nil
}
switch area.Level {
case m.AreaLevelProvince:
return nil
case m.AreaLevelCity:
if area.ParentID == nil || area.Parent == nil {
return core.NewServErr("地区数据异常", nil)
}
return nil
default:
return core.NewBizErr("地区层级不支持")
}
}
func areaProvinceCity(area *m.Area) (prov *string, city *string) {
if area == nil {
return nil, nil
}
switch area.Level {
case m.AreaLevelProvince:
return u.P(area.Name), nil
case m.AreaLevelCity:
return u.P(area.Parent.Name), u.P(area.Name)
default:
return nil, nil
}
}
func findExpiredChannelBatches(proxyId int32, now time.Time) (map[string]struct{}, error) {
keys, err := g.Redis.Keys(context.Background(), usedChansKey(proxyId, "*")).Result()
if err != nil {
@@ -778,6 +822,20 @@ redis.call("DEL", batch_key)
return 1
`)
// 节点筛选条件
type EdgeFilter struct {
Isp *m.EdgeISP `json:"isp"`
AreaID *int32 `json:"area_id"`
}
func (f *EdgeFilter) IsEmpty() bool {
if f == nil {
return true
}
return u.X(f.Isp.String()) == nil && f.AreaID == nil
}
// 错误信息
var (
ErrResourceNotExist = core.NewBizErr("套餐不存在")

View File

@@ -21,6 +21,7 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
if err != nil {
return nil, core.NewServErr("创建代理网关失败", err)
}
prov, city := areaProvinceCity(ctx.Area)
channels := make([]*m.Channel, len(ctx.Ports))
chanConfigs := make([]*g.PortConfigsReq, len(ctx.Ports))
@@ -30,8 +31,8 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
Port: int(portRef.Port()),
Status: true,
AutoEdgeConfig: &g.AutoEdgeConfig{
Province: u.Z(ctx.Filter.Prov),
City: u.Z(ctx.Filter.City),
Province: u.Z(prov),
City: u.Z(city),
Isp: ctx.Filter.Isp.String(),
Count: u.P(1),
},
@@ -52,7 +53,7 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
Channels: channels,
applyRemote: func() error {
slog.Debug("提交代理端口配置", "proxy", ctx.Proxy.IP.String(), "total_count", len(chanConfigs))
if err := ensureEdges(ctx.Proxy, gateway, ctx.Filter, ctx.Count); err != nil {
if err := ensureEdges(ctx.Proxy, gateway, ctx.Area, ctx.Filter.Isp, ctx.Count); err != nil {
slog.Warn("ensureEdges 失败", "err", err)
}
if len(chanConfigs) > 0 {
@@ -96,16 +97,17 @@ func (s *channelBaiyinProvider) removeRemote(_ string, batch *usedChanBatch) err
// ensureEdges 检查本地节点是否足够,如果不足从云端连入
// 本地节点通过 Assigned = false 排除已分配节点
// 云端节点通过 NoRepeat = true 排除已分配节点
func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, filter *EdgeFilter, count int) error {
if filter.IsEmpty() {
func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, area *m.Area, isp *m.EdgeISP, count int) error {
prov, city := areaProvinceCity(area)
if prov == nil && city == nil && u.X(isp.String()) == nil {
return nil // 没有过滤条件,直接返回空,避免无意义的查询
}
// 先查本地
localEdges, err := gateway.GatewayEdge(&g.GatewayEdgeReq{
Province: filter.Prov,
City: filter.City,
Isp: u.X(filter.Isp.String()),
Province: prov,
City: city,
Isp: u.X(isp.String()),
Limit: &count,
Assigned: u.P(false),
})
@@ -119,9 +121,9 @@ func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, filter *EdgeFilter, co
// 再查云端
remaining := count - len(localEdges)
cloudEdges, err := g.Cloud.CloudEdges(&g.CloudEdgesReq{
Province: filter.Prov,
City: filter.City,
Isp: u.X(filter.Isp.String()),
Province: prov,
City: city,
Isp: u.X(isp.String()),
Limit: &remaining,
NoRepeat: u.P(true),
ActiveTime: u.P(3600),

View File

@@ -9,12 +9,14 @@ import (
m "platform/web/models"
q "platform/web/queries"
"strings"
"gorm.io/gen"
)
type channelGostProvider struct{}
func (s *channelGostProvider) prepareCreate(ctx *channelCreateContext) (*channelCreateResult, error) {
edges, err := s.selectEdge(ctx.Filter, ctx.Count)
edges, err := s.selectEdge(ctx.Filter, ctx.Area, ctx.Count)
if err != nil {
return nil, err
}
@@ -131,26 +133,38 @@ func (s *channelGostProvider) selectProxy(count int) (*m.Proxy, error) {
return selectProxyByType(m.ProxyTypeGost, count)
}
func (s *channelGostProvider) selectEdge(filter *EdgeFilter, count int) ([]*m.Edge, error) {
func (s *channelGostProvider) selectEdge(filter *EdgeFilter, area *m.Area, count int) ([]*m.Edge, error) {
if filter == nil {
filter = &EdgeFilter{}
}
do := q.Edge.Where(
conds := []gen.Condition{
q.Edge.Type.Eq(int(m.EdgeTypeGostChain)),
q.Edge.Status.Eq(int(m.EdgeStatusNormal)),
)
if prov := u.N(filter.Prov); prov != nil {
do = do.Where(q.Edge.Prov.Eq(*prov))
}
if city := u.N(filter.City); city != nil {
do = do.Where(q.Edge.City.Eq(*city))
}
if isp := u.X(filter.Isp.String()); isp != nil {
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
conds = append(conds, q.Edge.ISP.Eq(int(*filter.Isp)))
}
edges, err := q.Edge.Where(do).Order(q.Edge.ID).Limit(count).Find()
query := q.Edge.Where(conds...)
if area != nil {
switch area.Level {
case m.AreaLevelProvince:
edgeArea := q.Area.As("EdgeArea")
query = query.
Join(edgeArea, edgeArea.ID.EqCol(q.Edge.AreaID)).
Where(edgeArea.ParentID.Eq(area.ID))
case m.AreaLevelCity:
query = query.Where(q.Edge.AreaID.Eq(area.ID))
default:
return nil, core.NewBizErr("地区层级不支持")
}
}
edges, err := query.
Order(q.Edge.ID).
Limit(count).
Find()
if err != nil {
return nil, core.NewBizErr("查询可用节点失败", err)
}

View File

@@ -1,74 +0,0 @@
package services
import (
"testing"
m "platform/web/models"
)
func TestExpandGostEdgesRejectsEmpty(t *testing.T) {
_, err := expandGostEdges(nil, 1)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestExpandGostEdgesReusesWhenInsufficient(t *testing.T) {
edges := []*m.Edge{
{Mac: "chain-a"},
{Mac: "chain-b"},
}
result, err := expandGostEdges(edges, 5)
if err != nil {
t.Fatalf("expandGostEdges returned error: %v", err)
}
if len(result) != 5 {
t.Fatalf("unexpected edge count: %d", len(result))
}
expected := []string{"chain-a", "chain-b", "chain-a", "chain-b", "chain-a"}
for i, edge := range result {
if edge.Mac != expected[i] {
t.Fatalf("unexpected edge at %d: %s", i, edge.Mac)
}
}
}
func TestEdgeFilterIsEmpty(t *testing.T) {
if !(*EdgeFilter)(nil).IsEmpty() {
t.Fatal("nil filter should be empty")
}
if (&EdgeFilter{}).IsEmpty() != true {
t.Fatal("empty filter should be empty")
}
if (&EdgeFilter{Prov: strPtr("")}).IsEmpty() != true {
t.Fatal("filter with empty province should be empty")
}
if (&EdgeFilter{City: strPtr("")}).IsEmpty() != true {
t.Fatal("filter with empty city should be empty")
}
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(0))}).IsEmpty() != true {
t.Fatal("filter with zero ISP should be empty")
}
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(99))}).IsEmpty() != true {
t.Fatal("filter with invalid ISP should be empty")
}
prov := "江苏"
if (&EdgeFilter{Prov: &prov}).IsEmpty() {
t.Fatal("filter with province should not be empty")
}
isp := m.EdgeISPTelecom
if (&EdgeFilter{Isp: &isp}).IsEmpty() {
t.Fatal("filter with valid ISP should not be empty")
}
}
func strPtr(v string) *string {
return &v
}
func ispPtr(v m.EdgeISP) *m.EdgeISP {
return &v
}

View File

@@ -1,48 +0,0 @@
package services
import (
"platform/pkg/u"
m "platform/web/models"
q "platform/web/queries"
)
var Edge = &edgeService{}
type edgeService struct{}
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
do := q.Edge.Where(q.Edge.Type.Eq(int(m.EdgeTypeSelfBuilt)))
if prov := u.N(filter.Prov); prov != nil {
do = do.Where(q.Edge.Prov.Eq(*prov))
}
if city := u.N(filter.City); city != nil {
do = do.Where(q.Edge.City.Eq(*city))
}
if isp := u.X(filter.Isp.String()); isp != nil {
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
}
if count > 0 {
do = do.Limit(count)
}
edges, err := q.Edge.Where(do).Find()
if err != nil {
return nil, err
}
return edges, nil
}
type EdgeFilter struct {
Isp *m.EdgeISP `json:"isp"`
Prov *string `json:"prov"`
City *string `json:"city"`
}
func (f *EdgeFilter) IsEmpty() bool {
if f == nil {
return true
}
return u.X(f.Isp.String()) == nil && u.N(f.Prov) == nil && u.N(f.City) == nil
}

View File

@@ -161,6 +161,17 @@ func (s *proxyService) Update(update *UpdateProxy) error {
return nil
}
func (s *proxyService) SyncPool(id int32) error {
proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(id)).Select(q.Proxy.ID, q.Proxy.IP).First()
if err != nil {
return core.NewServErr("获取代理数据失败", err)
}
if proxy == nil {
return core.NewBizErr("代理不存在")
}
return rebuildFreeChans(id, proxy.IP.Addr)
}
func (s *proxyService) Remove(id int32) error {
used, err := hasUsedChans(id)
if err != nil {