添加在线调试 api

This commit is contained in:
2025-03-08 10:59:31 +08:00
parent 053041ae34
commit 5786ac9d99
28 changed files with 236 additions and 539 deletions

View File

@@ -1,5 +1,9 @@
## todo ## todo
找一个其他方式即时关闭未成功建立数据通道的连接
排查下套接字重复的问题
鉴权时判断授权的协议 鉴权时判断授权的协议
建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string提高效率 建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string提高效率
@@ -10,7 +14,6 @@
数据通道池化 数据通道池化
可配配置环境变量 可配配置环境变量
- 退出等待时间 - 退出等待时间

View File

@@ -47,7 +47,7 @@ func Start() {
// 性能监控 // 性能监控
go func() { go func() {
runtime.SetBlockProfileRate(1) runtime.SetBlockProfileRate(1)
err := http.ListenAndServe(":6060", nil) err := http.ListenAndServe(":7070", nil)
if err != nil { if err != nil {
slog.Error("性能监控服务启动失败", "err", err) slog.Error("性能监控服务启动失败", "err", err)
} }
@@ -152,8 +152,8 @@ func data(addr string, tag []byte) error {
copy(tagBuf[2:], tag) copy(tagBuf[2:], tag)
// 向目标地址建立连接 // 向目标地址建立连接
dst, err := net.Dial("tcp", addr) dst, dstErr := net.Dial("tcp", addr)
if err != nil { if dstErr != nil {
tagBuf[0] = 0 tagBuf[0] = 0
} else { } else {
tagBuf[0] = 1 tagBuf[0] = 1
@@ -174,7 +174,7 @@ func data(addr string, tag []byte) error {
if dst != nil { if dst != nil {
utils.Close(dst) utils.Close(dst)
} }
return errors.New("目标地址连接失败") return errors.Wrap(dstErr, "连接目标地址失败")
} }
go func() { go func() {

View File

@@ -1,7 +1,6 @@
package main package main
import ( import (
"math/rand"
"net/http" "net/http"
"time" "time"
) )
@@ -14,8 +13,9 @@ func main() {
func mock() { func mock() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
waiting := rand.Intn(450) + 50 // waiting := rand.Intn(450) + 50
time.Sleep(time.Duration(waiting) * time.Millisecond) // time.Sleep(time.Duration(waiting) * time.Millisecond)
time.Sleep(200 * time.Millisecond)
w.Write([]byte("Hello World")) w.Write([]byte("Hello World"))
}) })

View File

@@ -1,6 +1,7 @@
# 应用配置 # 应用配置
APP_CTRL_PORT=18080 APP_CTRL_PORT=18080
APP_DATA_PORT=18081 APP_DATA_PORT=18081
APP_WEB_PORT=8848
APP_LOG_MODE=dev# dev | test APP_LOG_MODE=dev# dev | test
# 数据库配置 # 数据库配置

51
server/debug/debug.go Normal file
View File

@@ -0,0 +1,51 @@
package debug
import (
"container/ring"
"context"
"time"
)
func Start(ctx context.Context) {
go startReceiveConsuming(ctx)
}
type Consuming struct {
Auth time.Duration
Data time.Duration
Proxy time.Duration
Total time.Duration
}
var ConsumingCh = make(chan Consuming, 1024)
var consumingList = ring.New(3000)
func startReceiveConsuming(ctx context.Context) {
InitConsumingList()
for {
select {
case <-ctx.Done():
return
case c := <-ConsumingCh:
consumingList.Value = c
consumingList = consumingList.Next()
}
}
}
func InitConsumingList() {
for i := 0; i < consumingList.Len(); i++ {
consumingList.Value = nil
consumingList = consumingList.Next()
}
}
func ConsumingList() []Consuming {
consuming := make([]Consuming, 0, consumingList.Len())
consumingList.Do(func(value interface{}) {
if value != nil {
consuming = append(consuming, value.(Consuming))
}
})
return consuming
}

View File

@@ -3,7 +3,7 @@ package core
import ( import (
"log/slog" "log/slog"
"net" "net"
"proxy-server/server/models" models2 "proxy-server/server/pkg/models"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"time" "time"
@@ -41,12 +41,12 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
// 查询权限记录 // 查询权限记录
slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort) slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort)
var channels []models.Channel var channels []models2.Channel
err = orm.DB. err = orm.DB.
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort).
Joins("INNER JOIN public.users u ON channels.user_id = u.id"). Joins("INNER JOIN public.users u ON channels.user_id = u.id").
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost). Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost).
Where(&models.Channel{ Where(&models2.Channel{
AuthIp: true, AuthIp: true,
}). }).
Find(&channels).Error Find(&channels).Error
@@ -88,9 +88,9 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
}, nil }, nil
// 查询通道配置 // 查询通道配置
var channel models.Channel var channel models2.Channel
err := orm.DB. err := orm.DB.
Where(&models.Channel{ Where(&models2.Channel{
Username: username, Username: username,
AuthPass: true, AuthPass: true,
}). }).
@@ -125,7 +125,7 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
var ips int64 var ips int64
err = orm.DB. err = orm.DB.
Where(&models.UserIp{ Where(&models2.UserIp{
UserId: channel.UserId, UserId: channel.UserId,
IpAddress: remoteHost, IpAddress: remoteHost,
}). }).

