新增代理服务与边缘节点注册功能

This commit is contained in:
2025-05-13 18:48:17 +08:00
parent 536f36ae02
commit d69a77df38
17 changed files with 573 additions and 203 deletions

2
.gitignore vendored
View File

@@ -35,3 +35,5 @@ build/
.env .env
.env.* .env.*
!.env.example !.env.example
cmd/playground/

View File

@@ -28,6 +28,9 @@ COPY --from=builder /build/bin/proxy_linux_amd64 /app/proxy
# 设置可执行权限 # 设置可执行权限
RUN chmod +x /app/proxy RUN chmod +x /app/proxy
# 锁定文件
VOLUME /app/proxy.lock
# 声明暴露端口 # 声明暴露端口
EXPOSE 8080 EXPOSE 8080

View File

@@ -29,24 +29,54 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
其他级别日志就地打印Info 只用来跟踪关键流程 其他级别日志就地打印Info 只用来跟踪关键流程
### proxy.lock 文件格式
| mag_num(1) | name(16) |
|-------------|-----------|
| 魔法数,固定 0x72 | 服务名称uuid |
## 协议 ## 协议
### 建立连接 ### 步骤说明
客户端(控制通道): 1. 启动转发服务,尝试注册自身到后端服务,随后持续报告心跳
2. 启动边缘节点后,尝试注册自身到后端服务,随后持续报告心跳
3. 后端服务根据在线转发服务的状态,返回分配给边缘节点的转发服务地址
4. 边缘节点根据配置尝试连接到转发服务(建立控制通道)
5. 连接成功后,控制通道将长期保留,边缘节点定时发送保活心跳,代理服务丢弃所有心跳包
6. 当用户请求代理时,转发服务通过控制通道向边缘节点提供代理目标信息
7. 边缘节点尝试连接到目标地址,同时尝试建立数据通道
8. 当成功建立数据通道后,边缘节点将数据通道标识以及对目标地址的连接结果提供给转发服务
9. 如果连接成功建立,则开始代理流量,如果连接失败,则关闭数据通道
`version(1)` `id_len(1)` `id_buf(n)` ### 协议报文详情
服务端(控制通道): 协议中所有数值都以大端形式传输
`status(1)` #### 建立控制通道
### 开启代理 客户端:
服务端(控制通道): | version(1) | name_len(1) | name_buf(n) |
|------------|-------------|-------------|
| 版本号 | 名称长度 | 名称 |
`dst_len(1)` `dst_buf(n)` `tag_len(1)` `tag_buf(n)` 服务端:
客户端(数据通道): | status(1) |
|-----------|
| 状态,固定为 1 |
`status(1)` `tag_len(1)` `tag_buf(n)` #### 建立数据通道
服务端:
| tag(16) | dst_len(2) | dst_buf(n) |
|---------|------------|------------|
| 通道标识 | 目标地址长度 | 目标地址 |
客户端:
| tag(16) | status(1) |
|---------|-----------------------|
| 通道标识 | 目标地址连接结果,成功为 1不成功为 0 |

View File

