重构提取逻辑,新增 area 表
This commit is contained in:
112
web/services/area.go
Normal file
112
web/services/area.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"platform/pkg/u"
|
||||
"platform/web/core"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
var Area = &areaService{}
|
||||
|
||||
type areaService struct{}
|
||||
|
||||
func (s *areaService) ListAreas() ([]*m.Area, error) {
|
||||
areas, err := q.Area.
|
||||
Order(q.Area.Level, q.Area.ParentID, q.Area.ID).
|
||||
Find()
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
return areas, nil
|
||||
}
|
||||
|
||||
func (s *areaService) FindIdByFilter(prov *string, city *string) (*int32, error) {
|
||||
prov = u.N(prov)
|
||||
city = u.N(city)
|
||||
if prov == nil && city == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case prov != nil && city == nil:
|
||||
area, err := q.Area.
|
||||
Where(
|
||||
q.Area.Level.Eq(int(m.AreaLevelProvince)),
|
||||
q.Area.Name.Eq(*prov),
|
||||
).
|
||||
Order(q.Area.ID).
|
||||
Take()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAreaNotExist
|
||||
}
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
return u.P(area.ID), nil
|
||||
case prov == nil && city != nil:
|
||||
area, err := q.Area.
|
||||
Where(
|
||||
q.Area.Level.Eq(int(m.AreaLevelCity)),
|
||||
q.Area.Name.Eq(*city),
|
||||
).
|
||||
Order(q.Area.ID).
|
||||
Take()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAreaNotExist
|
||||
}
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
return u.P(area.ID), nil
|
||||
default:
|
||||
province, err := q.Area.
|
||||
Where(
|
||||
q.Area.Level.Eq(int(m.AreaLevelProvince)),
|
||||
q.Area.Name.Eq(*prov),
|
||||
).
|
||||
Order(q.Area.ID).
|
||||
Take()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAreaNotExist
|
||||
}
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
|
||||
area, err := q.Area.
|
||||
Where(
|
||||
q.Area.ParentID.Eq(province.ID),
|
||||
q.Area.Level.Eq(int(m.AreaLevelCity)),
|
||||
q.Area.Name.Eq(*city),
|
||||
).
|
||||
Order(q.Area.ID).
|
||||
Take()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAreaNotExist
|
||||
}
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
return u.P(area.ID), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *areaService) Get(id int32) (*m.Area, error) {
|
||||
area, err := q.Area.
|
||||
Preload(q.Area.Parent).
|
||||
Where(q.Area.ID.Eq(id)).
|
||||
Take()
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, ErrAreaNotExist
|
||||
}
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("查询地区失败", err)
|
||||
}
|
||||
return area, nil
|
||||
}
|
||||
|
||||
var ErrAreaNotExist = core.NewBizErr("地区不存在")
|
||||
@@ -40,12 +40,21 @@ type channelServer struct {
|
||||
}
|
||||
|
||||
func (s *channelServer) CreateChannels(source netip.Addr, resourceNo string, authWhitelist bool, authPassword bool, count int, edgeFilter *EdgeFilter) ([]*m.Channel, error) {
|
||||
var area *m.Area
|
||||
if edgeFilter.AreaID != nil {
|
||||
var err error
|
||||
area, err = Area.Get(*edgeFilter.AreaID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := validateChannelArea(area); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
batchNo := ID.GenReadable("bat")
|
||||
var channels []*m.Channel
|
||||
if edgeFilter == nil {
|
||||
edgeFilter = &EdgeFilter{}
|
||||
}
|
||||
|
||||
var whitelistText *string
|
||||
err := g.Redsync.WithLock(lockChannelCreateKey(resourceNo), func() error {
|
||||
@@ -80,6 +89,7 @@ func (s *channelServer) CreateChannels(source netip.Addr, resourceNo string, aut
|
||||
Expire: expire,
|
||||
Count: count,
|
||||
Filter: edgeFilter,
|
||||
Area: area,
|
||||
AuthWhitelist: authWhitelist,
|
||||
AuthPassword: authPassword,
|
||||
Whitelists: whitelists,
|
||||
@@ -160,6 +170,7 @@ type channelCreateContext struct {
|
||||
Expire time.Time
|
||||
Count int
|
||||
Filter *EdgeFilter
|
||||
Area *m.Area
|
||||
AuthWhitelist bool
|
||||
AuthPassword bool
|
||||
Whitelists []string
|
||||
@@ -172,6 +183,7 @@ type channelCreateResult struct {
|
||||
}
|
||||
|
||||
func newBaseChannel(ctx *channelCreateContext, port uint16) *m.Channel {
|
||||
prov, city := areaProvinceCity(ctx.Area)
|
||||
return &m.Channel{
|
||||
UserID: ctx.Resource.User.ID,
|
||||
ResourceID: ctx.Resource.ID,
|
||||
@@ -180,8 +192,8 @@ func newBaseChannel(ctx *channelCreateContext, port uint16) *m.Channel {
|
||||
Host: u.Else(ctx.Proxy.Host, ctx.Proxy.IP.String()),
|
||||
Port: port,
|
||||
FilterISP: ctx.Filter.Isp,
|
||||
FilterProv: ctx.Filter.Prov,
|
||||
FilterCity: ctx.Filter.City,
|
||||
FilterProv: prov,
|
||||
FilterCity: city,
|
||||
ExpiredAt: ctx.Expire,
|
||||
Proxy: ctx.Proxy,
|
||||
}
|
||||
@@ -202,6 +214,7 @@ func applyChannelAuth(ctx *channelCreateContext, channel *m.Channel) (username s
|
||||
}
|
||||
|
||||
func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) error {
|
||||
prov, city := areaProvinceCity(ctx.Area)
|
||||
return q.Q.Transaction(func(tx *q.Query) error {
|
||||
var (
|
||||
result gen.ResultInfo
|
||||
@@ -252,8 +265,8 @@ func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) erro
|
||||
BatchNo: ctx.BatchNo,
|
||||
Count: int32(ctx.Count),
|
||||
ISP: u.X(ctx.Filter.Isp.String()),
|
||||
Prov: ctx.Filter.Prov,
|
||||
City: ctx.Filter.City,
|
||||
Prov: prov,
|
||||
City: city,
|
||||
IP: orm.Inet{Addr: ctx.Source},
|
||||
Time: ctx.Now,
|
||||
}); err != nil {
|
||||
@@ -264,6 +277,37 @@ func persistChannelCreate(ctx *channelCreateContext, channels []*m.Channel) erro
|
||||
})
|
||||
}
|
||||
|
||||
func validateChannelArea(area *m.Area) error {
|
||||
if area == nil {
|
||||
return nil
|
||||
}
|
||||
switch area.Level {
|
||||
case m.AreaLevelProvince:
|
||||
return nil
|
||||
case m.AreaLevelCity:
|
||||
if area.ParentID == nil || area.Parent == nil {
|
||||
return core.NewServErr("地区数据异常", nil)
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return core.NewBizErr("地区层级不支持")
|
||||
}
|
||||
}
|
||||
|
||||
func areaProvinceCity(area *m.Area) (prov *string, city *string) {
|
||||
if area == nil {
|
||||
return nil, nil
|
||||
}
|
||||
switch area.Level {
|
||||
case m.AreaLevelProvince:
|
||||
return u.P(area.Name), nil
|
||||
case m.AreaLevelCity:
|
||||
return u.P(area.Parent.Name), u.P(area.Name)
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func findExpiredChannelBatches(proxyId int32, now time.Time) (map[string]struct{}, error) {
|
||||
keys, err := g.Redis.Keys(context.Background(), usedChansKey(proxyId, "*")).Result()
|
||||
if err != nil {
|
||||
@@ -778,6 +822,20 @@ redis.call("DEL", batch_key)
|
||||
return 1
|
||||
`)
|
||||
|
||||
// 节点筛选条件
|
||||
type EdgeFilter struct {
|
||||
Isp *m.EdgeISP `json:"isp"`
|
||||
AreaID *int32 `json:"area_id"`
|
||||
}
|
||||
|
||||
func (f *EdgeFilter) IsEmpty() bool {
|
||||
if f == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return u.X(f.Isp.String()) == nil && f.AreaID == nil
|
||||
}
|
||||
|
||||
// 错误信息
|
||||
var (
|
||||
ErrResourceNotExist = core.NewBizErr("套餐不存在")
|
||||
|
||||
@@ -21,6 +21,7 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
|
||||
if err != nil {
|
||||
return nil, core.NewServErr("创建代理网关失败", err)
|
||||
}
|
||||
prov, city := areaProvinceCity(ctx.Area)
|
||||
|
||||
channels := make([]*m.Channel, len(ctx.Ports))
|
||||
chanConfigs := make([]*g.PortConfigsReq, len(ctx.Ports))
|
||||
@@ -30,8 +31,8 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
|
||||
Port: int(portRef.Port()),
|
||||
Status: true,
|
||||
AutoEdgeConfig: &g.AutoEdgeConfig{
|
||||
Province: u.Z(ctx.Filter.Prov),
|
||||
City: u.Z(ctx.Filter.City),
|
||||
Province: u.Z(prov),
|
||||
City: u.Z(city),
|
||||
Isp: ctx.Filter.Isp.String(),
|
||||
Count: u.P(1),
|
||||
},
|
||||
@@ -52,7 +53,7 @@ func (s *channelBaiyinProvider) prepareCreate(ctx *channelCreateContext) (*chann
|
||||
Channels: channels,
|
||||
applyRemote: func() error {
|
||||
slog.Debug("提交代理端口配置", "proxy", ctx.Proxy.IP.String(), "total_count", len(chanConfigs))
|
||||
if err := ensureEdges(ctx.Proxy, gateway, ctx.Filter, ctx.Count); err != nil {
|
||||
if err := ensureEdges(ctx.Proxy, gateway, ctx.Area, ctx.Filter.Isp, ctx.Count); err != nil {
|
||||
slog.Warn("ensureEdges 失败", "err", err)
|
||||
}
|
||||
if len(chanConfigs) > 0 {
|
||||
@@ -96,16 +97,17 @@ func (s *channelBaiyinProvider) removeRemote(_ string, batch *usedChanBatch) err
|
||||
// ensureEdges 检查本地节点是否足够,如果不足从云端连入
|
||||
// 本地节点通过 Assigned = false 排除已分配节点
|
||||
// 云端节点通过 NoRepeat = true 排除已分配节点
|
||||
func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, filter *EdgeFilter, count int) error {
|
||||
if filter.IsEmpty() {
|
||||
func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, area *m.Area, isp *m.EdgeISP, count int) error {
|
||||
prov, city := areaProvinceCity(area)
|
||||
if prov == nil && city == nil && u.X(isp.String()) == nil {
|
||||
return nil // 没有过滤条件,直接返回空,避免无意义的查询
|
||||
}
|
||||
|
||||
// 先查本地
|
||||
localEdges, err := gateway.GatewayEdge(&g.GatewayEdgeReq{
|
||||
Province: filter.Prov,
|
||||
City: filter.City,
|
||||
Isp: u.X(filter.Isp.String()),
|
||||
Province: prov,
|
||||
City: city,
|
||||
Isp: u.X(isp.String()),
|
||||
Limit: &count,
|
||||
Assigned: u.P(false),
|
||||
})
|
||||
@@ -119,9 +121,9 @@ func ensureEdges(proxy *m.Proxy, gateway g.GatewayClient, filter *EdgeFilter, co
|
||||
// 再查云端
|
||||
remaining := count - len(localEdges)
|
||||
cloudEdges, err := g.Cloud.CloudEdges(&g.CloudEdgesReq{
|
||||
Province: filter.Prov,
|
||||
City: filter.City,
|
||||
Isp: u.X(filter.Isp.String()),
|
||||
Province: prov,
|
||||
City: city,
|
||||
Isp: u.X(isp.String()),
|
||||
Limit: &remaining,
|
||||
NoRepeat: u.P(true),
|
||||
ActiveTime: u.P(3600),
|
||||
|
||||
@@ -9,12 +9,14 @@ import (
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gen"
|
||||
)
|
||||
|
||||
type channelGostProvider struct{}
|
||||
|
||||
func (s *channelGostProvider) prepareCreate(ctx *channelCreateContext) (*channelCreateResult, error) {
|
||||
edges, err := s.selectEdge(ctx.Filter, ctx.Count)
|
||||
edges, err := s.selectEdge(ctx.Filter, ctx.Area, ctx.Count)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -131,26 +133,38 @@ func (s *channelGostProvider) selectProxy(count int) (*m.Proxy, error) {
|
||||
return selectProxyByType(m.ProxyTypeGost, count)
|
||||
}
|
||||
|
||||
func (s *channelGostProvider) selectEdge(filter *EdgeFilter, count int) ([]*m.Edge, error) {
|
||||
func (s *channelGostProvider) selectEdge(filter *EdgeFilter, area *m.Area, count int) ([]*m.Edge, error) {
|
||||
if filter == nil {
|
||||
filter = &EdgeFilter{}
|
||||
}
|
||||
|
||||
do := q.Edge.Where(
|
||||
conds := []gen.Condition{
|
||||
q.Edge.Type.Eq(int(m.EdgeTypeGostChain)),
|
||||
q.Edge.Status.Eq(int(m.EdgeStatusNormal)),
|
||||
)
|
||||
if prov := u.N(filter.Prov); prov != nil {
|
||||
do = do.Where(q.Edge.Prov.Eq(*prov))
|
||||
}
|
||||
if city := u.N(filter.City); city != nil {
|
||||
do = do.Where(q.Edge.City.Eq(*city))
|
||||
}
|
||||
if isp := u.X(filter.Isp.String()); isp != nil {
|
||||
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
|
||||
conds = append(conds, q.Edge.ISP.Eq(int(*filter.Isp)))
|
||||
}
|
||||
|
||||
edges, err := q.Edge.Where(do).Order(q.Edge.ID).Limit(count).Find()
|
||||
query := q.Edge.Where(conds...)
|
||||
if area != nil {
|
||||
switch area.Level {
|
||||
case m.AreaLevelProvince:
|
||||
edgeArea := q.Area.As("EdgeArea")
|
||||
query = query.
|
||||
Join(edgeArea, edgeArea.ID.EqCol(q.Edge.AreaID)).
|
||||
Where(edgeArea.ParentID.Eq(area.ID))
|
||||
case m.AreaLevelCity:
|
||||
query = query.Where(q.Edge.AreaID.Eq(area.ID))
|
||||
default:
|
||||
return nil, core.NewBizErr("地区层级不支持")
|
||||
}
|
||||
}
|
||||
|
||||
edges, err := query.
|
||||
Order(q.Edge.ID).
|
||||
Limit(count).
|
||||
Find()
|
||||
if err != nil {
|
||||
return nil, core.NewBizErr("查询可用节点失败", err)
|
||||
}
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
m "platform/web/models"
|
||||
)
|
||||
|
||||
func TestExpandGostEdgesRejectsEmpty(t *testing.T) {
|
||||
_, err := expandGostEdges(nil, 1)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExpandGostEdgesReusesWhenInsufficient(t *testing.T) {
|
||||
edges := []*m.Edge{
|
||||
{Mac: "chain-a"},
|
||||
{Mac: "chain-b"},
|
||||
}
|
||||
|
||||
result, err := expandGostEdges(edges, 5)
|
||||
if err != nil {
|
||||
t.Fatalf("expandGostEdges returned error: %v", err)
|
||||
}
|
||||
if len(result) != 5 {
|
||||
t.Fatalf("unexpected edge count: %d", len(result))
|
||||
}
|
||||
|
||||
expected := []string{"chain-a", "chain-b", "chain-a", "chain-b", "chain-a"}
|
||||
for i, edge := range result {
|
||||
if edge.Mac != expected[i] {
|
||||
t.Fatalf("unexpected edge at %d: %s", i, edge.Mac)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEdgeFilterIsEmpty(t *testing.T) {
|
||||
if !(*EdgeFilter)(nil).IsEmpty() {
|
||||
t.Fatal("nil filter should be empty")
|
||||
}
|
||||
if (&EdgeFilter{}).IsEmpty() != true {
|
||||
t.Fatal("empty filter should be empty")
|
||||
}
|
||||
if (&EdgeFilter{Prov: strPtr("")}).IsEmpty() != true {
|
||||
t.Fatal("filter with empty province should be empty")
|
||||
}
|
||||
if (&EdgeFilter{City: strPtr("")}).IsEmpty() != true {
|
||||
t.Fatal("filter with empty city should be empty")
|
||||
}
|
||||
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(0))}).IsEmpty() != true {
|
||||
t.Fatal("filter with zero ISP should be empty")
|
||||
}
|
||||
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(99))}).IsEmpty() != true {
|
||||
t.Fatal("filter with invalid ISP should be empty")
|
||||
}
|
||||
|
||||
prov := "江苏"
|
||||
if (&EdgeFilter{Prov: &prov}).IsEmpty() {
|
||||
t.Fatal("filter with province should not be empty")
|
||||
}
|
||||
isp := m.EdgeISPTelecom
|
||||
if (&EdgeFilter{Isp: &isp}).IsEmpty() {
|
||||
t.Fatal("filter with valid ISP should not be empty")
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(v string) *string {
|
||||
return &v
|
||||
}
|
||||
|
||||
func ispPtr(v m.EdgeISP) *m.EdgeISP {
|
||||
return &v
|
||||
}
|
||||
@@ -1,48 +0,0 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"platform/pkg/u"
|
||||
m "platform/web/models"
|
||||
q "platform/web/queries"
|
||||
)
|
||||
|
||||
var Edge = &edgeService{}
|
||||
|
||||
type edgeService struct{}
|
||||
|
||||
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
|
||||
do := q.Edge.Where(q.Edge.Type.Eq(int(m.EdgeTypeSelfBuilt)))
|
||||
if prov := u.N(filter.Prov); prov != nil {
|
||||
do = do.Where(q.Edge.Prov.Eq(*prov))
|
||||
}
|
||||
if city := u.N(filter.City); city != nil {
|
||||
do = do.Where(q.Edge.City.Eq(*city))
|
||||
}
|
||||
if isp := u.X(filter.Isp.String()); isp != nil {
|
||||
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
|
||||
}
|
||||
if count > 0 {
|
||||
do = do.Limit(count)
|
||||
}
|
||||
|
||||
edges, err := q.Edge.Where(do).Find()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return edges, nil
|
||||
}
|
||||
|
||||
type EdgeFilter struct {
|
||||
Isp *m.EdgeISP `json:"isp"`
|
||||
Prov *string `json:"prov"`
|
||||
City *string `json:"city"`
|
||||
}
|
||||
|
||||
func (f *EdgeFilter) IsEmpty() bool {
|
||||
if f == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
return u.X(f.Isp.String()) == nil && u.N(f.Prov) == nil && u.N(f.City) == nil
|
||||
}
|
||||
@@ -161,6 +161,17 @@ func (s *proxyService) Update(update *UpdateProxy) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *proxyService) SyncPool(id int32) error {
|
||||
proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(id)).Select(q.Proxy.ID, q.Proxy.IP).First()
|
||||
if err != nil {
|
||||
return core.NewServErr("获取代理数据失败", err)
|
||||
}
|
||||
if proxy == nil {
|
||||
return core.NewBizErr("代理不存在")
|
||||
}
|
||||
return rebuildFreeChans(id, proxy.IP.Addr)
|
||||
}
|
||||
|
||||
func (s *proxyService) Remove(id int32) error {
|
||||
used, err := hasUsedChans(id)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user