diff --git a/.gitignore b/.gitignore index 9668f99..8ccd745 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,6 @@ build/ .env .env.* -!.env.example \ No newline at end of file +!.env.example + +cmd/playground/ \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index fd64a33..3152709 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,6 +28,9 @@ COPY --from=builder /build/bin/proxy_linux_amd64 /app/proxy # 设置可执行权限 RUN chmod +x /app/proxy +# 锁定文件 +VOLUME /app/proxy.lock + # 声明暴露端口 EXPOSE 8080 diff --git a/README.md b/README.md index 4a759ee..909af2c 100644 --- a/README.md +++ b/README.md @@ -29,24 +29,54 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如 其他级别日志就地打印,Info 只用来跟踪关键流程 +### proxy.lock 文件格式 + +| mag_num(1) | name(16) | +|-------------|-----------| +| 魔法数,固定 0x72 | 服务名称,uuid | + ## 协议 -### 建立连接 +### 步骤说明 -客户端(控制通道): +1. 启动转发服务,尝试注册自身到后端服务,随后持续报告心跳 +2. 启动边缘节点后,尝试注册自身到后端服务,随后持续报告心跳 +3. 后端服务根据在线转发服务的状态,返回分配给边缘节点的转发服务地址 +4. 边缘节点根据配置尝试连接到转发服务(建立控制通道) +5. 连接成功后,控制通道将长期保留,边缘节点定时发送保活心跳,代理服务丢弃所有心跳包 +6. 当用户请求代理时,转发服务通过控制通道向边缘节点提供代理目标信息 +7. 边缘节点尝试连接到目标地址,同时尝试建立数据通道 +8. 当成功建立数据通道后,边缘节点将数据通道标识以及对目标地址的连接结果提供给转发服务 +9. 如果连接成功建立,则开始代理流量,如果连接失败,则关闭数据通道 -`version(1)` `id_len(1)` `id_buf(n)` +### 协议报文详情 -服务端(控制通道): +协议中所有数值都以大端形式传输 -`status(1)` +#### 建立控制通道 -### 开启代理 +客户端: -服务端(控制通道): +| version(1) | name_len(1) | name_buf(n) | +|------------|-------------|-------------| +| 版本号 | 名称长度 | 名称 | -`dst_len(1)` `dst_buf(n)` `tag_len(1)` `tag_buf(n)` +服务端: -客户端(数据通道): +| status(1) | +|-----------| +| 状态,固定为 1 | -`status(1)` `tag_len(1)` `tag_buf(n)` \ No newline at end of file +#### 建立数据通道 + +服务端: + +| tag(16) | dst_len(2) | dst_buf(n) | +|---------|------------|------------| +| 通道标识 | 目标地址长度 | 目标地址 | + +客户端: + +| tag(16) | status(1) | +|---------|-----------------------| +| 通道标识 | 目标地址连接结果,成功为 1,不成功为 0 | diff --git a/client/client.go b/client/client.go index 246b151..d62059f 100644 --- a/client/client.go +++ b/client/client.go @@ -2,99 +2,112 @@ package client import ( "bufio" - "flag" "fmt" "io" "log/slog" "net" - "net/http" "os" + "proxy-server/client/core" + "proxy-server/client/geo" + "proxy-server/client/report" "proxy-server/pkg/utils" - "runtime" - "strconv" "time" + "errors" "github.com/joho/godotenv" - "github.com/pkg/errors" _ "net/http/pprof" ) -const Version byte = 1 +var Geo geo.Func = geo.Ipapi -type Config struct { - Name string - FwdHost string - FwdCtrlPort uint - FwdDataPort uint - RetryInterval uint -} +func Start() error { -var cfg Config + // 初始化环境变量 + slog.SetLogLoggerLevel(slog.LevelDebug) -var frpCtrlAddr string -var frpDataAddr string - -func Start() { - - initLog() - initCmd() - initDevEnv() - - frpCtrlAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdCtrlPort))) - frpDataAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdDataPort))) + err := godotenv.Load() + if err != nil { + slog.Debug("没有本地环境变量文件") + } else { + online := os.Getenv("ENDPOINT_ONLINE") + if online != "" { + core.EndpointOnline = online + } + offline := os.Getenv("ENDPOINT_OFFLINE") + if offline != "" { + core.EndpointOffline = offline + } + } // 性能监控 - go func() { - runtime.SetBlockProfileRate(1) - err := http.ListenAndServe(":7070", nil) - if err != nil { - slog.Error("性能监控服务启动失败", "err", err) - } - }() + // go func() { + // runtime.SetBlockProfileRate(1) + // err := http.ListenAndServe(":7070", nil) + // if err != nil { + // slog.Error("性能监控服务启动失败", "err", err) + // } + // }() + + // 获取归属地 + slog.Debug("获取节点归属地...") + prov, city, isp, err := Geo() + if err != nil { + slog.Error("获取归属地失败", "err", err) + } + + // 注册节点 + slog.Debug("注册节点...") + host, err := report.Online(prov, city, isp) + if err != nil { + slog.Error("节点注册失败", "err", err) + return err + } // 建立控制通道 for { - err := ctrl() + err := ctrl(host) if err != nil { - slog.Error("建立控制通道失败", err) - slog.Info(fmt.Sprintf("%d 秒后重试", cfg.RetryInterval)) - time.Sleep(time.Duration(cfg.RetryInterval) * time.Second) + slog.Error("建立控制通道失败", "err", err) + slog.Info(fmt.Sprintf("%d 秒后重试", core.RetryInterval)) + time.Sleep(time.Duration(core.RetryInterval) * time.Second) } } } -func ctrl() error { - slog.Info("建立控制通道", "addr", frpCtrlAddr) +func ctrl(host string) error { + ctrlAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdCtrlPort)) + dataAddr := net.JoinHostPort(host, fmt.Sprintf("%d", core.FwdDataPort)) - conn, err := net.Dial("tcp", frpCtrlAddr) + slog.Info("建立控制通道", "addr", ctrlAddr) + conn, err := net.Dial("tcp", ctrlAddr) if err != nil { - return errors.Wrap(err, "连接失败") + return errors.New("连接失败") } defer utils.Close(conn) reader := bufio.NewReader(conn) // 请求转发端口 - _, err = conn.Write([]byte{Version}) + _, err = conn.Write([]byte{core.Version}) if err != nil { - return errors.Wrap(err, "发送版本号失败") + return errors.New("发送版本号失败") } // 发送客户端名称 - nameLen := byte(len(cfg.Name)) + nameLen := byte(len(core.Name)) nameBuf := make([]byte, 1+nameLen) nameBuf[0] = nameLen - copy(nameBuf[1:], cfg.Name) + copy(nameBuf[1:], core.Name) _, err = conn.Write(nameBuf) if err != nil { - return errors.Wrap(err, "发送 name 失败") + return errors.New("发送 name 失败") } // 等待服务端响应 respBuf, err := reader.ReadByte() if err != nil { - return errors.Wrap(err, "接收响应失败") + return errors.New("接收响应失败") } if respBuf != 1 { return errors.New("服务端响应失败") @@ -110,40 +123,40 @@ func ctrl() error { // 接收 dst dstLen, err := reader.ReadByte() if err != nil { - return errors.Wrap(err, "接收 dstLen 失败") + return errors.New("接收 dstLen 失败") } dstBuf, err := utils.ReadBuffer(reader, int(dstLen)) if err != nil { - return errors.Wrap(err, "接收 dstBuf 失败") + return errors.New("接收 dstBuf 失败") } addr := string(dstBuf) // 接收 tag tagLen, err := reader.ReadByte() if err != nil { - return errors.Wrap(err, "接收 tagLen 失败") + return errors.New("接收 tagLen 失败") } tagBuf, err := utils.ReadBuffer(reader, int(tagLen)) if err != nil { - return errors.Wrap(err, "接收 tagBuf 失败") + return errors.New("接收 tagBuf 失败") } // 建立数据通道 go func() { - err := data(addr, tagBuf) + err := data(dataAddr, addr, tagBuf) if err != nil { - slog.Error("建立数据通道失败", err) + slog.Error("建立数据通道失败", "err", err) } }() } } -func data(addr string, tag []byte) error { +func data(dataAddr string, dest string, tag []byte) error { // 向服务端建立连接 - src, err := net.Dial("tcp", frpDataAddr) + src, err := net.Dial("tcp", dataAddr) if err != nil { - return errors.Wrap(err, "连接服务端失败") + return errors.New("连接服务端失败") } tagLen := byte(len(tag)) @@ -152,7 +165,7 @@ func data(addr string, tag []byte) error { copy(tagBuf[2:], tag) // 向目标地址建立连接 - dst, dstErr := net.Dial("tcp", addr) + dst, dstErr := net.Dial("tcp", dest) if dstErr != nil { tagBuf[0] = 0 } else { @@ -166,7 +179,7 @@ func data(addr string, tag []byte) error { if dst != nil { utils.Close(dst) } - return errors.Wrap(err, "发送连接状态失败") + return errors.New("发送连接状态失败") } if tagBuf[0] == 0 { @@ -174,7 +187,7 @@ func data(addr string, tag []byte) error { if dst != nil { utils.Close(dst) } - return errors.Wrap(dstErr, "连接目标地址失败") + return errors.New("连接目标地址失败") } go func() { @@ -193,31 +206,3 @@ func data(addr string, tag []byte) error { }() return nil } - -func initLog() { - slog.SetLogLoggerLevel(slog.LevelDebug) -} - -func initCmd() { - flag.StringVar(&cfg.Name, "n", "", "客户端名称") - flag.StringVar(&cfg.FwdHost, "h", "", "转发服务器地址") - flag.UintVar(&cfg.FwdCtrlPort, "c", 18080, "转发服务器控制通道端口") - flag.UintVar(&cfg.FwdDataPort, "d", 18081, "转发服务器数据通道端口") - flag.UintVar(&cfg.RetryInterval, "r", 5, "重试间隔时间") - flag.Parse() - - if cfg.Name == "" { - slog.Error("客户端名称不能为空") - flag.Usage() - os.Exit(1) - } -} - -func initDevEnv() { - err := godotenv.Load() - if err != nil { - slog.Debug("没有本地环境变量文件") - } - - cfg.FwdHost = os.Getenv("FWD_HOST") -} diff --git a/client/core/env.go b/client/core/env.go new file mode 100644 index 0000000..8d8a615 --- /dev/null +++ b/client/core/env.go @@ -0,0 +1,12 @@ +package core + +const Version byte = 1 +const Name = "test-edge" + +var FwdCtrlPort uint = 18080 +var FwdDataPort uint = 18081 +var RetryInterval uint = 5 + +var EndpointOnline = "https://api.lanhuip.com/api/edge/online" +var EndpointOffline = "https://api.lanhuip.com/api/edge/offline" +var EndpointGeo = "http://cip.cc" diff --git a/client/geo/cip.go b/client/geo/cip.go new file mode 100644 index 0000000..bffe221 --- /dev/null +++ b/client/geo/cip.go @@ -0,0 +1,57 @@ +package geo + +import ( + "bufio" + "github.com/pkg/errors" + "log/slog" + "net/http" + "net/textproto" + "strings" +) + +func Cip() (prov, city, isp string, err error) { + const endpoint = "http://cip.cc" + + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return "", "", "", errors.Wrap(err, "创建请求失败") + } + req.Header.Set("User-Agent", "curl/8.9.1") + req.Header.Set("Accept", "*/*") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", "", "", errors.Wrap(err, "请求失败") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", "", "", errors.New("请求失败,状态码: " + resp.Status) + } + + reader := textproto.NewReader(bufio.NewReader(resp.Body)) + _, err = reader.ReadLine() + if err != nil { + return "", "", "", errors.Wrap(err, "读取响应失败") + } + + addrLine, err := reader.ReadLine() + if err != nil { + return "", "", "", errors.Wrap(err, "读取响应失败") + } + addr := strings.Split(strings.Split(addrLine, ":")[1], " ") + prov = strings.TrimSpace(addr[1]) + city = strings.TrimSpace(addr[2]) + + ispLine, err := reader.ReadLine() + if err != nil { + return "", "", "", errors.Wrap(err, "读取响应失败") + } + isp = strings.TrimSpace(strings.Split(ispLine, ":")[1]) + + if prov == "" || city == "" || isp == "" { + return "", "", "", errors.New("解析数据为空") + } + + slog.Debug("获取归属地", "prov", prov, "city", city, "isp", isp) + return prov, city, isp, nil +} diff --git a/client/geo/geo.go b/client/geo/geo.go new file mode 100644 index 0000000..81a6ed1 --- /dev/null +++ b/client/geo/geo.go @@ -0,0 +1,3 @@ +package geo + +type Func func() (prov, city, isp string, err error) diff --git a/client/geo/ipapi.go b/client/geo/ipapi.go new file mode 100644 index 0000000..c4c87b6 --- /dev/null +++ b/client/geo/ipapi.go @@ -0,0 +1,40 @@ +package geo + +import ( + "encoding/json" + "github.com/pkg/errors" + "net/http" +) + +func Ipapi() (prov, city, isp string, err error) { + const endpoint = "http://ip-api.com/json/?fields=regionName,city,as&lang=zh-CN" + + resp, err := http.Get(endpoint) + if err != nil { + return "", "", "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", "", err + } + + var data struct { + RegionName string `json:"regionName"` + City string `json:"city"` + As string `json:"as"` + } + err = json.NewDecoder(resp.Body).Decode(&data) + if err != nil { + return "", "", "", err + } + + prov = data.RegionName + city = data.City + isp = data.As + if prov == "" || city == "" || isp == "" { + return "", "", "", errors.New("解析数据为空") + } + + return prov, city, isp, nil +} diff --git a/client/report/online.go b/client/report/online.go new file mode 100644 index 0000000..dd8d54c --- /dev/null +++ b/client/report/online.go @@ -0,0 +1,66 @@ +package report + +import ( + "encoding/json" + "errors" + "io" + "net/http" + "proxy-server/client/core" + "strings" +) + +func Online(prov, city, isp string) (host string, err error) { + + var ispInt = 0 + switch isp { + case "电信": + ispInt = 1 + case "联通": + ispInt = 2 + case "移动": + ispInt = 3 + } + + body, err := json.Marshal(map[string]any{ + "prov": prov, + "city": city, + "isp": ispInt, + "name": core.Name, + "version": core.Version, + }) + if err != nil { + return "", err + } + + req, err := http.NewRequest("POST", core.EndpointOnline, strings.NewReader(string(body))) + if err != nil { + return "", errors.New("创建节点注册请求失败") + } + + req.Header.Set("Content-Type", "application/json") + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", errors.New("节点注册失败") + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", errors.New("节点注册失败,状态码: " + resp.Status) + } + + bytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", errors.New("读取节点注册响应失败") + } + var respBody struct { + Host string `json:"host"` + } + err = json.Unmarshal(bytes, &respBody) + if err != nil { + return "", errors.New("解析节点注册响应失败") + } + if respBody.Host == "" { + return "", errors.New("节点注册失败,响应体为空") + } + + return respBody.Host, nil +} diff --git a/cmd/client/main.go b/cmd/client/main.go index 205cc8e..d6e574a 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -3,5 +3,8 @@ package main import "proxy-server/client" func main() { - client.Start() + err := client.Start() + if err != nil { + println(err) + } } diff --git a/cmd/gen/gen.go b/cmd/gen/gen.go index 25b65ce..f06e273 100644 --- a/cmd/gen/gen.go +++ b/cmd/gen/gen.go @@ -1,25 +1,41 @@ package main import ( - "proxy-server/server/pkg/env" - "proxy-server/server/pkg/orm" - + "gorm.io/driver/postgres" "gorm.io/gen" + "gorm.io/gorm" + "gorm.io/gorm/schema" ) func main() { - env.Init() - orm.Init() - - g := gen.NewGenerator(gen.Config{ - OutPath: "../../temp-out", - Mode: gen.WithoutContext | gen.WithDefaultQuery | gen.WithQueryInterface, - }) - g.UseDB(orm.DB) - g.ApplyBasic( - g.GenerateAllTable()..., + // 初始化 + db, _ := gorm.Open( + postgres.Open("host=localhost user=test password=test dbname=app port=5432 sslmode=disable TimeZone=Asia/Shanghai"), + &gorm.Config{ + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + }, ) - g.Execute() + g := gen.NewGenerator(gen.Config{ + OutPath: "server/repo/queries", + ModelPkgPath: "models", + Mode: gen.WithDefaultQuery | gen.WithoutContext, + }) + g.UseDB(db) + + common := []gen.ModelOpt{ + gen.FieldModify(func(field gen.Field) gen.Field { + if field.Type == "time.Time" { + field.Type = "orm.LocalDateTime" + } + return field + }), + } + + // 生成需要的模型 + g.GenerateModel("channel", common...) + g.GenerateModel("node", common...) } diff --git a/cmd/server/main.go b/cmd/server/main.go index 69fe366..3b826bd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -5,5 +5,7 @@ import ( ) func main() { - server.Start() + var app = server.New() + var err = app.Run() + println(err) } diff --git a/cmd/server/proxy.lock b/cmd/server/proxy.lock new file mode 100644 index 0000000..fb62c6a --- /dev/null +++ b/cmd/server/proxy.lock @@ -0,0 +1 @@ +r;$4Nz?w \ No newline at end of file diff --git a/go.mod b/go.mod index 48ee8da..c5b3b3e 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.24 require ( github.com/gin-gonic/gin v1.10.0 + github.com/google/uuid v1.6.0 github.com/joho/godotenv v1.5.1 github.com/lmittmann/tint v1.0.7 github.com/mattn/go-colorable v0.1.14 diff --git a/go.sum b/go.sum index f9c1dac..7d5e193 100644 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= diff --git a/server/pkg/env/env.go b/server/pkg/env/env.go index 87f2611..1d09593 100644 --- a/server/pkg/env/env.go +++ b/server/pkg/env/env.go @@ -2,104 +2,139 @@ package env import ( "fmt" + "log/slog" "os" "strconv" + + "github.com/joho/godotenv" ) var ( - AppCtrlPort uint16 - AppDataPort uint16 - AppWebPort uint16 - AppLogMode string + AppCtrlPort uint16 = 18080 + AppDataPort uint16 = 18081 + AppWebPort uint16 = 8848 + AppLogMode string = "dev" + + ClientId string + ClientSecret string DbHost string - DbPort uint16 + DbPort uint16 = 5432 DbDatabase string DbUsername string DbPassword string - DbTimezone string + DbTimezone string = "Asia/Shanghai" + + EndpointOnline string + EndpointOffline string ) func Init() { - - // AppCtrlPort - appCtrlPortStr := os.Getenv("APP_CTRL_PORT") - if appCtrlPortStr == "" { - panic("环境变量 APP_CTRL_PORT 未设置") - } - appCtrlPort, err := strconv.ParseUint(appCtrlPortStr, 10, 16) + var err = godotenv.Load() if err != nil { - panic(fmt.Sprintf("环境变量 APP_CTRL_PORT 格式错误: %v", err)) + slog.Debug("没有本地环境变量文件") } - AppCtrlPort = uint16(appCtrlPort) + var value string - // AppDataPort - appDataPortStr := os.Getenv("APP_DATA_PORT") - if appDataPortStr == "" { - panic("环境变量 APP_DATA_PORT 未设置") + value = os.Getenv("APP_CTRL_PORT") + if value != "" { + appCtrlPort, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_CTRL_PORT 格式错误: %v", err)) + } + AppCtrlPort = uint16(appCtrlPort) } - appDataPort, err := strconv.ParseUint(appDataPortStr, 10, 16) - if err != nil { - panic(fmt.Sprintf("环境变量 APP_DATA_PORT 格式错误: %v", err)) - } - AppDataPort = uint16(appDataPort) - // AppWebPort - appWebPortStr := os.Getenv("APP_WEB_PORT") - if appWebPortStr == "" { - appWebPortStr = "8848" + value = os.Getenv("APP_DATA_PORT") + if value != "" { + appDataPort, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_DATA_PORT 格式错误: %v", err)) + } + AppDataPort = uint16(appDataPort) } - appWebPort, err := strconv.ParseUint(appWebPortStr, 10, 16) - if err != nil { - panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err)) - } - AppWebPort = uint16(appWebPort) - // AppLogMode - appLogMode := os.Getenv("APP_LOG_MODE") - if appLogMode == "" { - AppLogMode = "dev" + value = os.Getenv("APP_WEB_PORT") + if value != "" { + appWebPort, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err)) + } + AppWebPort = uint16(appWebPort) } - AppLogMode = appLogMode - // DbHost - DbHost = os.Getenv("DB_HOST") - if DbHost == "" { + value = os.Getenv("APP_LOG_MODE") + if value != "" { + AppLogMode = value + } + + value = os.Getenv("CLIENT_ID") + if value != "" { + ClientId = value + } else { + panic("环境变量 CLIENT_ID 未设置") + } + + value = os.Getenv("CLIENT_SECRET") + if value != "" { + ClientSecret = value + } else { + panic("环境变量 CLIENT_SECRET 未设置") + } + + value = os.Getenv("DB_HOST") + if value != "" { + DbHost = os.Getenv("DB_HOST") + } else { panic("环境变量 DB_HOST 未设置") } - // DbPort - dbPortStr := os.Getenv("DB_PORT") - if dbPortStr == "" { - dbPortStr = "5432" + value = os.Getenv("DB_PORT") + if value != "" { + dbPort, err := strconv.Atoi(value) + if err != nil { + panic(fmt.Sprintf("环境变量 DB_PORT 格式错误: %v", err)) + } + DbPort = uint16(dbPort) } - dbPort, err := strconv.ParseUint(dbPortStr, 10, 16) - if err != nil { - panic(fmt.Sprintf("环境变量 DB_PORT 格式错误: %v", err)) - } - DbPort = uint16(dbPort) - // DbDatabase - DbDatabase = os.Getenv("DB_DATABASE") - if DbDatabase == "" { + value = os.Getenv("DB_DATABASE") + if value != "" { + DbDatabase = value + } else { panic("环境变量 DB_DATABASE 未设置") } - // DbUsername - DbUsername = os.Getenv("DB_USERNAME") - if DbUsername == "" { + value = os.Getenv("DB_USERNAME") + if value != "" { + DbUsername = value + } else { panic("环境变量 DB_USERNAME 未设置") } - // DbPassword - DbPassword = os.Getenv("DB_PASSWORD") - if DbPassword == "" { + value = os.Getenv("DB_PASSWORD") + if value != "" { + DbPassword = value + } else { panic("环境变量 DB_PASSWORD 未设置") } - // DbTimezone - DbTimezone = os.Getenv("DB_TIMEZONE") - if DbTimezone == "" { - DbTimezone = "Asia/Shanghai" + value = os.Getenv("DB_TIMEZONE") + if value != "" { + DbTimezone = value + } + + value = os.Getenv("ENDPOINT_ONLINE") + if value != "" { + EndpointOnline = value + } else { + panic("环境变量 ENDPOINT_ONLINE 未设置") + } + + value = os.Getenv("ENDPOINT_OFFLINE") + if value != "" { + EndpointOffline = value + } else { + panic("环境变量 ENDPOINT_OFFLINE 未设置") } } diff --git a/server/server.go b/server/server.go index dbf2331..414315a 100644 --- a/server/server.go +++ b/server/server.go @@ -2,6 +2,8 @@ package server import ( "context" + "encoding/base64" + "encoding/json" "log/slog" "net/http" "os" @@ -13,43 +15,52 @@ import ( "proxy-server/server/pkg/log" "proxy-server/server/pkg/orm" "proxy-server/server/web" - "runtime" + "strings" "sync" "syscall" "time" + "github.com/google/uuid" + "github.com/joho/godotenv" + + _ "net/http/pprof" ) -import _ "net/http/pprof" +const ( + Version = 1 + RestoreMagic = 0x72 +) -type Context struct { - context.Context - log *slog.Logger +type server struct { + name string } -func Start() { +func New() *server { + return &server{} +} + +func (s *server) Run() (err error) { // 初始化 - err := godotenv.Load() + err = s.init() if err != nil { - println("没有本地环境变量文件") + return err } - log.Init() - env.Init() - orm.Init() + // 恢复服务状态 + err = s.restore() + if err != nil { + return err + } - // 退出信号 - osQuit := make(chan os.Signal) - signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) - - // 启动服务 - slog.Info("启动服务") + // 准备子服务 ctx, cancel := context.WithCancel(context.Background()) defer cancel() + wg := sync.WaitGroup{} + // 转发服务 wg.Add(1) fwdQuit := make(chan struct{}, 1) go func() { @@ -57,19 +68,19 @@ func Start() { defer close(fwdQuit) err := startFwdServer(ctx) if err != nil { - slog.Error("代理服务发生错误", "err", err) + slog.Error("转发服务发生错误", "err", err) } fwdQuit <- struct{}{} }() - // 启动 web 服务 + // 接口服务 wg.Add(1) apiQuit := make(chan struct{}, 1) go func() { defer wg.Done() err := startWebServer(ctx) if err != nil { - slog.Error("web 服务发生错误", "err", err) + slog.Error("接口服务发生错误", "err", err) } apiQuit <- struct{}{} }() @@ -80,15 +91,22 @@ func Start() { }() // 性能监控 - go func() { - runtime.SetBlockProfileRate(1) - err := http.ListenAndServe(":6060", nil) - if err != nil { - slog.Error("性能监控服务发生错误", "err", err) - } - }() + // go func() { + // runtime.SetBlockProfileRate(1) + // err := http.ListenAndServe(":6060", nil) + // if err != nil { + // slog.Error("性能监控服务发生错误", "err", err) + // } + // }() + + // 报告上线 + slog.Debug("报告服务上线") + go reportOnline(ctx, s.name) // 等待退出信号 + osQuit := make(chan os.Signal, 1) + signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) + select { case <-osQuit: slog.Info("服务主动退出") @@ -97,6 +115,11 @@ func Start() { case <-apiQuit: slog.Warn("web 服务异常退出") } + + // 报告下线 + slog.Debug("报告服务下线") + go reportOffline(ctx, s.name) + // 退出其他服务 cancel() @@ -109,6 +132,51 @@ func Start() { case <-timeout.Done(): slog.Warn("退出超时,强制退出") } + + return nil +} + +func (s *server) restore() (err error) { + var file = "proxy.lock" + + bytes, err := os.ReadFile(file) + if err != nil { + return err + } + + if len(bytes) == 17 && bytes[0] == RestoreMagic { + s.name = uuid.UUID(bytes[1:]).String() + slog.Info("恢复服务名称", "name", s.name) + } else { + var u = uuid.New() + s.name = u.String() + + bytes = make([]byte, 17) + bytes[0] = RestoreMagic + copy(bytes[1:], u[:]) + err := os.WriteFile(file, bytes, 0644) + if err != nil { + return err + } + + slog.Info("生成服务名称", "name", s.name) + } + + return nil +} + +func (s *server) init() error { + + err := godotenv.Load() + if err != nil { + println("没有本地环境变量文件") + } + + log.Init() + env.Init() + orm.Init() + + return nil } func startFwdServer(ctx context.Context) error { @@ -124,3 +192,47 @@ func startFwdServer(ctx context.Context) error { func startWebServer(ctx context.Context) error { return web.Start(ctx) } + +func reportOnline(ctx context.Context, name string) { + reportRepeat(ctx, env.EndpointOnline, map[string]any{ + "name": name, + "version": Version, + }) +} + +func reportOffline(ctx context.Context, name string) { + reportRepeat(ctx, env.EndpointOffline, map[string]any{ + "name": name, + "version": Version, + }) +} + +func reportRepeat(ctx context.Context, endpoint string, body any) { + bodyStr, err := json.Marshal(body) + if err != nil { + panic(err) + } + + for { + req, err := http.NewRequest("POST", endpoint, strings.NewReader(string(bodyStr))) + if err != nil { + panic(err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Basic "+base64.RawURLEncoding.EncodeToString([]byte("proxy:proxy"))) + + resp, err := http.DefaultClient.Do(req) + if resp != nil && resp.StatusCode == http.StatusOK { + return + } + select { + case <-ctx.Done(): + return + default: + } + + slog.Warn("服务注册失败,五秒后重试", "err", err) + time.Sleep(5 * time.Second) + } +}