实现 gost 网关

This commit is contained in:
2026-06-08 17:24:55 +08:00
parent b00782b3f6
commit c5453557ae
13 changed files with 972 additions and 123 deletions

2
.gitignore vendored
View File

@@ -19,3 +19,5 @@ scripts/*
!scripts/env/dev/ !scripts/env/dev/
!scripts/pre/ !scripts/pre/
!scripts/sql/ !scripts/sql/
*/uploads/

View File

@@ -1,6 +1,6 @@
## TODO ## TODO
--- 上传文件平铺到 uploads不分子文件夹
错误提示增强,展示整链路信息 错误提示增强,展示整链路信息

5
pkg/env/env.go vendored
View File

@@ -45,6 +45,9 @@ var (
BaiyinCloudUrl string BaiyinCloudUrl string
BaiyinTokenUrl string BaiyinTokenUrl string
GostApiPort = 8900
GostApiPathPrefix = ""
IdenCallbackUrl string IdenCallbackUrl string
IdenAccessKey string IdenAccessKey string
IdenSecretKey string IdenSecretKey string
@@ -129,6 +132,8 @@ func Init() {
errs = append(errs, parse(&BaiyinCloudUrl, "BAIYIN_CLOUD_URL", false, nil)) errs = append(errs, parse(&BaiyinCloudUrl, "BAIYIN_CLOUD_URL", false, nil))
errs = append(errs, parse(&BaiyinTokenUrl, "BAIYIN_TOKEN_URL", false, nil)) errs = append(errs, parse(&BaiyinTokenUrl, "BAIYIN_TOKEN_URL", false, nil))
errs = append(errs, parse(&GostApiPort, "GOST_API_PORT", true, nil))
errs = append(errs, parse(&GostApiPathPrefix, "GOST_API_PATH_PREFIX", true, nil))
errs = append(errs, parse(&IdenCallbackUrl, "IDEN_CALLBACK_URL", false, nil)) errs = append(errs, parse(&IdenCallbackUrl, "IDEN_CALLBACK_URL", false, nil))
errs = append(errs, parse(&IdenAccessKey, "IDEN_ACCESS_KEY", false, nil)) errs = append(errs, parse(&IdenAccessKey, "IDEN_ACCESS_KEY", false, nil))

View File

@@ -625,7 +625,7 @@ comment on column proxy.mac is '代理服务名称';
comment on column proxy.ip is '代理服务地址'; comment on column proxy.ip is '代理服务地址';
comment on column proxy.host is '代理服务域名'; comment on column proxy.host is '代理服务域名';
comment on column proxy.secret is '代理服务密钥'; comment on column proxy.secret is '代理服务密钥';
comment on column proxy.type is '代理服务类型1-自有2-白银'; comment on column proxy.type is '代理服务类型1-自有2-白银3-GOST';
comment on column proxy.status is '代理服务状态0-离线1-在线'; comment on column proxy.status is '代理服务状态0-离线1-在线';
comment on column proxy.meta is '代理服务元信息'; comment on column proxy.meta is '代理服务元信息';
comment on column proxy.created_at is '创建时间'; comment on column proxy.created_at is '创建时间';
@@ -640,6 +640,7 @@ create table edge (
version int not null, version int not null,
mac text not null, mac text not null,
ip inet not null, ip inet not null,
port int,
isp int not null, isp int not null,
prov text not null, prov text not null,
city text not null, city text not null,
@@ -659,10 +660,11 @@ create index idx_edge_created_at on edge (created_at) where deleted_at is null;
-- edge表字段注释 -- edge表字段注释
comment on table edge is '节点表'; comment on table edge is '节点表';
comment on column edge.id is '节点ID'; comment on column edge.id is '节点ID';
comment on column edge.type is '节点类型1-自建'; comment on column edge.type is '节点类型1-自建2-GOST chain';
comment on column edge.version is '节点版本'; comment on column edge.version is '节点版本';
comment on column edge.mac is '节点 mac 地址'; comment on column edge.mac is '节点 mac 地址或 GOST chain 名称';
comment on column edge.ip is '节点地址'; comment on column edge.ip is '节点地址或 GOST chain addr 的 IP';
comment on column edge.port is 'GOST chain addr 的端口';
comment on column edge.isp is '运营商1-电信2-联通3-移动'; comment on column edge.isp is '运营商1-电信2-联通3-移动';
comment on column edge.prov is '省份'; comment on column edge.prov is '省份';
comment on column edge.city is '城市'; comment on column edge.city is '城市';

215
web/globals/gost.go Normal file
View File

@@ -0,0 +1,215 @@
package globals
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"platform/web/core"
"strings"
)
var ErrGostNotFound = errors.New("gost resource not found")
func IsGostNotFound(err error) bool {
return errors.Is(err, ErrGostNotFound)
}
type GostClient interface {
GetChain(name string) (*GostChainConfig, error)
CreateService(service *GostServiceConfig) error
DeleteService(name string) error
CreateAuther(auther *GostAutherConfig) error
DeleteAuther(name string) error
CreateAdmission(admission *GostAdmissionConfig) error
DeleteAdmission(name string) error
}
type gostClient struct {
baseURL string
pathPrefix string
username string
password string
}
var GostInitializer = func(host string, port int, pathPrefix, username, password string) GostClient {
baseURL := strings.TrimSpace(host)
if !strings.Contains(baseURL, "://") {
baseURL = fmt.Sprintf("http://%s:%d", baseURL, port)
}
return &gostClient{
baseURL: strings.TrimRight(baseURL, "/"),
pathPrefix: normalizeGostPathPrefix(pathPrefix),
username: username,
password: password,
}
}
func NewGost(host string, port int, pathPrefix, username, password string) GostClient {
return GostInitializer(host, port, pathPrefix, username, password)
}
type GostChainConfig struct {
Name string `json:"name"`
}
type GostServiceConfig struct {
Name string `json:"name"`
Addr string `json:"addr"`
Admission string `json:"admission,omitempty"`
Handler GostHandlerConfig `json:"handler"`
Listener GostListenerConfig `json:"listener"`
}
type GostHandlerConfig struct {
Type string `json:"type"`
Chain string `json:"chain,omitempty"`
Auther string `json:"auther,omitempty"`
}
type GostListenerConfig struct {
Type string `json:"type"`
}
type GostAutherConfig struct {
Name string `json:"name"`
Auths []GostAuthConfig `json:"auths"`
}
type GostAuthConfig struct {
Username string `json:"username"`
Password string `json:"password"`
}
type GostAdmissionConfig struct {
Name string `json:"name"`
Whitelist bool `json:"whitelist"`
Matchers []string `json:"matchers"`
}
func (c *gostClient) GetChain(name string) (*GostChainConfig, error) {
body, err := c.get("/config/chains/" + url.PathEscape(name))
if err != nil {
return nil, err
}
if len(body) == 0 {
return &GostChainConfig{Name: name}, nil
}
var direct GostChainConfig
if err := json.Unmarshal(body, &direct); err == nil && direct.Name != "" {
return &direct, nil
}
var wrapper struct {
Data *GostChainConfig `json:"data"`
}
if err := json.Unmarshal(body, &wrapper); err == nil && wrapper.Data != nil && wrapper.Data.Name != "" {
return wrapper.Data, nil
}
return &GostChainConfig{Name: name}, nil
}
func (c *gostClient) CreateService(service *GostServiceConfig) error {
return c.create("/config/services", service)
}
func (c *gostClient) DeleteService(name string) error {
return c.delete("/config/services/" + url.PathEscape(name))
}
func (c *gostClient) CreateAuther(auther *GostAutherConfig) error {
return c.create("/config/authers", auther)
}
func (c *gostClient) DeleteAuther(name string) error {
return c.delete("/config/authers/" + url.PathEscape(name))
}
func (c *gostClient) CreateAdmission(admission *GostAdmissionConfig) error {
return c.create("/config/admissions", admission)
}
func (c *gostClient) DeleteAdmission(name string) error {
return c.delete("/config/admissions/" + url.PathEscape(name))
}
func (c *gostClient) create(path string, payload any) error {
_, err := c.request(http.MethodPost, path, payload)
return err
}
func (c *gostClient) get(path string) ([]byte, error) {
body, err := c.request(http.MethodGet, path, nil)
if err != nil {
return nil, err
}
return body, nil
}
func (c *gostClient) delete(path string) error {
_, err := c.request(http.MethodDelete, path, nil)
return err
}
func (c *gostClient) request(method string, path string, payload any) ([]byte, error) {
var bodyReader io.Reader
if payload != nil {
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
bodyReader = bytes.NewReader(data)
}
req, err := http.NewRequest(method, c.endpoint(path), bodyReader)
if err != nil {
return nil, err
}
req.SetBasicAuth(c.username, c.password)
if payload != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := core.Fetch(req)
if err != nil {
return nil, err
}
defer func(Body io.ReadCloser) {
_ = Body.Close()
}(resp.Body)
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("%w: %s", ErrGostNotFound, string(body))
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return nil, fmt.Errorf("gost api %s %s failed: %d %s", method, path, resp.StatusCode, string(body))
}
return body, nil
}
func (c *gostClient) endpoint(path string) string {
return c.baseURL + c.pathPrefix + path
}
func normalizeGostPathPrefix(prefix string) string {
prefix = strings.TrimSpace(prefix)
if prefix == "" {
return ""
}
if !strings.HasPrefix(prefix, "/") {
prefix = "/" + prefix
}
return strings.TrimRight(prefix, "/")
}

96
web/globals/gost_test.go Normal file
View File

@@ -0,0 +1,96 @@
package globals
import (
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestGostClientCreateServiceUsesBasicAuthAndPathPrefix(t *testing.T) {
var (
gotPath string
gotAuth string
gotBody GostServiceConfig
)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
gotAuth = r.Header.Get("Authorization")
if err := json.NewDecoder(r.Body).Decode(&gotBody); err != nil {
t.Fatalf("Decode failed: %v", err)
}
w.WriteHeader(http.StatusOK)
}))
defer server.Close()
client := NewGost(server.URL, 0, "/api", "user", "pass")
err := client.CreateService(&GostServiceConfig{
Name: "svc-1",
Addr: ":10000",
Handler: GostHandlerConfig{
Type: "auto",
Chain: "chain-a",
Auther: "auther-a",
},
Listener: GostListenerConfig{Type: "tcp"},
})
if err != nil {
t.Fatalf("CreateService returned error: %v", err)
}
if gotPath != "/api/config/services" {
t.Fatalf("unexpected path: %s", gotPath)
}
if gotAuth != "Basic "+base64.StdEncoding.EncodeToString([]byte("user:pass")) {
t.Fatalf("unexpected auth header: %s", gotAuth)
}
if gotBody.Name != "svc-1" || gotBody.Handler.Type != "auto" || gotBody.Handler.Chain != "chain-a" {
t.Fatalf("unexpected request body: %+v", gotBody)
}
}
func TestGostClientDeleteServiceTreats404AsIdempotent(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}))
defer server.Close()
client := NewGost(server.URL, 0, "", "user", "pass")
if err := client.DeleteService("svc-1"); !IsGostNotFound(err) {
t.Fatalf("expected gost not found error, got: %v", err)
}
}
func TestGostClientGetChainReadsTopLevelName(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/config/chains/chain-a" {
t.Fatalf("unexpected path: %s", r.URL.Path)
}
_, _ = w.Write([]byte(`{"name":"chain-a"}`))
}))
defer server.Close()
client := NewGost(server.URL, 0, "", "user", "pass")
chain, err := client.GetChain("chain-a")
if err != nil {
t.Fatalf("GetChain returned error: %v", err)
}
if chain.Name != "chain-a" {
t.Fatalf("unexpected chain: %+v", chain)
}
}
func TestNormalizeGostPathPrefix(t *testing.T) {
if got := normalizeGostPathPrefix("api/"); got != "/api" {
t.Fatalf("unexpected prefix: %s", got)
}
if got := normalizeGostPathPrefix(""); got != "" {
t.Fatalf("unexpected empty prefix: %s", got)
}
if !strings.HasPrefix(normalizeGostPathPrefix("/v1"), "/") {
t.Fatal("expected normalized prefix to start with slash")
}
}