View File

@@ -8,8 +8,9 @@ import (
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"proxy-server/server/fwd/dispatcher" "proxy-server/server/fwd/dispatcher"
"proxy-server/server/models" "proxy-server/server/fwd/metrics"
"proxy-server/server/pkg/env" "proxy-server/server/pkg/env"
"proxy-server/server/pkg/models"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"strconv" "strconv"
"strings" "strings"
@@ -170,6 +171,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
return errors.Wrap(err, "客户端意外断开连接") return errors.Wrap(err, "客户端意外断开连接")
} }
case user := <-proxy.Conn: case user := <-proxy.Conn:
metrics.TimerAuth.Store(user.Conn, time.Now())
s.userConnWg.Add(1) s.userConnWg.Add(1)
go func() { go func() {
defer s.userConnWg.Done() defer s.userConnWg.Done()

View File

@@ -5,9 +5,12 @@ import (
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/debug"
"proxy-server/server/fwd/metrics"
"proxy-server/server/pkg/env" "proxy-server/server/pkg/env"
"strconv" "strconv"
"sync" "sync"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
@@ -77,6 +80,7 @@ func (s *Service) processDataConn(client net.Conn) error {
return errors.New("用户连接已关闭tag" + tag) return errors.New("用户连接已关闭tag" + tag)
} }
defer utils.Close(user) defer utils.Close(user)
data := time.Now()
// 检查状态 // 检查状态
if status != 1 { if status != 1 {
@@ -116,5 +120,34 @@ func (s *Service) processDataConn(client net.Conn) error {
case <-utils.ChanWgWait(s.ctx, &wg): case <-utils.ChanWgWait(s.ctx, &wg):
} }
proxy := time.Now()
start, startOk := metrics.TimerStart.Load(user.Conn)
auth, authOk := metrics.TimerAuth.Load(user.Conn)
var authDuration time.Duration
if startOk && authOk {
authDuration = auth.(time.Time).Sub(start.(time.Time))
}
var dataDuration time.Duration
if authOk {
dataDuration = data.Sub(auth.(time.Time))
}
proxyDuration := proxy.Sub(data)
var totalDuration time.Duration
if startOk {
totalDuration = proxy.Sub(start.(time.Time))
}
debug.ConsumingCh <- debug.Consuming{
Auth: authDuration,
Data: dataDuration,
Proxy: proxyDuration,
Total: totalDuration,
}
return nil return nil
} }

View File

@@ -7,6 +7,7 @@ import (
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"proxy-server/server/fwd/http" "proxy-server/server/fwd/http"
"proxy-server/server/fwd/metrics"
"proxy-server/server/fwd/socks" "proxy-server/server/fwd/socks"
"strconv" "strconv"
"strings" "strings"
@@ -112,6 +113,8 @@ func (s *Server) acceptHttp(ls net.Listener) error {
return errors.Wrap(err, "dispatcher http accept error") return errors.Wrap(err, "dispatcher http accept error")
} }
metrics.TimerStart.Store(conn, time.Now())
go func() { go func() {
user, err := http.Process(s.ctx, conn) user, err := http.Process(s.ctx, conn)
if err != nil { if err != nil {
@@ -142,6 +145,8 @@ func (s *Server) acceptSocks(ls net.Listener) error {
return errors.Wrap(err, "dispatcher socks accept error") return errors.Wrap(err, "dispatcher socks accept error")
} }
metrics.TimerStart.Store(conn, time.Now())
go func() { go func() {
user, err := socks.Process(s.ctx, conn) user, err := socks.Process(s.ctx, conn)
if err != nil { if err != nil {

View File

@@ -34,13 +34,6 @@ func New(config *Config) *Service {
Config: config, Config: config,
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
userConnMap: core.ConnMap{},
fwdLesWg: utils.CountWaitGroup{},
ctrlConnWg: utils.CountWaitGroup{},
dataConnWg: utils.CountWaitGroup{},
userConnWg: utils.CountWaitGroup{},
} }
} }

View File

@@ -0,0 +1,6 @@
package metrics
import "sync"
var TimerStart sync.Map
var TimerAuth sync.Map

13
server/pkg/env/env.go vendored
View File

@@ -9,6 +9,7 @@ import (
var ( var (
AppCtrlPort uint16 AppCtrlPort uint16
AppDataPort uint16 AppDataPort uint16
AppWebPort uint16
AppLogMode string AppLogMode string
DbHost string DbHost string
@@ -43,11 +44,23 @@ func Init() {
} }
AppDataPort = uint16(appDataPort) AppDataPort = uint16(appDataPort)
// AppWebPort
appWebPortStr := os.Getenv("APP_WEB_PORT")
if appWebPortStr == "" {
appWebPortStr = "8848"
}
appWebPort, err := strconv.ParseUint(appWebPortStr, 10, 16)
if err != nil {
panic(fmt.Sprintf("环境变量 APP_WEB_PORT 格式错误: %v", err))
}
AppWebPort = uint16(appWebPort)
// AppLogMode // AppLogMode
appLogMode := os.Getenv("APP_LOG_MODE") appLogMode := os.Getenv("APP_LOG_MODE")
if appLogMode == "" { if appLogMode == "" {
AppLogMode = "dev" AppLogMode = "dev"
} }
AppLogMode = appLogMode
// DbHost // DbHost
DbHost = os.Getenv("DB_HOST") DbHost = os.Getenv("DB_HOST")

View File

@@ -1,21 +0,0 @@
package resp
type Data struct {
Error bool
Cause string
Data interface{}
}
func Done(data interface{}) *Data {
return &Data{
Error: false,
Data: data,
}
}
func Fail(cause string) *Data {
return &Data{
Error: true,
Cause: cause,
}
}

View File

@@ -7,10 +7,12 @@ import (
"os" "os"
"os/signal" "os/signal"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/debug"
"proxy-server/server/fwd" "proxy-server/server/fwd"
"proxy-server/server/pkg/env" "proxy-server/server/pkg/env"
"proxy-server/server/pkg/log" "proxy-server/server/pkg/log"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"proxy-server/server/web"
"runtime" "runtime"
"sync" "sync"
"syscall" "syscall"
@@ -38,6 +40,45 @@ func Start() {
env.Init() env.Init()
orm.Init() orm.Init()
// 退出信号
osQuit := make(chan os.Signal)
signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM)
// 启动服务
slog.Info("启动服务")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := sync.WaitGroup{}
wg.Add(1)
fwdQuit := make(chan struct{}, 1)
go func() {
defer wg.Done()
defer close(fwdQuit)
err := startFwdServer(ctx)
if err != nil {
slog.Error("代理服务发生错误", "err", err)
}
fwdQuit <- struct{}{}
}()
// 启动 web 服务
wg.Add(1)
apiQuit := make(chan struct{}, 1)
go func() {
defer wg.Done()
err := startWebServer(ctx)
if err != nil {
slog.Error("web 服务发生错误", "err", err)
}
apiQuit <- struct{}{}
}()
// debug
go func() {
debug.Start(ctx)
}()
// 性能监控 // 性能监控
go func() { go func() {
runtime.SetBlockProfileRate(1) runtime.SetBlockProfileRate(1)
@@ -47,44 +88,21 @@ func Start() {
} }
}() }()
// 退出信号
osQuit := make(chan os.Signal)
signal.Notify(osQuit, os.Interrupt, syscall.SIGTERM)
// 启动服务
slog.Info("启动服务")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
wg := sync.WaitGroup{}
wg.Add(1)
errQuit := make(chan struct{}, 1)
defer close(errQuit)
go func() {
defer wg.Done()
err := startFwdServer(ctx)
if err != nil {
slog.Error("代理服务发生错误", "err", err)
}
errQuit <- struct{}{}
}()
// 等待退出信号 // 等待退出信号
select { select {
case <-osQuit: case <-osQuit:
slog.Info("服务主动退出") slog.Info("服务主动退出")
case <-errQuit: case <-fwdQuit:
slog.Warn("服务异常退出") slog.Warn("fwd 服务异常退出")
case <-apiQuit:
slog.Warn("web 服务异常退出")
} }
// 退出其他服务
// 退出服务
cancel() cancel()
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
wg.Wait()
select { select {
case <-utils.ChanWgWait(timeout, &wg): case <-utils.ChanWgWait(timeout, &wg):
slog.Info("服务已退出") slog.Info("服务已退出")
@@ -103,6 +121,6 @@ func startFwdServer(ctx context.Context) error {
return nil return nil
} }
func startWebServer(ctx context.Context) { func startWebServer(ctx context.Context) error {
return web.Start(ctx)
} }

View File

@@ -1,10 +0,0 @@
package auth
import "github.com/gin-gonic/gin"
type Config struct {
}
func Apply(r *gin.Engine, config *Config) {
r.Use(middleware)
}

View File

@@ -1,41 +0,0 @@
package auth
type Context interface {
Permissions() map[string]struct{}
PermitAll(permissions ...string) bool
PermitAny(permissions ...string) bool
}
// region DeviceContext
type DeviceContext struct {
ID uint
IpAddress string
Permissions map[string]struct{}
}
func (c DeviceContext) PermitAny(permissions ...string) bool {
if _, exist := c.Permissions["*"]; exist {
return true
}
for _, permission := range permissions {
if _, ok := c.Permissions[permission]; ok {
return true
}
}
return false
}
func (c DeviceContext) PermitAll(permissions ...string) bool {
if _, exist := c.Permissions["*"]; exist {
return true
}
for _, permission := range permissions {
if _, ok := c.Permissions[permission]; !ok {
return false
}
}
return true
}
// endregion

View File

@@ -1,98 +0,0 @@
package auth
import (
"encoding/base64"
"log/slog"
"net/http"
"os"
"proxy-server/server/pkg/resp"
"slices"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
func middleware(c *gin.Context) {
auth := check(c)
if auth {
secret, err := getSecret(c)
if err != nil {
slog.Error("认证失败", err)
fail400(c, err)
return
}
err = authenticate(c, secret)
if err != nil {
slog.Error("认证失败", err)
fail401(c, err)
return
}
}
c.Next()
}
var (
securedPaths = []string{
"/connect",
}
)
func check(c *gin.Context) bool {
path := c.Request.URL.Path
if slices.Contains(securedPaths, path) {
return true
}
return false
}
func getSecret(c *gin.Context) (string, error) {
// 获取认证信息
header := strings.Split(c.GetHeader("Authorization"), " ")
if len(header) != 2 {
return "", errors.New("无认证信息")
}
// 检查认证类型
schema := header[0]
if schema != "Secret" {
return "", errors.New("不支持的认证类型 " + schema)
}
// 解码密钥
parameters := header[1]
result, err := base64.URLEncoding.DecodeString(parameters)
if err != nil {
return "", errors.Wrap(err, "密钥解析失败")
}
return string(result), nil
}
func authenticate(_ *gin.Context, secret string) error {
if secret != os.Getenv("SECRET") {
return errors.New("认证失败")
}
return nil
}
func fail400(c *gin.Context, err error) {
_ = c.Error(err)
c.Abort()
c.JSON(
http.StatusBadRequest,
resp.Fail(err.Error()),
)
}
func fail401(c *gin.Context, err error) {
_ = c.Error(err)
c.Abort()
c.JSON(
http.StatusUnauthorized,
resp.Fail(err.Error()),
)
}

View File

@@ -1,167 +0,0 @@
package handlers
import (
"log/slog"
"proxy-server/server/models"
"proxy-server/server/pkg/orm"
"proxy-server/server/pkg/resp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
)
// region frp 接口
type FrpData struct {
Reject bool
RejectReason string
Unchange bool
}
func ChanRequest(c *gin.Context) {
type Body struct {
Content struct {
ProxyName string `json:"proxy_name"`
ProxyType string `json:"proxy_type"`
RemoteAddr string `json:"remote_addr"`
User interface{}
}
}
op := c.Query("op")
if op != "NewUserConn" {
_ = c.Error(errors.New("不支持的操作"))
return
}
id := c.GetHeader("X-Frp-Reqid")
if id == "" {
_ = c.Error(errors.New("请求头中缺少 X-Frp-Reqid"))
return
}
var body Body
err := c.ShouldBindJSON(&body)
if err != nil {
_ = c.Error(errors.Wrap(err, "解析请求正文失败"))
return
}
content := body.Content
// 检查此 ip 是否有权限访问目标 node
clientIp := strings.Split(content.RemoteAddr, ":")[0]
targetNode := content.ProxyName
slog.Debug(id + " 用户 " + clientIp + " 请求连接到 " + targetNode)
var channels []models.Channel
err = orm.DB.
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", targetNode).
Joins("INNER JOIN public.users u ON channels.user_id = u.id").
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", clientIp).
Find(&channels).Error
if err != nil {
_ = c.Error(errors.Wrap(err, "查询用户权限失败"))
return
}
// 返回响应
rsCount := len(channels)
if rsCount > 1 {
slog.Warn(clientIp + " + " + targetNode + "的组合有多个权限结果,这是不应当存在的")
}
if rsCount == 0 {
slog.Debug(id + " 没有权限")
reject(c)
return
}
channel := channels[0]
if channel.Expiration.Before(time.Now()) {
slog.Debug(id + " 权限已过期")
reject(c)
return
}
slog.Debug(id + " 通过验证")
confirm(c)
}
func ChanTest(c *gin.Context) {
var body map[string]interface{}
err := c.ShouldBindJSON(&body)
if err != nil {
slog.Error("解析请求正文失败", err)
}
for k, v := range body {
slog.Debug("map", "key: ", k, " value: ", v)
}
confirm(c)
}
func confirm(c *gin.Context) {
c.JSON(200, FrpData{
Reject: false,
Unchange: true,
})
}
func reject(c *gin.Context) {
c.JSON(401, FrpData{
Reject: true,
RejectReason: "客户端没有权限访问该节点",
})
}
// endregion
func ChanAuth(c *gin.Context) {
type Body struct {
Username string `json:"username"`
Password string `json:"password"`
}
type Data struct {
Timeout uint64 `json:"timeout"`
}
var body Body
err := c.ShouldBindJSON(&body)
if err != nil {
_ = c.Error(err)
c.JSON(400, resp.Fail("请求参数错误"))
return
}
// 查找通道
var result *models.Channel
orm.DB.
Model(&models.Channel{}).
Where(&models.Channel{
Username: body.Username,
Password: body.Password,
}).
First(&result)
if result == nil {
_ = c.Error(errors.New("用户信息不存在"))
c.JSON(401, resp.Fail("账号密码错误"))
return
}
// 验证账号密码 todo 哈希密码验证
if result.Username != body.Username || result.Password != body.Password {
_ = c.Error(errors.New("账号密码错误"))
c.JSON(401, resp.Fail("账号密码错误"))
return
}
// 计算到期时间
timeout := result.Expiration.Sub(time.Now())
// todo 保存会话 对于大量短连接的情况,考虑如何保存连接会话信息
// 返回结果
c.JSON(200, resp.Done(Data{
Timeout: uint64(timeout.Seconds()),
}))
}

View File

@@ -0,0 +1,36 @@
package handlers
import (
"fmt"
"proxy-server/server/debug"
"slices"
"github.com/gin-gonic/gin"
)
func GetConsuming(c *gin.Context) {
list := debug.ConsumingList()
// sort by total time
slices.SortFunc(list, func(a debug.Consuming, b debug.Consuming) int {
if a.Total < b.Total {
return 1
} else if a.Total > b.Total {
return -1
}
return 0
})
// map to string
strList := make([]string, len(list))
for i := 0; i < len(list); i++ {
times := list[i]
strList[i] = fmt.Sprintf("Auth: %s, Data: %s, Proxy: %s, Total: %s", times.Auth, times.Data, times.Proxy, times.Total)
}
c.JSON(200, strList)
}
func RestConsuming(c *gin.Context) {
debug.InitConsumingList()
c.JSON(200, gin.H{
"message": "success",
})
}

View File

@@ -1,116 +0,0 @@
package handlers
import (
"os"
"proxy-server/server/models"
"proxy-server/server/pkg/orm"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"gorm.io/gorm"
)
type NodeRegisterReq struct {
Name string
Secret string
}
func NodeRegister(c *gin.Context) {
// 请求参数
var req NodeRegisterReq
err := c.ShouldBind(&req)
if err != nil {
_ = c.Error(errors.Wrap(err, "参数解析错误"))
return
}
// 验证 secret
secret := os.Getenv("SECRET")
if req.Secret != secret {
_ = c.Error(errors.New("拒绝连接"))
return
}
// 注册节点
// todo 查询运营商和地区
err = orm.DB.Transaction(func(tx *gorm.DB) error {
// 查询节点是否已存在
var count int64
err := orm.DB.Where(&models.Node{
Name: req.Name,
}).Count(&count).Error
if err != nil {
return err
}
// 不存在则注册
if count == 0 {
ipAddress := c.ClientIP()
node := models.Node{
Name: req.Name,
Provider: "",
Location: "",
IPAddress: ipAddress,
}
err = orm.DB.Create(&node).Error
if err != nil {
return err
}
}
return nil
})
if err != nil {
_ = c.Error(errors.Wrap(err, "注册节点失败"))
return
}
c.Status(200)
}
type NodeReportReq struct {
Name string
}
func NodeReport(c *gin.Context) {
// 请求参数
var req NodeReportReq
err := c.ShouldBind(&req)
if err != nil {
_ = c.Error(errors.Wrap(err, "参数解析错误"))
return
}
// 上报节点信息
err = orm.DB.Transaction(func(tx *gorm.DB) error {
// 查询节点
var node models.Node
err = orm.DB.Where(&models.Node{
Name: req.Name,
}).First(&node).Error
if err != nil {
return err
}
// 更新节点信息
ipAddress := c.ClientIP()
if ipAddress != node.IPAddress {
err = orm.DB.Model(&node).Update("ip_address", ipAddress).Error
if err != nil {
return err
}
}
return nil
})
if err != nil {
_ = c.Error(errors.Wrap(err, "上报节点信息失败"))
return
}
c.Status(200)
}

View File

@@ -1 +0,0 @@
package handlers

12
server/web/router.go Normal file
View File

@@ -0,0 +1,12 @@
package web
import (
"proxy-server/server/web/handlers"
"github.com/gin-gonic/gin"
)
func Router(r *gin.Engine) {
r.Handle("GET", "/debug/consuming/list", handlers.GetConsuming)
r.Handle("GET", "/debug/consuming/reset", handlers.RestConsuming)
}

View File

@@ -1,17 +0,0 @@
package router
import (
handlers2 "proxy-server/server/web/handlers"
"github.com/gin-gonic/gin"
)
func Apply(r *gin.Engine) {
r.POST("/node/register", handlers2.NodeRegister)
r.POST("/node/report", handlers2.NodeReport)
r.POST("/chan/request", handlers2.ChanRequest)
r.POST("/chan/auth", handlers2.ChanAuth)
r.POST("/chan/test", handlers2.ChanTest)
}

View File

@@ -4,42 +4,37 @@ import (
"context" "context"
"log/slog" "log/slog"
"net/http" "net/http"
"os" "proxy-server/server/pkg/env"
"proxy-server/server/web/auth" "strconv"
"proxy-server/server/web/router"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
) )
var server *http.Server var server *http.Server
func Start(ctx context.Context, errCh chan error) { func Start(ctx context.Context) error {
address := ":" + os.Getenv("PORT") address := ":" + strconv.Itoa(int(env.AppWebPort))
engine := gin.Default() engine := gin.Default()
server = &http.Server{Addr: address, Handler: engine} server = &http.Server{Addr: address, Handler: engine}
// 配置中间件和路由
Router(engine)
// 监听关闭信号 // 监听关闭信号
go func() { go func() {
<-ctx.Done() <-ctx.Done()
slog.Info("web 服务被动关闭") err := server.Shutdown(context.Background())
err := server.Shutdown(ctx)
if err != nil { if err != nil {
slog.Error("web 服务关闭失败", err) slog.Error("web 服务关闭失败", err)
return
} }
}() }()
// 配置中间件和路由
auth.Apply(engine, nil)
router.Apply(engine)
// 启动服务 // 启动服务
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
errCh <- err return errors.Wrap(err, "web 服务启动失败")
return
} }
slog.Debug("web 服务主动结束") return nil
errCh <- nil
} }