优化数据连接处理逻辑,避免数据通道连接空等待问题;授权部分添加全局白名单支持;现在节点丢失连接后不会清空授权数据
This commit is contained in:
@@ -33,7 +33,6 @@ func AddEdge(id int32, port uint16) {
|
||||
func DelEdge(port uint16) {
|
||||
id, _ := Assigns.LoadAndDelete(port)
|
||||
Edges.Delete(id)
|
||||
Permits.Delete(id)
|
||||
}
|
||||
|
||||
func LoadPermit(port uint16) *core.Permit {
|
||||
|
||||
@@ -63,7 +63,6 @@ func (a FwdAddr) String() string {
|
||||
}
|
||||
|
||||
type AuthContext struct {
|
||||
Timeout float64
|
||||
Payload Payload
|
||||
Meta map[string]any
|
||||
}
|
||||
|
||||
29
gateway/env/env.go
vendored
29
gateway/env/env.go
vendored
@@ -3,8 +3,10 @@ package env
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
@@ -20,6 +22,8 @@ var (
|
||||
AppDataTimeout = 10 // 等待数据通道连接的超时时间
|
||||
AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程)
|
||||
|
||||
AuthWhitelist []net.IP // 全局白名单,可以将白名单 IP 视为一个可信任代理
|
||||
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
|
||||
@@ -99,6 +103,31 @@ func Init() {
|
||||
AppDataTimeout = appDataTimeout
|
||||
}
|
||||
|
||||
value = os.Getenv("APP_USER_TIMEOUT")
|
||||
if value != "" {
|
||||
appUserTimeout, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("环境变量 APP_USER_TIMEOUT 格式错误: %v", err))
|
||||
}
|
||||
AppUserTimeout = appUserTimeout
|
||||
}
|
||||
|
||||
value = os.Getenv("AUTH_WHITELIST")
|
||||
if value != "" {
|
||||
ips := strings.Split(value, ",")
|
||||
for _, ip := range ips {
|
||||
ip = strings.TrimSpace(ip)
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
parsedIP := net.ParseIP(ip)
|
||||
if parsedIP == nil {
|
||||
panic(fmt.Sprintf("环境变量 AUTH_WHITELIST 格式错误: %s", ip))
|
||||
}
|
||||
AuthWhitelist = append(AuthWhitelist, parsedIP)
|
||||
}
|
||||
}
|
||||
|
||||
value = os.Getenv("CLIENT_ID")
|
||||
if value != "" {
|
||||
ClientId = value
|
||||
|
||||
@@ -5,6 +5,8 @@ import (
|
||||
"net"
|
||||
"proxy-server/gateway/app"
|
||||
"proxy-server/gateway/core"
|
||||
"proxy-server/gateway/env"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -35,6 +37,21 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
|
||||
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||
}
|
||||
|
||||
var id, _ = app.Assigns.Load(uint16(localPort))
|
||||
|
||||
// 检查全局白名单
|
||||
var remoteIp = net.ParseIP(remoteHost)
|
||||
if remoteIp == nil {
|
||||
return nil, fmt.Errorf("无法解析 IP 地址: %s", remoteHost)
|
||||
}
|
||||
if slices.ContainsFunc(env.AuthWhitelist, func(ip net.IP) bool { return ip.Equal(remoteIp) }) {
|
||||
return &core.AuthContext{
|
||||
Payload: core.Payload{
|
||||
ID: id,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 查找权限配置
|
||||
var permit = app.LoadPermit(uint16(localPort))
|
||||
if permit == nil {
|
||||
@@ -68,9 +85,7 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
|
||||
}
|
||||
}
|
||||
|
||||
var id, _ = app.Assigns.Load(uint16(localPort))
|
||||
return &core.AuthContext{
|
||||
Timeout: time.Since(permit.Expire).Seconds(),
|
||||
Payload: core.Payload{
|
||||
ID: id,
|
||||
},
|
||||
|
||||
@@ -3,19 +3,19 @@ package fwd
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/gateway/app"
|
||||
"proxy-server/gateway/core"
|
||||
"proxy-server/gateway/debug"
|
||||
"proxy-server/gateway/env"
|
||||
"proxy-server/gateway/fwd/metrics"
|
||||
"proxy-server/utils"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -60,7 +60,10 @@ func ListenData(ctx context.Context) error {
|
||||
app.DataConnWg.Add(1)
|
||||
go func() {
|
||||
defer app.DataConnWg.Done()
|
||||
defer utils.Close(conn)
|
||||
defer func() {
|
||||
utils.Close(conn)
|
||||
slog.Debug("关闭数据通道连接")
|
||||
}()
|
||||
err := processDataConn(ctx, conn)
|
||||
if err != nil {
|
||||
slog.Error("处理数据通道连接失败", "err", err)
|
||||
@@ -80,88 +83,107 @@ func processDataConn(ctx context.Context, client net.Conn) error {
|
||||
return fmt.Errorf("从节点获取连接结果失败: %w", err)
|
||||
}
|
||||
|
||||
tag := buf[0:16]
|
||||
tag := hex.EncodeToString(buf[0:16])
|
||||
status := buf[16]
|
||||
|
||||
// 加载用户连接
|
||||
var tagStr = uuid.UUID(tag).String()
|
||||
user, ok := app.UserConnMap.LoadAndDelete(tagStr)
|
||||
user, ok := app.UserConnMap.LoadAndDelete(tag)
|
||||
if !ok {
|
||||
return fmt.Errorf("用户连接已关闭,tag:%s", tagStr)
|
||||
return fmt.Errorf("用户连接已关闭,tag:%s", tag)
|
||||
}
|
||||
defer utils.Close(user)
|
||||
defer func() {
|
||||
utils.Close(user)
|
||||
slog.Debug("关闭用户连接")
|
||||
}()
|
||||
|
||||
// 检查状态
|
||||
if status != 1 {
|
||||
return errors.New("目标地址建立连接失败")
|
||||
}
|
||||
|
||||
// 转发数据
|
||||
data := time.Now()
|
||||
|
||||
userPipeReader, userPipeWriter := io.Pipe()
|
||||
defer utils.Close(userPipeWriter)
|
||||
// 复制用户流量进行访问目标分析
|
||||
userCopyFrom, userCopyTo := io.Pipe()
|
||||
defer utils.Close(userCopyTo)
|
||||
|
||||
teeUser := io.TeeReader(user, userPipeWriter)
|
||||
teeUser := io.TeeReader(user, userCopyTo)
|
||||
go func() {
|
||||
err := analysisAndLog(user, userPipeReader)
|
||||
err := analysisAndLog(user, userCopyFrom)
|
||||
if err != nil {
|
||||
slog.Error("数据解析失败", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
// 复制节点数据到用户
|
||||
var waitEdge = make(chan error)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(client, teeUser)
|
||||
if err != nil {
|
||||
slog.Error("数据转发失败 user->client", "err", err)
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(user, reader)
|
||||
if err != nil {
|
||||
slog.Error("数据转发失败 client->user", "err", err)
|
||||
switch {
|
||||
case errors.Is(err, net.ErrClosed):
|
||||
slog.Debug("节点连接意外关闭")
|
||||
case err != nil:
|
||||
slog.Error("读取节点数据失败", "err", err)
|
||||
default:
|
||||
slog.Debug("节点数据读取完成")
|
||||
}
|
||||
waitEdge <- err
|
||||
}()
|
||||
|
||||
// 复制用户数据到节点
|
||||
var waitUser = make(chan error)
|
||||
go func() {
|
||||
_, err := io.Copy(client, teeUser)
|
||||
switch {
|
||||
case errors.Is(err, net.ErrClosed):
|
||||
slog.Debug("用户连接意外关闭")
|
||||
case err != nil:
|
||||
slog.Error("读取用户数据失败", "err", err)
|
||||
default:
|
||||
slog.Debug("用户数据读取完成")
|
||||
}
|
||||
waitUser <- err
|
||||
}()
|
||||
|
||||
// 等待数据转发完成,关闭数据通道的时机:
|
||||
select {
|
||||
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
slog.Debug("服务关闭")
|
||||
case <-waitEdge:
|
||||
case <-waitUser:
|
||||
storeConnMatrics(user, data)
|
||||
}
|
||||
|
||||
case <-utils.WgWait(&wg):
|
||||
proxy := time.Now()
|
||||
return nil
|
||||
}
|
||||
|
||||
start, startOk := metrics.TimerStart.Load(user.Conn)
|
||||
auth, authOk := metrics.TimerAuth.Load(user.Conn)
|
||||
func storeConnMatrics(user *core.Conn, data time.Time) {
|
||||
proxy := time.Now()
|
||||
|
||||
var authDuration time.Duration
|
||||
if startOk && authOk {
|
||||
authDuration = auth.(time.Time).Sub(start.(time.Time))
|
||||
}
|
||||
start, startOk := metrics.TimerStart.Load(user.Conn)
|
||||
auth, authOk := metrics.TimerAuth.Load(user.Conn)
|
||||
|
||||
var dataDuration time.Duration
|
||||
if authOk {
|
||||
dataDuration = data.Sub(auth.(time.Time))
|
||||
}
|
||||
var authDuration time.Duration
|
||||
if startOk && authOk {
|
||||
authDuration = auth.(time.Time).Sub(start.(time.Time))
|
||||
}
|
||||
|
||||
proxyDuration := proxy.Sub(data)
|
||||
var dataDuration time.Duration
|
||||
if authOk {
|
||||
dataDuration = data.Sub(auth.(time.Time))
|
||||
}
|
||||
|
||||
var totalDuration time.Duration
|
||||
if startOk {
|
||||
totalDuration = proxy.Sub(start.(time.Time))
|
||||
}
|
||||
proxyDuration := proxy.Sub(data)
|
||||
|
||||
debug.ConsumingCh <- debug.Consuming{
|
||||
Auth: authDuration,
|
||||
Data: dataDuration,
|
||||
Proxy: proxyDuration,
|
||||
Total: totalDuration,
|
||||
}
|
||||
var totalDuration time.Duration
|
||||
if startOk {
|
||||
totalDuration = proxy.Sub(start.(time.Time))
|
||||
}
|
||||
|
||||
return nil
|
||||
debug.ConsumingCh <- debug.Consuming{
|
||||
Auth: authDuration,
|
||||
Data: dataDuration,
|
||||
Proxy: proxyDuration,
|
||||
Total: totalDuration,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -67,7 +67,8 @@ func processUserConn(ctx context.Context, user *core.Conn, ctrl io.Writer) (err
|
||||
}
|
||||
|
||||
// 保存用户连接
|
||||
app.UserConnMap.Store(hex.EncodeToString(user.Tag[:]), user)
|
||||
var tag = hex.EncodeToString(user.Tag[:])
|
||||
app.UserConnMap.Store(tag, user)
|
||||
|
||||
// 如果限定时间内没有建立数据通道,则关闭连接
|
||||
var timeout, cancel = context.WithTimeout(context.Background(), time.Duration(env.AppDataTimeout)*time.Second)
|
||||
@@ -80,11 +81,11 @@ func processUserConn(ctx context.Context, user *core.Conn, ctrl io.Writer) (err
|
||||
err = ctx.Err()
|
||||
}
|
||||
|
||||
_, ok := app.UserConnMap.LoadAndDelete(hex.EncodeToString(user.Tag[:]))
|
||||
_, ok := app.UserConnMap.LoadAndDelete(tag)
|
||||
if ok {
|
||||
utils.Close(user)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
slog.Error("用户连接超时", "tag", hex.EncodeToString(user.Tag[:]), "addr", user.RemoteAddr().String())
|
||||
slog.Error("用户连接超时", "tag", tag, "addr", user.RemoteAddr().String())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user