@@ -2,99 +2,112 @@ package client
import ( import (
"bufio" "bufio"
"flag"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"net" "net"
"net/http"
"os" "os"
"proxy-server/client/core"
"proxy-server/client/geo"
"proxy-server/client/report"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"runtime"
"strconv"
"time" "time"
"errors"
"github.com/joho/godotenv" "github.com/joho/godotenv"
"github.com/pkg/errors"
_ "net/http/pprof" _ "net/http/pprof"
) )
const Version byte = 1 var Geo geo.Func = geo.Ipapi
type Config struct { func Start() error {
Name string
FwdHost string // 初始化环境变量
FwdCtrlPort uint slog.SetLogLoggerLevel(slog.LevelDebug)
FwdDataPort uint
RetryInterval uint err := godotenv.Load()
if err != nil {
slog.Debug("没有本地环境变量文件")
} else {
online := os.Getenv("ENDPOINT_ONLINE")
if online != "" {
core.EndpointOnline = online
}
offline := os.Getenv("ENDPOINT_OFFLINE")
if offline != "" {
core.EndpointOffline = offline
}
} }
var cfg Config
var frpCtrlAddr string
var frpDataAddr string
func Start() {
initLog()
initCmd()
initDevEnv()
frpCtrlAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdCtrlPort)))
frpDataAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdDataPort)))
// 性能监控 // 性能监控
go func() { // go func() {
runtime.SetBlockProfileRate(1) // runtime.SetBlockProfileRate(1)
err := http.ListenAndServe(":7070", nil) // err := http.ListenAndServe(":7070", nil)
// if err != nil {
// slog.Error("性能监控服务启动失败", "err", err)
// }
// }()
// 获取归属地
slog.Debug("获取节点归属地...")
prov, city, isp, err := Geo()
if err != nil { if err != nil {
slog.Error("性能监控服务启动失败", "err", err) slog.Error("获取归属地失败", "err", err)
}
// 注册节点
slog.Debug("注册节点...")
host, err := report.Online(prov, city, isp)
if err != nil {
slog.Error("节点注册失败", "err", err)
return err
} }
}()
// 建立控制通道 // 建立控制通道
for { for {
err := ctrl() err := ctrl(host)
if err != nil { if err != nil {
slog.Error("建立控制通道失败", err) slog.Error("建立控制通道失败", "err", err)
slog.Info(fmt.Sprintf("%d 秒后重试", cfg.RetryInterval)) slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval))
time.Sleep(time.Duration(cfg.RetryInterval) * time.Second) time.Sleep(time.Duration(core.RetryInterval) * time.Second)
} }
} }
} }
func ctrl() error { func ctrl(host string) error {
slog.Info("建立控制通道", "addr", frpCtrlAddr) ctrlAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdCtrlPort))
dataAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdDataPort))
conn, err := net.Dial("tcp", frpCtrlAddr) slog.Info("建立控制通道", "addr", ctrlAddr)
conn, err := net.Dial("tcp", ctrlAddr)
if err != nil { if err != nil {
return errors.Wrap(err, "连接失败") return errors.New("连接失败")
} }
defer utils.Close(conn) defer utils.Close(conn)
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// 请求转发端口 // 请求转发端口
_, err = conn.Write([]byte{Version}) _, err = conn.Write([]byte{core.Version})
if err != nil { if err != nil {
return errors.Wrap(err, "发送版本号失败") return errors.New("发送版本号失败")
} }
// 发送客户端名称 // 发送客户端名称
nameLen := byte(len(cfg.Name)) nameLen := byte(len(core.Name))
nameBuf := make([]byte, 1+nameLen) nameBuf := make([]byte, 1+nameLen)
nameBuf[0] = nameLen nameBuf[0] = nameLen
copy(nameBuf[1:], cfg.Name) copy(nameBuf[1:], core.Name)
_, err = conn.Write(nameBuf) _, err = conn.Write(nameBuf)
if err != nil { if err != nil {
return errors.Wrap(err, "发送 name 失败") return errors.New("发送 name 失败")
} }
// 等待服务端响应 // 等待服务端响应
respBuf, err := reader.ReadByte() respBuf, err := reader.ReadByte()
if err != nil { if err != nil {
return errors.Wrap(err, "接收响应失败") return errors.New("接收响应失败")
} }
if respBuf != 1 { if respBuf != 1 {
return errors.New("服务端响应失败") return errors.New("服务端响应失败")
@@ -110,40 +123,40 @@ func ctrl() error {
// 接收 dst // 接收 dst
dstLen, err := reader.ReadByte() dstLen, err := reader.ReadByte()
if err != nil { if err != nil {
return errors.Wrap(err, "接收 dstLen 失败") return errors.New("接收 dstLen 失败")
} }
dstBuf, err := utils.ReadBuffer(reader, int(dstLen)) dstBuf, err := utils.ReadBuffer(reader, int(dstLen))
if err != nil { if err != nil {
return errors.Wrap(err, "接收 dstBuf 失败") return errors.New("接收 dstBuf 失败")
} }
addr := string(dstBuf) addr := string(dstBuf)
// 接收 tag // 接收 tag
tagLen, err := reader.ReadByte() tagLen, err := reader.ReadByte()
if err != nil { if err != nil {
return errors.Wrap(err, "接收 tagLen 失败") return errors.New("接收 tagLen 失败")
} }
tagBuf, err := utils.ReadBuffer(reader, int(tagLen)) tagBuf, err := utils.ReadBuffer(reader, int(tagLen))
if err != nil { if err != nil {
return errors.Wrap(err, "接收 tagBuf 失败") return errors.New("接收 tagBuf 失败")
} }
// 建立数据通道 // 建立数据通道
go func() { go func() {
err := data(addr, tagBuf) err := data(dataAddr, addr, tagBuf)
if err != nil { if err != nil {
slog.Error("建立数据通道失败", err) slog.Error("建立数据通道失败", "err", err)
} }
}() }()
} }
} }
func data(addr string, tag []byte) error { func data(dataAddr string, dest string, tag []byte) error {
// 向服务端建立连接 // 向服务端建立连接
src, err := net.Dial("tcp", frpDataAddr) src, err := net.Dial("tcp", dataAddr)
if err != nil { if err != nil {
return errors.Wrap(err, "连接服务端失败") return errors.New("连接服务端失败")
} }
tagLen := byte(len(tag)) tagLen := byte(len(tag))
@@ -152,7 +165,7 @@ func data(addr string, tag []byte) error {
copy(tagBuf[2:], tag) copy(tagBuf[2:], tag)
// 向目标地址建立连接 // 向目标地址建立连接
dst, dstErr := net.Dial("tcp", addr) dst, dstErr := net.Dial("tcp", dest)
if dstErr != nil { if dstErr != nil {
tagBuf[0] = 0 tagBuf[0] = 0
} else { } else {
@@ -166,7 +179,7 @@ func data(addr string, tag []byte) error {
if dst != nil { if dst != nil {
utils.Close(dst) utils.Close(dst)
} }
return errors.Wrap(err, "发送连接状态失败") return errors.New("发送连接状态失败")
} }
if tagBuf[0] == 0 { if tagBuf[0] == 0 {
@@ -174,7 +187,7 @@ func data(addr string, tag []byte) error {
if dst != nil { if dst != nil {
utils.Close(dst) utils.Close(dst)
} }
return errors.Wrap(dstErr, "连接目标地址失败") return errors.New("连接目标地址失败")
} }
go func() { go func() {
@@ -193,31 +206,3 @@ func data(addr string, tag []byte) error {
}() }()
return nil return nil
} }
func initLog() {
slog.SetLogLoggerLevel(slog.LevelDebug)
}
func initCmd() {
flag.StringVar(&cfg.Name, "n", "", "客户端名称")
flag.StringVar(&cfg.FwdHost, "h", "", "转发服务器地址")
flag.UintVar(&cfg.FwdCtrlPort, "c", 18080, "转发服务器控制通道端口")
flag.UintVar(&cfg.FwdDataPort, "d", 18081, "转发服务器数据通道端口")
flag.UintVar(&cfg.RetryInterval, "r", 5, "重试间隔时间")
flag.Parse()
if cfg.Name == "" {
slog.Error("客户端名称不能为空")
flag.Usage()
os.Exit(1)
}
}
func initDevEnv() {
err := godotenv.Load()
if err != nil {
slog.Debug("没有本地环境变量文件")
}
cfg.FwdHost = os.Getenv("FWD_HOST")
}

12
client/core/env.go Normal file
View File

@@ -0,0 +1,12 @@
package core
const Version byte = 1
const Name = "test-edge"
var FwdCtrlPort uint = 18080
var FwdDataPort uint = 18081
var RetryInterval uint = 5
var EndpointOnline = "https://api.lanhuip.com/api/edge/online"
var EndpointOffline = "https://api.lanhuip.com/api/edge/offline"
var EndpointGeo = "http://cip.cc"

57
client/geo/cip.go Normal file
View File

@@ -0,0 +1,57 @@
package geo
import (
"bufio"
"github.com/pkg/errors"
"log/slog"
"net/http"
"net/textproto"
"strings"
)
func Cip() (prov, city, isp string, err error) {
const endpoint = "http://cip.cc"
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return "", "", "", errors.Wrap(err, "创建请求失败")
}
req.Header.Set("User-Agent", "curl/8.9.1")
req.Header.Set("Accept", "*/*")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", "", "", errors.Wrap(err, "请求失败")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", "", errors.New("请求失败,状态码: " + resp.Status)
}
reader := textproto.NewReader(bufio.NewReader(resp.Body))
_, err = reader.ReadLine()
if err != nil {
return "", "", "", errors.Wrap(err, "读取响应失败")
}
addrLine, err := reader.ReadLine()
if err != nil {
return "", "", "", errors.Wrap(err, "读取响应失败")
}
addr := strings.Split(strings.Split(addrLine, ":")[1], " ")
prov = strings.TrimSpace(addr[1])
city = strings.TrimSpace(addr[2])
ispLine, err := reader.ReadLine()
if err != nil {
return "", "", "", errors.Wrap(err, "读取响应失败")
}
isp = strings.TrimSpace(strings.Split(ispLine, ":")[1])
if prov == "" || city == "" || isp == "" {
return "", "", "", errors.New("解析数据为空")
}
slog.Debug("获取归属地", "prov", prov, "city", city, "isp", isp)
return prov, city, isp, nil
}

3
client/geo/geo.go Normal file
View File

@@ -0,0 +1,3 @@
package geo
type Func func() (prov, city, isp string, err error)

40
client/geo/ipapi.go Normal file
View File

@@ -0,0 +1,40 @@
package geo
import (
"encoding/json"
"github.com/pkg/errors"
"net/http"
)
func Ipapi() (prov, city, isp string, err error) {
const endpoint = "http://ip-api.com/json/?fields=regionName,city,as&lang=zh-CN"
resp, err := http.Get(endpoint)
if err != nil {
return "", "", "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", "", err
}
var data struct {
RegionName string `json:"regionName"`
City string `json:"city"`
As string `json:"as"`
}
err = json.NewDecoder(resp.Body).Decode(&data)
if err != nil {
return "", "", "", err
}
prov = data.RegionName
city = data.City
isp = data.As
if prov == "" || city == "" || isp == "" {
return "", "", "", errors.New("解析数据为空")
}
return prov, city, isp, nil
}

66
client/report/online.go Normal file
View File

@@ -0,0 +1,66 @@
package report
import (
"encoding/json"
"errors"
"io"
"net/http"
"proxy-server/client/core"
"strings"
)
func Online(prov, city, isp string) (host string, err error) {
var ispInt = 0
switch isp {
case "电信":
ispInt = 1
case "联通":
ispInt = 2
case "移动":
ispInt = 3
}
body, err := json.Marshal(map[string]any{
"prov": prov,
"city": city,
"isp": ispInt,
"name": core.Name,
"version": core.Version,
})
if err != nil {
return "", err
}
req, err := http.NewRequest("POST", core.EndpointOnline, strings.NewReader(string(body)))
if err != nil {
return "", errors.New("创建节点注册请求失败")
}
req.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", errors.New("节点注册失败")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", errors.New("节点注册失败,状态码: " + resp.Status)
}
bytes, err := io.ReadAll(resp.Body)
if err != nil {
return "", errors.New("读取节点注册响应失败")
}
var respBody struct {
Host string `json:"host"`
}
err = json.Unmarshal(bytes, &respBody)
if err != nil {
return "", errors.New("解析节点注册响应失败")
}
if respBody.Host == "" {
return "", errors.New("节点注册失败,响应体为空")
}
return respBody.Host, nil
}

View File

@@ -3,5 +3,8 @@ package main
import "proxy-server/client" import "proxy-server/client"
func main() { func main() {
client.Start() err := client.Start()
if err != nil {
println(err)
}
} }

View File

@@ -1,25 +1,41 @@
package main package main
import ( import (
"proxy-server/server/pkg/env" "gorm.io/driver/postgres"
"proxy-server/server/pkg/orm"
"gorm.io/gen" "gorm.io/gen"
"gorm.io/gorm"
"gorm.io/gorm/schema"
) )
func main() { func main() {
env.Init() // 初始化
orm.Init() db, _ := gorm.Open(
postgres.Open("host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai"),
g := gen.NewGenerator(gen.Config{ &gorm.Config{
OutPath: "../../temp-out", NamingStrategy: schema.NamingStrategy{
Mode: gen.WithoutContext | gen.WithDefaultQuery | gen.WithQueryInterface, SingularTable: true,
}) },
g.UseDB(orm.DB) },
g.ApplyBasic(
g.GenerateAllTable()...,
) )
g.Execute() g := gen.NewGenerator(gen.Config{
OutPath: "server/repo/queries",
ModelPkgPath: "models",
Mode: gen.WithDefaultQuery | gen.WithoutContext,
})
g.UseDB(db)
common := []gen.ModelOpt{
gen.FieldModify(func(field gen.Field) gen.Field {
if field.Type == "time.Time" {
field.Type = "orm.LocalDateTime"
}
return field
}),
}
// 生成需要的模型
g.GenerateModel("channel", common...)
g.GenerateModel("node", common...)
} }

View File

@@ -5,5 +5,7 @@ import (
) )
func main() { func main() {
server.Start() var app = server.New()
var err = app.Run()
println(err)
} }

