diff --git a/README.md b/README.md index 2c456f3..314b345 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,9 @@ ## todo +找一个其他方式即时关闭未成功建立数据通道的连接 + +排查下套接字重复的问题 + 鉴权时判断授权的协议 建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string,提高效率 @@ -10,7 +14,6 @@ 数据通道池化 - 可配配置环境变量 - 退出等待时间 diff --git a/client/client.go b/client/client.go index e17ab7b..246b151 100644 --- a/client/client.go +++ b/client/client.go @@ -47,7 +47,7 @@ func Start() { // 性能监控 go func() { runtime.SetBlockProfileRate(1) - err := http.ListenAndServe(":6060", nil) + err := http.ListenAndServe(":7070", nil) if err != nil { slog.Error("性能监控服务启动失败", "err", err) } @@ -152,8 +152,8 @@ func data(addr string, tag []byte) error { copy(tagBuf[2:], tag) // 向目标地址建立连接 - dst, err := net.Dial("tcp", addr) - if err != nil { + dst, dstErr := net.Dial("tcp", addr) + if dstErr != nil { tagBuf[0] = 0 } else { tagBuf[0] = 1 @@ -174,7 +174,7 @@ func data(addr string, tag []byte) error { if dst != nil { utils.Close(dst) } - return errors.New("目标地址连接失败") + return errors.Wrap(dstErr, "连接目标地址失败") } go func() { diff --git a/cmd/mock/main.go b/cmd/mock/main.go index c7ea927..2d7eb5e 100644 --- a/cmd/mock/main.go +++ b/cmd/mock/main.go @@ -1,7 +1,6 @@ package main import ( - "math/rand" "net/http" "time" ) @@ -14,8 +13,9 @@ func main() { func mock() { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - waiting := rand.Intn(450) + 50 - time.Sleep(time.Duration(waiting) * time.Millisecond) + // waiting := rand.Intn(450) + 50 + // time.Sleep(time.Duration(waiting) * time.Millisecond) + time.Sleep(200 * time.Millisecond) w.Write([]byte("Hello World")) }) diff --git a/cmd/server/.env.example b/cmd/server/.env.example index 4d63cd8..2b7e1d3 100644 --- a/cmd/server/.env.example +++ b/cmd/server/.env.example @@ -1,6 +1,7 @@ # 应用配置 APP_CTRL_PORT=18080 APP_DATA_PORT=18081 +APP_WEB_PORT=8848 APP_LOG_MODE=dev# dev | test # 数据库配置 diff --git a/server/debug/debug.go b/server/debug/debug.go new file mode 100644 index 0000000..98a6e3d --- /dev/null +++ b/server/debug/debug.go @@ -0,0 +1,51 @@ +package debug + +import ( + "container/ring" + "context" + "time" +) + +func Start(ctx context.Context) { + go startReceiveConsuming(ctx) +} + +type Consuming struct { + Auth time.Duration + Data time.Duration + Proxy time.Duration + Total time.Duration +} + +var ConsumingCh = make(chan Consuming, 1024) +var consumingList = ring.New(3000) + +func startReceiveConsuming(ctx context.Context) { + InitConsumingList() + for { + select { + case <-ctx.Done(): + return + case c := <-ConsumingCh: + consumingList.Value = c + consumingList = consumingList.Next() + } + } +} + +func InitConsumingList() { + for i := 0; i < consumingList.Len(); i++ { + consumingList.Value = nil + consumingList = consumingList.Next() + } +} + +func ConsumingList() []Consuming { + consuming := make([]Consuming, 0, consumingList.Len()) + consumingList.Do(func(value interface{}) { + if value != nil { + consuming = append(consuming, value.(Consuming)) + } + }) + return consuming +} diff --git a/server/fwd/core/auth.go b/server/fwd/core/auth.go index f8b76e9..27ee4a4 100644 --- a/server/fwd/core/auth.go +++ b/server/fwd/core/auth.go @@ -3,7 +3,7 @@ package core import ( "log/slog" "net" - "proxy-server/server/models" + models2 "proxy-server/server/pkg/models" "proxy-server/server/pkg/orm" "time" @@ -41,12 +41,12 @@ func CheckIp(conn net.Conn) (*AuthContext, error) { // 查询权限记录 slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort) - var channels []models.Channel + var channels []models2.Channel err = orm.DB. Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). Joins("INNER JOIN public.users u ON channels.user_id = u.id"). Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost). - Where(&models.Channel{ + Where(&models2.Channel{ AuthIp: true, }). Find(&channels).Error @@ -88,9 +88,9 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { }, nil // 查询通道配置 - var channel models.Channel + var channel models2.Channel err := orm.DB. - Where(&models.Channel{ + Where(&models2.Channel{ Username: username, AuthPass: true, }). @@ -125,7 +125,7 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { var ips int64 err = orm.DB. - Where(&models.UserIp{ + Where(&models2.UserIp{ UserId: channel.UserId, IpAddress: remoteHost, }). diff --git a/server/fwd/ctrl.go b/server/fwd/ctrl.go index 205c522..7f07adc 100644 --- a/server/fwd/ctrl.go +++ b/server/fwd/ctrl.go @@ -8,8 +8,9 @@ import ( "proxy-server/pkg/utils" "proxy-server/server/fwd/core" "proxy-server/server/fwd/dispatcher" - "proxy-server/server/models" + "proxy-server/server/fwd/metrics" "proxy-server/server/pkg/env" + "proxy-server/server/pkg/models" "proxy-server/server/pkg/orm" "strconv" "strings" @@ -170,6 +171,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error { return errors.Wrap(err, "客户端意外断开连接") } case user := <-proxy.Conn: + metrics.TimerAuth.Store(user.Conn, time.Now()) s.userConnWg.Add(1) go func() { defer s.userConnWg.Done() diff --git a/server/fwd/data.go b/server/fwd/data.go index 9fa4e7b..db8303c 100644 --- a/server/fwd/data.go +++ b/server/fwd/data.go @@ -5,9 +5,12 @@ import ( "log/slog" "net" "proxy-server/pkg/utils" + "proxy-server/server/debug" + "proxy-server/server/fwd/metrics" "proxy-server/server/pkg/env" "strconv" "sync" + "time" "github.com/pkg/errors" ) @@ -77,6 +80,7 @@ func (s *Service) processDataConn(client net.Conn) error { return errors.New("用户连接已关闭,tag:" + tag) } defer utils.Close(user) + data := time.Now() // 检查状态 if status != 1 { @@ -116,5 +120,34 @@ func (s *Service) processDataConn(client net.Conn) error { case <-utils.ChanWgWait(s.ctx, &wg): } + proxy := time.Now() + + start, startOk := metrics.TimerStart.Load(user.Conn) + auth, authOk := metrics.TimerAuth.Load(user.Conn) + + var authDuration time.Duration + if startOk && authOk { + authDuration = auth.(time.Time).Sub(start.(time.Time)) + } + + var dataDuration time.Duration + if authOk { + dataDuration = data.Sub(auth.(time.Time)) + } + + proxyDuration := proxy.Sub(data) + + var totalDuration time.Duration + if startOk { + totalDuration = proxy.Sub(start.(time.Time)) + } + + debug.ConsumingCh <- debug.Consuming{ + Auth: authDuration, + Data: dataDuration, + Proxy: proxyDuration, + Total: totalDuration, + } + return nil } diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go index 67f0223..3c05b22 100644 --- a/server/fwd/dispatcher/dispatch.go +++ b/server/fwd/dispatcher/dispatch.go @@ -7,6 +7,7 @@ import ( "proxy-server/pkg/utils" "proxy-server/server/fwd/core" "proxy-server/server/fwd/http" + "proxy-server/server/fwd/metrics" "proxy-server/server/fwd/socks" "strconv" "strings" @@ -112,6 +113,8 @@ func (s *Server) acceptHttp(ls net.Listener) error { return errors.Wrap(err, "dispatcher http accept error") } + metrics.TimerStart.Store(conn, time.Now()) + go func() { user, err := http.Process(s.ctx, conn) if err != nil { @@ -142,6 +145,8 @@ func (s *Server) acceptSocks(ls net.Listener) error { return errors.Wrap(err, "dispatcher socks accept error") } + metrics.TimerStart.Store(conn, time.Now()) + go func() { user, err := socks.Process(s.ctx, conn) if err != nil { diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index 2a9d1e1..302b774 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -34,13 +34,6 @@ func New(config *Config) *Service { Config: config, ctx: ctx, cancel: cancel, - - userConnMap: core.ConnMap{}, - - fwdLesWg: utils.CountWaitGroup{}, - ctrlConnWg: utils.CountWaitGroup{}, - dataConnWg: utils.CountWaitGroup{}, - userConnWg: utils.CountWaitGroup{}, } } diff --git a/server/fwd/metrics/debug.go b/server/fwd/metrics/debug.go new file mode 100644 index 0000000..9b623cc --- /dev/null +++ b/server/fwd/metrics/debug.go @@ -0,0 +1,6 @@ +package metrics + +import "sync" + +var TimerStart sync.Map +var TimerAuth sync.Map diff --git a/server/pkg/env/env.go b/server/pkg/env/env.go index 48e7345..87f2611 100644 --- a/server/pkg/env/env.go +++ b/server/pkg/env/env.go @@ -9,6 +9,7 @@ import ( var ( AppCtrlPort uint16 AppDataPort uint16 + AppWebPort uint16 AppLogMode string DbHost string @@ -43,11 +44,23 @@ func Init() { } AppDataPort = uint16(appDataPort) + // AppWebPort + appWebPortStr := os.Getenv("APP_WEB_PORT") + if appWebPortStr == "" { + appWebPortStr = "8848" + } + appWebPort, err := strconv.ParseUint(appWebPortStr, 10, 16) + if err != nil { + panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err)) + } + AppWebPort = uint16(appWebPort) + // AppLogMode appLogMode := os.Getenv("APP_LOG_MODE") if appLogMode == "" { AppLogMode = "dev" } + AppLogMode = appLogMode // DbHost DbHost = os.Getenv("DB_HOST") diff --git a/server/models/channel.go b/server/pkg/models/channel.go similarity index 100% rename from server/models/channel.go rename to server/pkg/models/channel.go diff --git a/server/models/node.go b/server/pkg/models/node.go similarity index 100% rename from server/models/node.go rename to server/pkg/models/node.go diff --git a/server/models/user-ip.go b/server/pkg/models/user-ip.go similarity index 100% rename from server/models/user-ip.go rename to server/pkg/models/user-ip.go diff --git a/server/models/user.go b/server/pkg/models/user.go similarity index 100% rename from server/models/user.go rename to server/pkg/models/user.go diff --git a/server/pkg/resp/resp.go b/server/pkg/resp/resp.go deleted file mode 100644 index 1aa1449..0000000 --- a/server/pkg/resp/resp.go +++ /dev/null @@ -1,21 +0,0 @@ -package resp - -type Data struct { - Error bool - Cause string - Data interface{} -} - -func Done(data interface{}) *Data { - return &Data{ - Error: false, - Data: data, - } -} - -func Fail(cause string) *Data { - return &Data{ - Error: true, - Cause: cause, - } -} diff --git a/server/server.go b/server/server.go index 007be82..dbf2331 100644 --- a/server/server.go +++ b/server/server.go @@ -7,10 +7,12 @@ import ( "os" "os/signal" "proxy-server/pkg/utils" + "proxy-server/server/debug" "proxy-server/server/fwd" "proxy-server/server/pkg/env" "proxy-server/server/pkg/log" "proxy-server/server/pkg/orm" + "proxy-server/server/web" "runtime" "sync" "syscall" @@ -38,6 +40,45 @@ func Start() { env.Init() orm.Init() + // 退出信号 + osQuit := make(chan os.Signal) + signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) + + // 启动服务 + slog.Info("启动服务") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + wg := sync.WaitGroup{} + + wg.Add(1) + fwdQuit := make(chan struct{}, 1) + go func() { + defer wg.Done() + defer close(fwdQuit) + err := startFwdServer(ctx) + if err != nil { + slog.Error("代理服务发生错误", "err", err) + } + fwdQuit <- struct{}{} + }() + + // 启动 web 服务 + wg.Add(1) + apiQuit := make(chan struct{}, 1) + go func() { + defer wg.Done() + err := startWebServer(ctx) + if err != nil { + slog.Error("web 服务发生错误", "err", err) + } + apiQuit <- struct{}{} + }() + + // debug + go func() { + debug.Start(ctx) + }() + // 性能监控 go func() { runtime.SetBlockProfileRate(1) @@ -47,44 +88,21 @@ func Start() { } }() - // 退出信号 - osQuit := make(chan os.Signal) - signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM) - - // 启动服务 - slog.Info("启动服务") - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - wg := sync.WaitGroup{} - wg.Add(1) - errQuit := make(chan struct{}, 1) - defer close(errQuit) - go func() { - defer wg.Done() - err := startFwdServer(ctx) - if err != nil { - slog.Error("代理服务发生错误", "err", err) - } - errQuit <- struct{}{} - }() - // 等待退出信号 select { case <-osQuit: slog.Info("服务主动退出") - case <-errQuit: - slog.Warn("服务异常退出") + case <-fwdQuit: + slog.Warn("fwd 服务异常退出") + case <-apiQuit: + slog.Warn("web 服务异常退出") } - - // 退出服务 + // 退出其他服务 cancel() timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - wg.Wait() - select { case <-utils.ChanWgWait(timeout, &wg): slog.Info("服务已退出") @@ -103,6 +121,6 @@ func startFwdServer(ctx context.Context) error { return nil } -func startWebServer(ctx context.Context) { - +func startWebServer(ctx context.Context) error { + return web.Start(ctx) } diff --git a/server/web/auth/auth.go b/server/web/auth/auth.go deleted file mode 100644 index 9ffc6e2..0000000 --- a/server/web/auth/auth.go +++ /dev/null @@ -1,10 +0,0 @@ -package auth - -import "github.com/gin-gonic/gin" - -type Config struct { -} - -func Apply(r *gin.Engine, config *Config) { - r.Use(middleware) -} diff --git a/server/web/auth/context.go b/server/web/auth/context.go deleted file mode 100644 index ff737ab..0000000 --- a/server/web/auth/context.go +++ /dev/null @@ -1,41 +0,0 @@ -package auth - -type Context interface { - Permissions() map[string]struct{} - PermitAll(permissions ...string) bool - PermitAny(permissions ...string) bool -} - -// region DeviceContext - -type DeviceContext struct { - ID uint - IpAddress string - Permissions map[string]struct{} -} - -func (c DeviceContext) PermitAny(permissions ...string) bool { - if _, exist := c.Permissions["*"]; exist { - return true - } - for _, permission := range permissions { - if _, ok := c.Permissions[permission]; ok { - return true - } - } - return false -} - -func (c DeviceContext) PermitAll(permissions ...string) bool { - if _, exist := c.Permissions["*"]; exist { - return true - } - for _, permission := range permissions { - if _, ok := c.Permissions[permission]; !ok { - return false - } - } - return true -} - -// endregion diff --git a/server/web/auth/middleware.go b/server/web/auth/middleware.go deleted file mode 100644 index b080e83..0000000 --- a/server/web/auth/middleware.go +++ /dev/null @@ -1,98 +0,0 @@ -package auth - -import ( - "encoding/base64" - "log/slog" - "net/http" - "os" - "proxy-server/server/pkg/resp" - "slices" - "strings" - - "github.com/gin-gonic/gin" - "github.com/pkg/errors" -) - -func middleware(c *gin.Context) { - - auth := check(c) - if auth { - secret, err := getSecret(c) - if err != nil { - slog.Error("认证失败", err) - fail400(c, err) - return - } - err = authenticate(c, secret) - if err != nil { - slog.Error("认证失败", err) - fail401(c, err) - return - } - } - - c.Next() -} - -var ( - securedPaths = []string{ - "/connect", - } -) - -func check(c *gin.Context) bool { - path := c.Request.URL.Path - if slices.Contains(securedPaths, path) { - return true - } - return false -} - -func getSecret(c *gin.Context) (string, error) { - - // 获取认证信息 - header := strings.Split(c.GetHeader("Authorization"), " ") - if len(header) != 2 { - return "", errors.New("无认证信息") - } - - // 检查认证类型 - schema := header[0] - if schema != "Secret" { - return "", errors.New("不支持的认证类型 " + schema) - } - - // 解码密钥 - parameters := header[1] - result, err := base64.URLEncoding.DecodeString(parameters) - if err != nil { - return "", errors.Wrap(err, "密钥解析失败") - } - - return string(result), nil -} - -func authenticate(_ *gin.Context, secret string) error { - if secret != os.Getenv("SECRET") { - return errors.New("认证失败") - } - return nil -} - -func fail400(c *gin.Context, err error) { - _ = c.Error(err) - c.Abort() - c.JSON( - http.StatusBadRequest, - resp.Fail(err.Error()), - ) -} - -func fail401(c *gin.Context, err error) { - _ = c.Error(err) - c.Abort() - c.JSON( - http.StatusUnauthorized, - resp.Fail(err.Error()), - ) -} diff --git a/server/web/handlers/channel.go b/server/web/handlers/channel.go deleted file mode 100644 index bffe3c2..0000000 --- a/server/web/handlers/channel.go +++ /dev/null @@ -1,167 +0,0 @@ -package handlers - -import ( - "log/slog" - "proxy-server/server/models" - "proxy-server/server/pkg/orm" - "proxy-server/server/pkg/resp" - "strings" - "time" - - "github.com/gin-gonic/gin" - "github.com/pkg/errors" -) - -// region frp 接口 - -type FrpData struct { - Reject bool - RejectReason string - Unchange bool -} - -func ChanRequest(c *gin.Context) { - type Body struct { - Content struct { - ProxyName string `json:"proxy_name"` - ProxyType string `json:"proxy_type"` - RemoteAddr string `json:"remote_addr"` - User interface{} - } - } - - op := c.Query("op") - if op != "NewUserConn" { - _ = c.Error(errors.New("不支持的操作")) - return - } - - id := c.GetHeader("X-Frp-Reqid") - if id == "" { - _ = c.Error(errors.New("请求头中缺少 X-Frp-Reqid")) - return - } - - var body Body - err := c.ShouldBindJSON(&body) - if err != nil { - _ = c.Error(errors.Wrap(err, "解析请求正文失败")) - return - } - content := body.Content - - // 检查此 ip 是否有权限访问目标 node - clientIp := strings.Split(content.RemoteAddr, ":")[0] - targetNode := content.ProxyName - slog.Debug(id + " 用户 " + clientIp + " 请求连接到 " + targetNode) - - var channels []models.Channel - err = orm.DB. - Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", targetNode). - Joins("INNER JOIN public.users u ON channels.user_id = u.id"). - Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", clientIp). - Find(&channels).Error - if err != nil { - _ = c.Error(errors.Wrap(err, "查询用户权限失败")) - return - } - - // 返回响应 - rsCount := len(channels) - if rsCount > 1 { - slog.Warn(clientIp + " + " + targetNode + "的组合有多个权限结果,这是不应当存在的") - } - - if rsCount == 0 { - slog.Debug(id + " 没有权限") - reject(c) - return - } - channel := channels[0] - if channel.Expiration.Before(time.Now()) { - slog.Debug(id + " 权限已过期") - reject(c) - return - } - - slog.Debug(id + " 通过验证") - confirm(c) -} - -func ChanTest(c *gin.Context) { - var body map[string]interface{} - err := c.ShouldBindJSON(&body) - if err != nil { - slog.Error("解析请求正文失败", err) - } - for k, v := range body { - slog.Debug("map", "key: ", k, " value: ", v) - } - confirm(c) -} - -func confirm(c *gin.Context) { - c.JSON(200, FrpData{ - Reject: false, - Unchange: true, - }) -} - -func reject(c *gin.Context) { - c.JSON(401, FrpData{ - Reject: true, - RejectReason: "客户端没有权限访问该节点", - }) -} - -// endregion - -func ChanAuth(c *gin.Context) { - type Body struct { - Username string `json:"username"` - Password string `json:"password"` - } - type Data struct { - Timeout uint64 `json:"timeout"` - } - - var body Body - err := c.ShouldBindJSON(&body) - if err != nil { - _ = c.Error(err) - c.JSON(400, resp.Fail("请求参数错误")) - return - } - - // 查找通道 - var result *models.Channel - orm.DB. - Model(&models.Channel{}). - Where(&models.Channel{ - Username: body.Username, - Password: body.Password, - }). - First(&result) - if result == nil { - _ = c.Error(errors.New("用户信息不存在")) - c.JSON(401, resp.Fail("账号密码错误")) - return - } - - // 验证账号密码 todo 哈希密码验证 - if result.Username != body.Username || result.Password != body.Password { - _ = c.Error(errors.New("账号密码错误")) - c.JSON(401, resp.Fail("账号密码错误")) - return - } - - // 计算到期时间 - timeout := result.Expiration.Sub(time.Now()) - - // todo 保存会话 对于大量短连接的情况,考虑如何保存连接会话信息 - - // 返回结果 - c.JSON(200, resp.Done(Data{ - Timeout: uint64(timeout.Seconds()), - })) -} diff --git a/server/web/handlers/debug.go b/server/web/handlers/debug.go new file mode 100644 index 0000000..06fbbf1 --- /dev/null +++ b/server/web/handlers/debug.go @@ -0,0 +1,36 @@ +package handlers + +import ( + "fmt" + "proxy-server/server/debug" + "slices" + + "github.com/gin-gonic/gin" +) + +func GetConsuming(c *gin.Context) { + list := debug.ConsumingList() + // sort by total time + slices.SortFunc(list, func(a debug.Consuming, b debug.Consuming) int { + if a.Total < b.Total { + return 1 + } else if a.Total > b.Total { + return -1 + } + return 0 + }) + // map to string + strList := make([]string, len(list)) + for i := 0; i < len(list); i++ { + times := list[i] + strList[i] = fmt.Sprintf("Auth: %s, Data: %s, Proxy: %s, Total: %s", times.Auth, times.Data, times.Proxy, times.Total) + } + c.JSON(200, strList) +} + +func RestConsuming(c *gin.Context) { + debug.InitConsumingList() + c.JSON(200, gin.H{ + "message": "success", + }) +} diff --git a/server/web/handlers/node.go b/server/web/handlers/node.go deleted file mode 100644 index 205f5bc..0000000 --- a/server/web/handlers/node.go +++ /dev/null @@ -1,116 +0,0 @@ -package handlers - -import ( - "os" - "proxy-server/server/models" - "proxy-server/server/pkg/orm" - - "github.com/gin-gonic/gin" - "github.com/pkg/errors" - "gorm.io/gorm" -) - -type NodeRegisterReq struct { - Name string - Secret string -} - -func NodeRegister(c *gin.Context) { - - // 请求参数 - var req NodeRegisterReq - err := c.ShouldBind(&req) - if err != nil { - _ = c.Error(errors.Wrap(err, "参数解析错误")) - return - } - - // 验证 secret - secret := os.Getenv("SECRET") - if req.Secret != secret { - _ = c.Error(errors.New("拒绝连接")) - return - } - - // 注册节点 - // todo 查询运营商和地区 - err = orm.DB.Transaction(func(tx *gorm.DB) error { - - // 查询节点是否已存在 - var count int64 - err := orm.DB.Where(&models.Node{ - Name: req.Name, - }).Count(&count).Error - if err != nil { - return err - } - - // 不存在则注册 - if count == 0 { - ipAddress := c.ClientIP() - node := models.Node{ - Name: req.Name, - Provider: "", - Location: "", - IPAddress: ipAddress, - } - err = orm.DB.Create(&node).Error - if err != nil { - return err - } - } - - return nil - }) - if err != nil { - _ = c.Error(errors.Wrap(err, "注册节点失败")) - return - } - - c.Status(200) -} - -type NodeReportReq struct { - Name string -} - -func NodeReport(c *gin.Context) { - - // 请求参数 - var req NodeReportReq - err := c.ShouldBind(&req) - if err != nil { - _ = c.Error(errors.Wrap(err, "参数解析错误")) - return - } - - // 上报节点信息 - err = orm.DB.Transaction(func(tx *gorm.DB) error { - - // 查询节点 - var node models.Node - err = orm.DB.Where(&models.Node{ - Name: req.Name, - }).First(&node).Error - if err != nil { - return err - } - - // 更新节点信息 - ipAddress := c.ClientIP() - if ipAddress != node.IPAddress { - err = orm.DB.Model(&node).Update("ip_address", ipAddress).Error - if err != nil { - return err - } - } - - return nil - }) - if err != nil { - _ = c.Error(errors.Wrap(err, "上报节点信息失败")) - return - } - - c.Status(200) -} diff --git a/server/web/handlers/user.go b/server/web/handlers/user.go deleted file mode 100644 index 5ac8282..0000000 --- a/server/web/handlers/user.go +++ /dev/null @@ -1 +0,0 @@ -package handlers diff --git a/server/web/router.go b/server/web/router.go new file mode 100644 index 0000000..85b0eb4 --- /dev/null +++ b/server/web/router.go @@ -0,0 +1,12 @@ +package web + +import ( + "proxy-server/server/web/handlers" + + "github.com/gin-gonic/gin" +) + +func Router(r *gin.Engine) { + r.Handle("GET", "/debug/consuming/list", handlers.GetConsuming) + r.Handle("GET", "/debug/consuming/reset", handlers.RestConsuming) +} diff --git a/server/web/router/router.go b/server/web/router/router.go deleted file mode 100644 index ef77fd4..0000000 --- a/server/web/router/router.go +++ /dev/null @@ -1,17 +0,0 @@ -package router - -import ( - handlers2 "proxy-server/server/web/handlers" - - "github.com/gin-gonic/gin" -) - -func Apply(r *gin.Engine) { - - r.POST("/node/register", handlers2.NodeRegister) - r.POST("/node/report", handlers2.NodeReport) - - r.POST("/chan/request", handlers2.ChanRequest) - r.POST("/chan/auth", handlers2.ChanAuth) - r.POST("/chan/test", handlers2.ChanTest) -} diff --git a/server/web/web.go b/server/web/web.go index 5331d81..d47c365 100644 --- a/server/web/web.go +++ b/server/web/web.go @@ -4,42 +4,37 @@ import ( "context" "log/slog" "net/http" - "os" - "proxy-server/server/web/auth" - "proxy-server/server/web/router" + "proxy-server/server/pkg/env" + "strconv" "github.com/gin-gonic/gin" + "github.com/pkg/errors" ) var server *http.Server -func Start(ctx context.Context, errCh chan error) { - address := ":" + os.Getenv("PORT") +func Start(ctx context.Context) error { + address := ":" + strconv.Itoa(int(env.AppWebPort)) engine := gin.Default() server = &http.Server{Addr: address, Handler: engine} + // 配置中间件和路由 + Router(engine) + // 监听关闭信号 go func() { <-ctx.Done() - slog.Info("web 服务被动关闭") - err := server.Shutdown(ctx) + err := server.Shutdown(context.Background()) if err != nil { slog.Error("web 服务关闭失败", err) - return } }() - // 配置中间件和路由 - auth.Apply(engine, nil) - router.Apply(engine) - // 启动服务 err := server.ListenAndServe() if err != nil { - errCh <- err - return + return errors.Wrap(err, "web 服务启动失败") } - slog.Debug("web 服务主动结束") - errCh <- nil + return nil }