View File

@@ -8,10 +8,11 @@ import (
// Edge 节点表 // Edge 节点表
type Edge struct { type Edge struct {
core.Model core.Model
Type EdgeType `json:"type" gorm:"column:type"` // 节点类型1-自建 Type EdgeType `json:"type" gorm:"column:type"` // 节点类型1-自建2-GOST chain
Version int32 `json:"version" gorm:"column:version"` // 节点版本 Version int32 `json:"version" gorm:"column:version"` // 节点版本
Mac string `json:"mac" gorm:"column:mac"` // 节点 mac 地址 Mac string `json:"mac" gorm:"column:mac"` // 节点 mac 地址或 GOST chain 名称
IP orm.Inet `json:"ip" gorm:"column:ip;not null"` // 节点地址 IP orm.Inet `json:"ip" gorm:"column:ip;not null"` // 节点地址或 GOST chain addr 的 IP
Port *uint16 `json:"port,omitempty" gorm:"column:port"` // GOST chain addr 的端口
ISP EdgeISP `json:"isp" gorm:"column:isp"` // 运营商0-未知1-电信2-联通3-移动 ISP EdgeISP `json:"isp" gorm:"column:isp"` // 运营商0-未知1-电信2-联通3-移动
Prov string `json:"prov" gorm:"column:prov"` // 省份 Prov string `json:"prov" gorm:"column:prov"` // 省份
City string `json:"city" gorm:"column:city"` // 城市 City string `json:"city" gorm:"column:city"` // 城市
@@ -25,6 +26,7 @@ type EdgeType int
const ( const (
EdgeTypeSelfBuilt EdgeType = 1 // 自建 EdgeTypeSelfBuilt EdgeType = 1 // 自建
EdgeTypeGostChain EdgeType = 2 // GOST chain
) )
// EdgeStatus 节点状态枚举 // EdgeStatus 节点状态枚举

View File

@@ -28,6 +28,7 @@ type ProxyType int
const ( const (
ProxyTypeSelfHosted ProxyType = 1 // 自有 ProxyTypeSelfHosted ProxyType = 1 // 自有
ProxyTypeBaiYin ProxyType = 2 // 白银 ProxyTypeBaiYin ProxyType = 2 // 白银
ProxyTypeGost ProxyType = 3 // GOST
) )
// ProxyStatus 代理服务状态枚举 // ProxyStatus 代理服务状态枚举

View File

@@ -4,24 +4,28 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log/slog"
"math/rand/v2" "math/rand/v2"
"net/netip" "net/netip"
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/u" "platform/pkg/u"
"platform/web/core" "platform/web/core"
e "platform/web/events"
g "platform/web/globals" g "platform/web/globals"
m "platform/web/models" m "platform/web/models"
q "platform/web/queries" q "platform/web/queries"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/hibiken/asynq"
"github.com/redis/go-redis/v9" "github.com/redis/go-redis/v9"
"gorm.io/gen/field" "gorm.io/gen/field"
) )
// 通道服务 // 通道服务
var Channel = &channelServer{ var Channel = &channelServer{
provider: &channelBaiyinProvider{}, provider: &channelGostProvider{},
} }
type ChannelServiceProvider interface { type ChannelServiceProvider interface {
@@ -46,13 +50,77 @@ func (s *channelServer) ClearExpiredChannels(proxyId int32) (int, error) {
return s.provider.ClearExpiredChannels(proxyId) return s.provider.ClearExpiredChannels(proxyId)
} }
func lockChannelCreateKey(resourceNo string) string {
return fmt.Sprintf("platform:channel:create:%s", resourceNo)
}
func lockChannelRemoveKey(bid string) string {
return fmt.Sprintf("platform:batch:remove_expired:%s", bid)
}
func selectPorts(proxyId int32, batchNo string, count int, expire time.Time) ([]netip.AddrPort, error) {
chans, err := lockChans(proxyId, batchNo, count)
if err != nil {
return nil, core.NewBizErr("无可用通道,请稍后再试", err)
}
_, err = g.Asynq.Enqueue(
e.NewRemoveChannel(batchNo),
asynq.ProcessAt(expire),
)
if err != nil {
return nil, core.NewServErr("注册异步关闭通道任务失败", err)
}
return chans, nil
}
func selectProxyByType(proxyType m.ProxyType, count int) (*m.Proxy, error) {
proxies, err := q.Proxy.Where(
q.Proxy.Type.Eq(int(proxyType)),
q.Proxy.Status.Eq(int(m.ProxyStatusOnline)),
).Find()
if err != nil {
return nil, core.NewBizErr("获取可用代理失败", err)
}
if len(proxies) == 0 {
return nil, core.NewBizErr("无可用代理")
}
proxyIDs := make([]int32, 0, len(proxies))
proxyMap := make(map[int32]*m.Proxy, len(proxies))
for _, item := range proxies {
proxyIDs = append(proxyIDs, item.ID)
proxyMap[item.ID] = item
}
maxID := int32(0)
maxCount := -1
for _, id := range proxyIDs {
idCount, err := g.Redis.SCard(context.Background(), freeChansKey(id)).Result()
if err != nil {
return nil, core.NewServErr("查询可用通道数量失败", err)
}
if idCount > int64(maxCount) {
maxCount = int(idCount)
maxID = id
}
}
if maxCount < count {
return nil, core.NewBizErr("无可用代理")
}
return proxyMap[maxID], nil
}
func (s *channelServer) RefreshEdges() error { func (s *channelServer) RefreshEdges() error {
if env.RunMode != env.RunModeProd { if env.RunMode != env.RunModeProd {
return nil return nil
} }
// 找到所有网关 // 仅白银网关支持边缘节点刷新GOST 不参与此流程。
proxies, err := q.Proxy.Where( proxies, err := q.Proxy.Where(
q.Proxy.Type.Eq(int(m.ProxyTypeBaiYin)),
q.Proxy.Status.Eq(int(m.ProxyStatusOnline)), q.Proxy.Status.Eq(int(m.ProxyStatusOnline)),
).Find() ).Find()
if err != nil { if err != nil {
@@ -282,6 +350,83 @@ func usedChansKey(proxy int32, batch string) string {
return "channel:used:" + strconv.Itoa(int(proxy)) + ":" + batch return "channel:used:" + strconv.Itoa(int(proxy)) + ":" + batch
} }
type usedChanBatch struct {
ProxyID int32
Chans []netip.AddrPort
}
func findUsedChanBatch(batch string) (*usedChanBatch, error) {
keys, err := g.Redis.Keys(context.Background(), "channel:used:*:"+batch).Result()
if err != nil {
return nil, core.NewServErr("查询使用中通道失败", err)
}
key, ok, err := selectUsedChanBatchKey(batch, keys)
if err != nil {
return nil, err
}
if !ok {
return nil, nil
}
chans, err := g.Redis.LRange(context.Background(), key, 0, -1).Result()
if err != nil {
return nil, core.NewServErr("查询使用中通道失败", err)
}
return parseUsedChanBatch(key, chans)
}
func selectUsedChanBatchKey(batch string, keys []string) (string, bool, error) {
switch len(keys) {
case 0:
return "", false, nil
case 1:
return keys[0], true, nil
default:
slog.Error("batchNo 全局唯一约束被破坏", "batch", batch, "keys", keys)
return "", false, core.NewServErr(
fmt.Sprintf("检测到重复 usedChans 键batchNo 全局唯一被破坏: %s", batch),
fmt.Errorf("keys=%s", strings.Join(keys, ",")),
)
}
}
func parseUsedChanBatch(key string, chans []string) (*usedChanBatch, error) {
proxyID, err := parseUsedChansKey(key)
if err != nil {
return nil, err
}
addrs := make([]netip.AddrPort, len(chans))
for i, ch := range chans {
addr, err := netip.ParseAddrPort(ch)
if err != nil {
return nil, core.NewServErr(fmt.Sprintf("解析通道数据失败: %s", ch), err)
}
addrs[i] = addr
}
return &usedChanBatch{
ProxyID: proxyID,
Chans: addrs,
}, nil
}
func parseUsedChansKey(key string) (int32, error) {
parts := strings.Split(key, ":")
if len(parts) != 4 {
return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), nil)
}
proxyID, err := strconv.Atoi(parts[2])
if err != nil {
return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), err)
}
return int32(proxyID), nil
}
// 扩容通道 // 扩容通道
func regChans(proxy int32, chans []netip.AddrPort) error { func regChans(proxy int32, chans []netip.AddrPort) error {
strs := make([]any, len(chans)) strs := make([]any, len(chans))

View File

@@ -9,7 +9,6 @@ import (
"platform/pkg/env" "platform/pkg/env"
"platform/pkg/u" "platform/pkg/u"
"platform/web/core" "platform/web/core"
e "platform/web/events"
g "platform/web/globals" g "platform/web/globals"
"platform/web/globals/orm" "platform/web/globals/orm"
m "platform/web/models" m "platform/web/models"
@@ -17,7 +16,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/hibiken/asynq"
"gorm.io/gen" "gorm.io/gen"
"gorm.io/gen/field" "gorm.io/gen/field"
) )
@@ -213,41 +211,19 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error {
return g.Redsync.WithLock(lockChannelRemoveKey(batchNo), func() error { return g.Redsync.WithLock(lockChannelRemoveKey(batchNo), func() error {
start := time.Now() start := time.Now()
// 获取连接数据 batch, err := findUsedChanBatch(batchNo)
channels, err := q.Channel.Where(q.Channel.BatchNo.Eq(batchNo)).Find()
if err != nil { if err != nil {
return core.NewServErr(fmt.Sprintf("获取通道数据失败batch%s", batchNo), err) return err
} }
if len(channels) == 0 { if batch == nil {
slog.Warn(fmt.Sprintf("未找到通道数据,batch%s", batchNo)) slog.Debug("通道为空,跳过清理", "batch", batchNo)
return nil return nil
} }
proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(channels[0].ProxyID)).Take() configs := make([]*g.PortConfigsReq, len(batch.Chans))
if err != nil { for i, ch := range batch.Chans {
return core.NewServErr(fmt.Sprintf("获取代理数据失败batch%s", batchNo), err)
}
// 检查通道是否存在
chans, err := g.Redis.LRange(context.Background(), usedChansKey(proxy.ID, batchNo), 0, -1).Result()
if err != nil {
return core.NewServErr("查询使用中通道失败", err)
}
if len(chans) == 0 {
slog.Debug("通道为空,跳过清理", "key", usedChansKey(proxy.ID, batchNo))
return nil // 没有使用中通道,已经被清理过了
}
// 准备配置数据
configs := make([]*g.PortConfigsReq, len(chans))
for i, ch := range chans {
ap, err := netip.ParseAddrPort(ch)
if err != nil {
return core.NewServErr(fmt.Sprintf("解析通道数据失败: %s", ch), err)
}
configs[i] = &g.PortConfigsReq{ configs[i] = &g.PortConfigsReq{
Port: int(ap.Port()), Port: int(ch.Port()),
Edge: &[]string{}, Edge: &[]string{},
AutoEdgeConfig: &g.AutoEdgeConfig{Count: u.P(0)}, AutoEdgeConfig: &g.AutoEdgeConfig{Count: u.P(0)},
Status: false, Status: false,
@@ -256,6 +232,11 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error {
// 提交配置 // 提交配置
if env.RunMode == env.RunModeProd { if env.RunMode == env.RunModeProd {
proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(batch.ProxyID)).Take()
if err != nil {
return core.NewServErr("获取代理数据失败", err)
}
gateway, err := proxyGateway(proxy) gateway, err := proxyGateway(proxy)
if err != nil { if err != nil {
return core.NewServErr("创建代理网关失败", err) return core.NewServErr("创建代理网关失败", err)
@@ -271,13 +252,11 @@ func (s *channelBaiyinProvider) RemoveChannels(batchNo string) error {
} }
} }
// 释放端口 if err := freeChans(batch.ProxyID, batchNo); err != nil {
err = freeChans(proxy.ID, batchNo)
if err != nil {
return err return err
} }
slog.Debug("清除代理端口配置", "proxy", proxy.ID, "batch", batchNo, "duration", time.Since(start).String()) slog.Debug("清除代理端口配置", "proxy", batch.ProxyID, "batch", batchNo, "duration", time.Since(start).String())
return nil return nil
}) })
} }
@@ -335,53 +314,12 @@ func (s *channelBaiyinProvider) ClearExpiredChannels(proxyId int32) (int, error)
return len(batchSet), nil return len(batchSet), nil
} }
func lockChannelCreateKey(resourceNo string) string {
return fmt.Sprintf("platform:channel:create:%s", resourceNo)
}
func lockChannelRemoveKey(bid string) string {
return fmt.Sprintf("platform:batch:remove_expired:%s", bid)
}
func selectProxy(count int) (*m.Proxy, g.GatewayClient, error) { func selectProxy(count int) (*m.Proxy, g.GatewayClient, error) {
// 获取在线节点 proxy, err := selectProxyByType(m.ProxyTypeBaiYin, count)
proxies, err := q.Proxy.Where(
q.Proxy.Type.Eq(int(m.ProxyTypeBaiYin)),
q.Proxy.Status.Eq(int(m.ProxyStatusOnline)),
).Find()
if err != nil { if err != nil {
return nil, nil, core.NewBizErr("获取可用代理失败", err) return nil, nil, err
}
if len(proxies) == 0 {
return nil, nil, core.NewBizErr("无可用代理")
} }
proxyIDs := make([]int32, 0, len(proxies))
proxyMap := make(map[int32]*m.Proxy, len(proxies))
for _, item := range proxies {
proxyIDs = append(proxyIDs, item.ID)
proxyMap[item.ID] = item
}
// 获取最空闲节点
maxId := int32(0)
maxCount := -1
for _, id := range proxyIDs {
idCount, err := g.Redis.SCard(context.Background(), freeChansKey(id)).Result()
if err != nil {
return nil, nil, fmt.Errorf("查询可用通道数量失败: %w", err)
}
if idCount > int64(maxCount) {
maxCount = int(idCount)
maxId = id
}
}
if maxCount < count {
return nil, nil, core.NewBizErr("无可用代理")
}
proxy := proxyMap[maxId]
gateway, err := proxyGateway(proxy) gateway, err := proxyGateway(proxy)
if err != nil { if err != nil {
return nil, nil, core.NewServErr("创建代理网关失败", err) return nil, nil, core.NewServErr("创建代理网关失败", err)
@@ -390,23 +328,6 @@ func selectProxy(count int) (*m.Proxy, g.GatewayClient, error) {
return proxy, gateway, nil return proxy, gateway, nil
} }
func selectPorts(proxyId int32, batchNo string, count int, expire time.Time) ([]netip.AddrPort, error) {
chans, err := lockChans(proxyId, batchNo, count)
if err != nil {
return nil, core.NewBizErr("无可用通道,请稍后再试", err)
}
_, err = g.Asynq.Enqueue(
e.NewRemoveChannel(batchNo),
asynq.ProcessAt(expire),
)
if err != nil {
return nil, core.NewServErr("注册异步关闭通道任务失败", err)
}
return chans, nil
}
// ensureEdges 检查本地节点是否足够,如果不足从云端连入 // ensureEdges 检查本地节点是否足够,如果不足从云端连入
// 本地节点通过 Assigned = false 排除已分配节点 // 本地节点通过 Assigned = false 排除已分配节点
// 云端节点通过 NoRepeat = true 排除已分配节点 // 云端节点通过 NoRepeat = true 排除已分配节点

View File

@@ -0,0 +1,390 @@
package services
import (
"context"
"fmt"
"log/slog"
"net/netip"
"platform/pkg/env"
"platform/pkg/u"
"platform/web/core"
g "platform/web/globals"
"platform/web/globals/orm"
m "platform/web/models"
q "platform/web/queries"
"strings"
"time"
"gorm.io/gen"
"gorm.io/gen/field"
)
type channelGostProvider struct{}
func (s *channelGostProvider) CreateChannels(source netip.Addr, resourceNo string, authWhitelist bool, authPassword bool, count int, filter *EdgeFilter) ([]*m.Channel, error) {
now := time.Now()
batchNo := ID.GenReadable("bat")
channels := make([]*m.Channel, count)
if filter == nil {
filter = &EdgeFilter{}
}
err := g.Redsync.WithLock(lockChannelCreateKey(resourceNo), func() error {
resource, whitelists, err := ensure(now, source, resourceNo, authWhitelist, count)
if err != nil {
return err
}
user := resource.User
expire := now.Add(resource.Live)
proxy, err := s.selectProxy(count)
if err != nil {
return err
}
chans, err := selectPorts(proxy.ID, batchNo, count, expire)
if err != nil {
return err
}
edges, err := s.selectEdge(filter, count)
if err != nil {
return err
}
client, err := proxyGost(proxy)
if err != nil {
return err
}
admissions := make([]*g.GostAdmissionConfig, 0, count)
authers := make([]*g.GostAutherConfig, 0, count)
services := make([]*g.GostServiceConfig, count)
for i := range count {
ch := chans[i]
edge := edges[i]
port := ch.Port()
host := u.Else(proxy.Host, proxy.IP.String())
serviceName := gostServiceName(batchNo, port)
channel := &m.Channel{
UserID: user.ID,
ResourceID: resource.ID,
BatchNo: batchNo,
ProxyID: proxy.ID,
Host: host,
Port: port,
EdgeID: u.P(edge.ID),
EdgeRef: u.P(serviceName),
FilterISP: filter.Isp,
FilterProv: filter.Prov,
FilterCity: filter.City,
IP: u.P(edge.IP),
ExpiredAt: expire,
Proxy: proxy,
}
service := &g.GostServiceConfig{
Name: serviceName,
Addr: fmt.Sprintf(":%d", port),
Handler: g.GostHandlerConfig{
Type: "auto",
Chain: edge.Mac,
},
Listener: g.GostListenerConfig{
Type: "tcp",
},
}
if authWhitelist {
channel.Whitelists = u.P(strings.Join(whitelists, ","))
service.Admission = gostAdmissionName(batchNo, port)
admission := &g.GostAdmissionConfig{
Name: service.Admission,
Whitelist: true,
Matchers: whitelists,
}
admissions = append(admissions, admission)
}
if authPassword {
username, password := genPassPair()
channel.Username = &username
channel.Password = &password
service.Handler.Auther = gostAutherName(batchNo, port)
auther := &g.GostAutherConfig{
Name: service.Handler.Auther,
Auths: []g.GostAuthConfig{{
Username: username,
Password: password,
}},
}
authers = append(authers, auther)
}
services[i] = service
channels[i] = channel
}
for _, admission := range admissions {
if err := client.CreateAdmission(admission); err != nil {
return core.NewServErr(fmt.Sprintf("创建 GOST admission 失败: %s", admission.Name), err)
}
}
for _, auther := range authers {
if err := client.CreateAuther(auther); err != nil {
return core.NewServErr(fmt.Sprintf("创建 GOST auther 失败: %s", auther.Name), err)
}
}
for _, service := range services {
if err := client.CreateService(service); err != nil {
return core.NewServErr(fmt.Sprintf("创建 GOST service 失败: %s", service.Name), err)
}
}
err = q.Q.Transaction(func(tx *q.Query) error {
var result gen.ResultInfo
var err error
switch resource.Type {
case m.ResourceTypeShort:
result, err = tx.ResourceShort.
Where(
tx.ResourceShort.ID.Eq(*resource.ShortId),
tx.ResourceShort.Used.Eq(resource.Used),
tx.ResourceShort.Daily.Eq(resource.Daily),
).
UpdateSimple(
tx.ResourceShort.Used.Add(int32(count)),
tx.ResourceShort.Daily.Value(int32(resource.Today+count)),
tx.ResourceShort.LastAt.Value(now),
)
case m.ResourceTypeLong:
result, err = tx.ResourceLong.
Where(
tx.ResourceLong.ID.Eq(*resource.LongId),
tx.ResourceLong.Used.Eq(resource.Used),
tx.ResourceLong.Daily.Eq(resource.Daily),
).
UpdateSimple(
tx.ResourceLong.Used.Add(int32(count)),
tx.ResourceLong.Daily.Value(int32(resource.Today+count)),
tx.ResourceLong.LastAt.Value(now),
)
default:
return core.NewBizErr("套餐类型不正确,无法更新")
}
if err != nil {
return core.NewServErr("更新套餐使用记录失败", err)
}
if result.RowsAffected == 0 {
return core.NewBizErr("套餐状态已过期")
}
if err := tx.Channel.Omit(field.AssociationFields).Create(channels...); err != nil {
return core.NewServErr("保存通道失败", err)
}
if err := tx.LogsUserUsage.Create(&m.LogsUserUsage{
UserID: user.ID,
ResourceID: resource.ID,
BatchNo: batchNo,
Count: int32(count),
ISP: u.X(filter.Isp.String()),
Prov: filter.Prov,
City: filter.City,
IP: orm.Inet{Addr: source},
Time: now,
}); err != nil {
return core.NewServErr("保存用户使用记录失败", err)
}
return nil
})
if err != nil {
return err
}
return nil
})
if err != nil {
return nil, err
}
return channels, nil
}
func (s *channelGostProvider) RemoveChannels(batchNo string) error {
return g.Redsync.WithLock(lockChannelRemoveKey(batchNo), func() error {
start := time.Now()
batch, err := findUsedChanBatch(batchNo)
if err != nil {
return err
}
if batch == nil {
slog.Debug("通道为空,跳过清理", "batch", batchNo)
return nil
}
if env.RunMode == env.RunModeProd {
proxy, err := q.Proxy.Where(q.Proxy.ID.Eq(batch.ProxyID)).Take()
if err != nil {
return core.NewServErr("获取代理数据失败", err)
}
client, err := proxyGost(proxy)
if err != nil {
return core.NewServErr("创建 GOST 客户端失败", err)
}
var deleteErrs []error
for _, ch := range batch.Chans {
port := ch.Port()
serviceName := gostServiceName(batchNo, port)
deleteErrs = append(deleteErrs, deleteGostResource("service", serviceName, func() error {
return client.DeleteService(serviceName)
}))
autherName := gostAutherName(batchNo, port)
deleteErrs = append(deleteErrs, deleteGostResource("auther", autherName, func() error {
return client.DeleteAuther(autherName)
}))
admissionName := gostAdmissionName(batchNo, port)
deleteErrs = append(deleteErrs, deleteGostResource("admission", admissionName, func() error {
return client.DeleteAdmission(admissionName)
}))
}
if err := u.CombineErrors(deleteErrs); err != nil {
return err
}
}
if err := freeChans(batch.ProxyID, batchNo); err != nil {
return err
}
slog.Debug("清除 GOST 端口配置", "proxy", batch.ProxyID, "batch", batchNo, "duration", time.Since(start).String())
return nil
})
}
func (s *channelGostProvider) ClearExpiredChannels(proxyId int32) (int, error) {
now := time.Now()
keys, err := g.Redis.Keys(context.Background(), usedChansKey(proxyId, "*")).Result()
if err != nil {
return 0, core.NewServErr("查询使用中通道失败", err)
}
if len(keys) == 0 {
return 0, nil
}
batchList := make([]string, len(keys))
batchSet := make(map[string]struct{}, len(keys))
for i, key := range keys {
parts := strings.Split(key, ":")
if len(parts) != 4 {
return 0, core.NewServErr(fmt.Sprintf("使用中通道键格式错误: %s", key), nil)
}
batchList[i] = parts[3]
batchSet[parts[3]] = struct{}{}
}
var batchQueried []struct{ BatchNo string }
err = q.Channel.
Select(q.Channel.BatchNo).
Where(
q.Channel.BatchNo.In(batchList...),
q.Channel.ExpiredAt.Gte(now.UTC()),
).
Group(q.Channel.BatchNo).
Scan(&batchQueried)
if err != nil {
return 0, core.NewServErr("查询过期通道失败", err)
}
for _, batch := range batchQueried {
delete(batchSet, batch.BatchNo)
}
slog.Info("批量清理过期 GOST 通道", "count", len(batchSet))
for batchNo := range batchSet {
if err := s.RemoveChannels(batchNo); err != nil {
slog.Error("清理过期 GOST 通道失败", "batch", batchNo, "error", err)
}
}
return len(batchSet), nil
}
func (s *channelGostProvider) selectProxy(count int) (*m.Proxy, error) {
return selectProxyByType(m.ProxyTypeGost, count)
}
func (s *channelGostProvider) selectEdge(filter *EdgeFilter, count int) ([]*m.Edge, error) {
if filter == nil {
filter = &EdgeFilter{}
}
do := q.Edge.Where(
q.Edge.Type.Eq(int(m.EdgeTypeGostChain)),
q.Edge.Status.Eq(int(m.EdgeStatusNormal)),
)
if prov := u.N(filter.Prov); prov != nil {
do = do.Where(q.Edge.Prov.Eq(*prov))
}
if city := u.N(filter.City); city != nil {
do = do.Where(q.Edge.City.Eq(*city))
}
if isp := u.X(filter.Isp.String()); isp != nil {
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
}
edges, err := q.Edge.Where(do).Order(q.Edge.ID).Limit(count).Find()
if err != nil {
return nil, core.NewBizErr("查询可用节点失败", err)
}
return expandGostEdges(edges, count)
}
func expandGostEdges(edges []*m.Edge, count int) ([]*m.Edge, error) {
if len(edges) == 0 {
return nil, core.NewBizErr("地区可用节点数量不足")
}
result := make([]*m.Edge, count)
for i := range count {
result[i] = edges[i%len(edges)]
}
return result, nil
}
func proxyGost(proxy *m.Proxy) (g.GostClient, error) {
secret := strings.Split(u.Z(proxy.Secret), ":")
if len(secret) != 2 {
return nil, core.NewServErr(fmt.Sprintf("代理 %s 密钥格式错误", proxy.IP.String()), nil)
}
host := u.Else(proxy.Host, proxy.IP.String())
return g.NewGost(host, env.GostApiPort, env.GostApiPathPrefix, secret[0], secret[1]), nil
}
func deleteGostResource(kind string, name string, deleteFn func() error) error {
if err := deleteFn(); err != nil && !g.IsGostNotFound(err) {
return core.NewServErr(fmt.Sprintf("删除 GOST %s 配置失败: %s", kind, name), err)
}
return nil
}
func gostServiceName(batchNo string, port uint16) string {
return fmt.Sprintf("gost-svc-%s-%d", batchNo, port)
}
func gostAutherName(batchNo string, port uint16) string {
return fmt.Sprintf("gost-auther-%s-%d", batchNo, port)
}
func gostAdmissionName(batchNo string, port uint16) string {
return fmt.Sprintf("gost-adm-%s-%d", batchNo, port)
}

View File

@@ -0,0 +1,74 @@
package services
import (
"testing"
m "platform/web/models"
)
func TestExpandGostEdgesRejectsEmpty(t *testing.T) {
_, err := expandGostEdges(nil, 1)
if err == nil {
t.Fatal("expected error, got nil")
}
}
func TestExpandGostEdgesReusesWhenInsufficient(t *testing.T) {
edges := []*m.Edge{
{Mac: "chain-a"},
{Mac: "chain-b"},
}
result, err := expandGostEdges(edges, 5)
if err != nil {
t.Fatalf("expandGostEdges returned error: %v", err)
}
if len(result) != 5 {
t.Fatalf("unexpected edge count: %d", len(result))
}
expected := []string{"chain-a", "chain-b", "chain-a", "chain-b", "chain-a"}
for i, edge := range result {
if edge.Mac != expected[i] {
t.Fatalf("unexpected edge at %d: %s", i, edge.Mac)
}
}
}
func TestEdgeFilterIsEmpty(t *testing.T) {
if !(*EdgeFilter)(nil).IsEmpty() {
t.Fatal("nil filter should be empty")
}
if (&EdgeFilter{}).IsEmpty() != true {
t.Fatal("empty filter should be empty")
}
if (&EdgeFilter{Prov: strPtr("")}).IsEmpty() != true {
t.Fatal("filter with empty province should be empty")
}
if (&EdgeFilter{City: strPtr("")}).IsEmpty() != true {
t.Fatal("filter with empty city should be empty")
}
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(0))}).IsEmpty() != true {
t.Fatal("filter with zero ISP should be empty")
}
if (&EdgeFilter{Isp: ispPtr(m.ToEdgeISP(99))}).IsEmpty() != true {
t.Fatal("filter with invalid ISP should be empty")
}
prov := "江苏"
if (&EdgeFilter{Prov: &prov}).IsEmpty() {
t.Fatal("filter with province should not be empty")
}
isp := m.EdgeISPTelecom
if (&EdgeFilter{Isp: &isp}).IsEmpty() {
t.Fatal("filter with valid ISP should not be empty")
}
}
func strPtr(v string) *string {
return &v
}
func ispPtr(v m.EdgeISP) *m.EdgeISP {
return &v
}

View File

@@ -11,14 +11,14 @@ var Edge = &edgeService{}
type edgeService struct{} type edgeService struct{}
func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) { func (s *edgeService) AllEdges(count int, filter EdgeFilter) ([]*m.Edge, error) {
do := q.Edge.Where() do := q.Edge.Where(q.Edge.Type.Eq(int(m.EdgeTypeSelfBuilt)))
if filter.Prov != nil { if prov := u.N(filter.Prov); prov != nil {
do = do.Where(q.Edge.Prov.Eq(*filter.Prov)) do = do.Where(q.Edge.Prov.Eq(*prov))
} }
if filter.City != nil { if city := u.N(filter.City); city != nil {
do = do.Where(q.Edge.City.Eq(*filter.City)) do = do.Where(q.Edge.City.Eq(*city))
} }
if filter.Isp != nil { if isp := u.X(filter.Isp.String()); isp != nil {
do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp))) do = do.Where(q.Edge.ISP.Eq(int(*filter.Isp)))
} }
if count > 0 { if count > 0 {
@@ -44,9 +44,5 @@ func (f *EdgeFilter) IsEmpty() bool {
return true return true
} }
if f.Isp.String() == "" || u.Z(f.Prov) != "" || u.Z(f.City) != "" { return u.X(f.Isp.String()) == nil && u.N(f.Prov) == nil && u.N(f.City) == nil
return false
}
return false
} }