1
cmd/server/proxy.lock Normal file
View File

@@ -0,0 +1 @@
rモ;鉦$4Nzゥ?ヘ

1
go.mod
View File

@@ -4,6 +4,7 @@ go 1.24
require ( require (
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/google/uuid v1.6.0
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/lmittmann/tint v1.0.7 github.com/lmittmann/tint v1.0.7
github.com/mattn/go-colorable v0.1.14 github.com/mattn/go-colorable v0.1.14

2
go.sum
View File

@@ -36,6 +36,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=

137
server/pkg/env/env.go vendored
View File

@@ -2,104 +2,139 @@ package env
import ( import (
"fmt" "fmt"
"log/slog"
"os" "os"
"strconv" "strconv"
"github.com/joho/godotenv"
) )
var ( var (
AppCtrlPort uint16 AppCtrlPort uint16 = 18080
AppDataPort uint16 AppDataPort uint16 = 18081
AppWebPort uint16 AppWebPort uint16 = 8848
AppLogMode string AppLogMode string = "dev"
ClientId string
ClientSecret string
DbHost string DbHost string
DbPort uint16 DbPort uint16 = 5432
DbDatabase string DbDatabase string
DbUsername string DbUsername string
DbPassword string DbPassword string
DbTimezone string DbTimezone string = "Asia/Shanghai"
EndpointOnline string
EndpointOffline string
) )
func Init() { func Init() {
var err = godotenv.Load()
// AppCtrlPort if err != nil {
appCtrlPortStr := os.Getenv("APP_CTRL_PORT") slog.Debug("没有本地环境变量文件")
if appCtrlPortStr == "" {
panic("环境变量 APP_CTRL_PORT 未设置")
} }
appCtrlPort, err := strconv.ParseUint(appCtrlPortStr, 10, 16) var value string
value = os.Getenv("APP_CTRL_PORT")
if value != "" {
appCtrlPort, err := strconv.Atoi(value)
if err != nil { if err != nil {
panic(fmt.Sprintf("环境变量 APP_CTRL_PORT 格式错误: %v", err)) panic(fmt.Sprintf("环境变量 APP_CTRL_PORT 格式错误: %v", err))
} }
AppCtrlPort = uint16(appCtrlPort) AppCtrlPort = uint16(appCtrlPort)
// AppDataPort
appDataPortStr := os.Getenv("APP_DATA_PORT")
if appDataPortStr == "" {
panic("环境变量 APP_DATA_PORT 未设置")
} }
appDataPort, err := strconv.ParseUint(appDataPortStr, 10, 16)
value = os.Getenv("APP_DATA_PORT")
if value != "" {
appDataPort, err := strconv.Atoi(value)
if err != nil { if err != nil {
panic(fmt.Sprintf("环境变量 APP_DATA_PORT 格式错误: %v", err)) panic(fmt.Sprintf("环境变量 APP_DATA_PORT 格式错误: %v", err))
} }
AppDataPort = uint16(appDataPort) AppDataPort = uint16(appDataPort)
// AppWebPort
appWebPortStr := os.Getenv("APP_WEB_PORT")
if appWebPortStr == "" {
appWebPortStr = "8848"
} }
appWebPort, err := strconv.ParseUint(appWebPortStr, 10, 16)
value = os.Getenv("APP_WEB_PORT")
if value != "" {
appWebPort, err := strconv.Atoi(value)
if err != nil { if err != nil {
panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err)) panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err))
} }
AppWebPort = uint16(appWebPort) AppWebPort = uint16(appWebPort)
// AppLogMode
appLogMode := os.Getenv("APP_LOG_MODE")
if appLogMode == "" {
AppLogMode = "dev"
} }
AppLogMode = appLogMode
// DbHost value = os.Getenv("APP_LOG_MODE")
if value != "" {
AppLogMode = value
}
value = os.Getenv("CLIENT_ID")
if value != "" {
ClientId = value
} else {
panic("环境变量 CLIENT_ID 未设置")
}
value = os.Getenv("CLIENT_SECRET")
if value != "" {
ClientSecret = value
} else {
panic("环境变量 CLIENT_SECRET 未设置")
}
value = os.Getenv("DB_HOST")
if value != "" {
DbHost = os.Getenv("DB_HOST") DbHost = os.Getenv("DB_HOST")
if DbHost == "" { } else {
panic("环境变量 DB_HOST 未设置") panic("环境变量 DB_HOST 未设置")
} }
// DbPort value = os.Getenv("DB_PORT")
dbPortStr := os.Getenv("DB_PORT") if value != "" {
if dbPortStr == "" { dbPort, err := strconv.Atoi(value)
dbPortStr = "5432"
}
dbPort, err := strconv.ParseUint(dbPortStr, 10, 16)
if err != nil { if err != nil {
panic(fmt.Sprintf("环境变量 DB_PORT 格式错误: %v", err)) panic(fmt.Sprintf("环境变量 DB_PORT 格式错误: %v", err))
} }
DbPort = uint16(dbPort) DbPort = uint16(dbPort)
}
// DbDatabase value = os.Getenv("DB_DATABASE")
DbDatabase = os.Getenv("DB_DATABASE") if value != "" {
if DbDatabase == "" { DbDatabase = value
} else {
panic("环境变量 DB_DATABASE 未设置") panic("环境变量 DB_DATABASE 未设置")
} }
// DbUsername value = os.Getenv("DB_USERNAME")
DbUsername = os.Getenv("DB_USERNAME") if value != "" {
if DbUsername == "" { DbUsername = value
} else {
panic("环境变量 DB_USERNAME 未设置") panic("环境变量 DB_USERNAME 未设置")
} }
// DbPassword value = os.Getenv("DB_PASSWORD")
DbPassword = os.Getenv("DB_PASSWORD") if value != "" {
if DbPassword == "" { DbPassword = value
} else {
panic("环境变量 DB_PASSWORD 未设置") panic("环境变量 DB_PASSWORD 未设置")
} }
// DbTimezone value = os.Getenv("DB_TIMEZONE")
DbTimezone = os.Getenv("DB_TIMEZONE") if value != "" {
if DbTimezone == "" { DbTimezone = value
DbTimezone = "Asia/Shanghai" }
value = os.Getenv("ENDPOINT_ONLINE")
if value != "" {
EndpointOnline = value
} else {
panic("环境变量 ENDPOINT_ONLINE 未设置")
}
value = os.Getenv("ENDPOINT_OFFLINE")
if value != "" {
EndpointOffline = value
} else {
panic("环境变量 ENDPOINT_OFFLINE 未设置")
} }
} }

View File

@@ -2,6 +2,8 @@ package server
import ( import (
"context" "context"
"encoding/base64"
"encoding/json"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "os"
@@ -13,43 +15,52 @@ import (
"proxy-server/server/pkg/log" "proxy-server/server/pkg/log"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"proxy-server/server/web" "proxy-server/server/web"
"runtime" "strings"
"sync" "sync"
"syscall" "syscall"
"time" "time"
"github.com/google/uuid"
"github.com/joho/godotenv" "github.com/joho/godotenv"
_ "net/http/pprof"
) )
import _ "net/http/pprof" const (
Version = 1
RestoreMagic = 0x72
)
type Context struct { type server struct {
context.Context name string
log *slog.Logger
} }
func Start() { func New() *server {
return &server{}
}
func (s *server) Run() (err error) {
// 初始化 // 初始化
err := godotenv.Load() err = s.init()
if err != nil { if err != nil {
println("没有本地环境变量文件") return err
} }
log.Init() // 恢复服务状态
env.Init() err = s.restore()
orm.Init() if err != nil {
return err
}
// 退出信号 // 准备子服务
osQuit := make(chan os.Signal)
signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM)
// 启动服务
slog.Info("启动服务")
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
// 转发服务
wg.Add(1) wg.Add(1)
fwdQuit := make(chan struct{}, 1) fwdQuit := make(chan struct{}, 1)
go func() { go func() {
@@ -57,19 +68,19 @@ func Start() {
defer close(fwdQuit) defer close(fwdQuit)
err := startFwdServer(ctx) err := startFwdServer(ctx)
if err != nil { if err != nil {
slog.Error("代理服务发生错误", "err", err) slog.Error("转发服务发生错误", "err", err)
} }
fwdQuit <- struct{}{} fwdQuit <- struct{}{}
}() }()
// 启动 web 服务 // 接口服务
wg.Add(1) wg.Add(1)
apiQuit := make(chan struct{}, 1) apiQuit := make(chan struct{}, 1)
go func() { go func() {
defer wg.Done() defer wg.Done()
err := startWebServer(ctx) err := startWebServer(ctx)
if err != nil { if err != nil {
slog.Error("web 服务发生错误", "err", err) slog.Error("接口服务发生错误", "err", err)
} }
apiQuit <- struct{}{} apiQuit <- struct{}{}
}() }()
@@ -80,15 +91,22 @@ func Start() {
}() }()
// 性能监控 // 性能监控
go func() { // go func() {
runtime.SetBlockProfileRate(1) // runtime.SetBlockProfileRate(1)
err := http.ListenAndServe(":6060", nil) // err := http.ListenAndServe(":6060", nil)
if err != nil { // if err != nil {
slog.Error("性能监控服务发生错误", "err", err) // slog.Error("性能监控服务发生错误", "err", err)
} // }
}() // }()
// 报告上线
slog.Debug("报告服务上线")
go reportOnline(ctx, s.name)
// 等待退出信号 // 等待退出信号
osQuit := make(chan os.Signal, 1)
signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM)
select { select {
case <-osQuit: case <-osQuit:
slog.Info("服务主动退出") slog.Info("服务主动退出")
@@ -97,6 +115,11 @@ func Start() {
case <-apiQuit: case <-apiQuit:
slog.Warn("web 服务异常退出") slog.Warn("web 服务异常退出")
} }
// 报告下线
slog.Debug("报告服务下线")
go reportOffline(ctx, s.name)
// 退出其他服务 // 退出其他服务
cancel() cancel()
@@ -109,6 +132,51 @@ func Start() {
case <-timeout.Done(): case <-timeout.Done():
slog.Warn("退出超时,强制退出") slog.Warn("退出超时,强制退出")
} }
return nil
}
func (s *server) restore() (err error) {
var file = "proxy.lock"
bytes, err := os.ReadFile(file)
if err != nil {
return err
}
if len(bytes) == 17 && bytes[0] == RestoreMagic {
s.name = uuid.UUID(bytes[1:]).String()
slog.Info("恢复服务名称", "name", s.name)
} else {
var u = uuid.New()
s.name = u.String()
bytes = make([]byte, 17)
bytes[0] = RestoreMagic
copy(bytes[1:], u[:])
err := os.WriteFile(file, bytes, 0644)
if err != nil {
return err
}
slog.Info("生成服务名称", "name", s.name)
}
return nil
}
func (s *server) init() error {
err := godotenv.Load()
if err != nil {
println("没有本地环境变量文件")
}
log.Init()
env.Init()
orm.Init()
return nil
} }
func startFwdServer(ctx context.Context) error { func startFwdServer(ctx context.Context) error {
@@ -124,3 +192,47 @@ func startFwdServer(ctx context.Context) error {
func startWebServer(ctx context.Context) error { func startWebServer(ctx context.Context) error {
return web.Start(ctx) return web.Start(ctx)
} }
func reportOnline(ctx context.Context, name string) {
reportRepeat(ctx, env.EndpointOnline, map[string]any{
"name": name,
"version": Version,
})
}
func reportOffline(ctx context.Context, name string) {
reportRepeat(ctx, env.EndpointOffline, map[string]any{
"name": name,
"version": Version,
})
}
func reportRepeat(ctx context.Context, endpoint string, body any) {
bodyStr, err := json.Marshal(body)
if err != nil {
panic(err)
}
for {
req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(bodyStr)))
if err != nil {
panic(err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Basic "+base64.RawURLEncoding.EncodeToString([]byte("proxy:proxy")))
resp, err := http.DefaultClient.Do(req)
if resp != nil && resp.StatusCode == http.StatusOK {
return
}
select {
case <-ctx.Done():
return
default:
}
slog.Warn("服务注册失败,五秒后重试", "err", err)
time.Sleep(5 * time.Second)
}
}