重新规划网关与节点的交互协议,实现统一命令位的识别和处理
This commit is contained in:
@@ -1,12 +1,15 @@
|
||||
package app
|
||||
|
||||
import "proxy-server/server/core"
|
||||
import (
|
||||
"proxy-server/server/core"
|
||||
)
|
||||
|
||||
var (
|
||||
Id int32
|
||||
Name string
|
||||
PlatformSecret string // 平台密钥,验证接收的请求是否属于平台
|
||||
|
||||
Assigns = make(map[uint16]int32) // 转发端口 -> 转发服务ID
|
||||
Permits = make(map[uint16]core.Permit) // 转发端口 -> 权限配置
|
||||
Clients = core.SyncMap[int32, uint16]{} // 节点 ID -> 转发端口
|
||||
Assigns = core.SyncMap[uint16, int32]{} // 转发端口 -> 节点 ID
|
||||
Permits = core.SyncMap[uint16, *core.Permit]{} // 转发端口 -> 权限配置
|
||||
)
|
||||
|
||||
@@ -4,40 +4,13 @@ import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ConnMap struct {
|
||||
_map sync.Map
|
||||
}
|
||||
|
||||
func (c *ConnMap) LoadAndDelete(key string) (*Conn, bool) {
|
||||
_value, ok := c._map.LoadAndDelete(key)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return _value.(*Conn), true
|
||||
}
|
||||
|
||||
func (c *ConnMap) Store(key string, value *Conn) {
|
||||
c._map.Store(key, value)
|
||||
}
|
||||
|
||||
func (c *ConnMap) Range(f func(key string, value *Conn) bool) {
|
||||
c._map.Range(func(key, value any) bool {
|
||||
return f(key.(string), value.(*Conn))
|
||||
})
|
||||
}
|
||||
|
||||
func (c *ConnMap) Clear() {
|
||||
c._map.Clear()
|
||||
}
|
||||
|
||||
type Conn struct {
|
||||
Conn net.Conn
|
||||
Reader *bufio.Reader
|
||||
Tag string
|
||||
Tag [16]byte
|
||||
Protocol string
|
||||
Dest *FwdAddr
|
||||
Auth *AuthContext
|
||||
@@ -75,10 +48,6 @@ func (c Conn) SetWriteDeadline(t time.Time) error {
|
||||
return c.Conn.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c Conn) DestAddr() net.Addr {
|
||||
return c.Dest
|
||||
}
|
||||
|
||||
type FwdAddr struct {
|
||||
IP net.IP
|
||||
Port int
|
||||
65
server/core/map.go
Normal file
65
server/core/map.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package core
|
||||
|
||||
import "sync"
|
||||
|
||||
type SyncMap[K any, V any] struct {
|
||||
_map sync.Map
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Store(key K, value V) {
|
||||
m._map.Store(key, value)
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Load(key K) (value V, ok bool) {
|
||||
v, ok := m._map.Load(key)
|
||||
if ok {
|
||||
value = v.(V)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Swap(key K, value V) (previous V, loaded bool) {
|
||||
v, loaded := m._map.Swap(key, value)
|
||||
if loaded {
|
||||
previous = v.(V)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Delete(key K) {
|
||||
m._map.Delete(key)
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Clear() {
|
||||
m._map.Clear()
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) Range(f func(key K, value V) bool) {
|
||||
m._map.Range(func(k, v any) bool {
|
||||
return f(k.(K), v.(V))
|
||||
})
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
|
||||
v, loaded := m._map.LoadOrStore(key, value)
|
||||
if loaded {
|
||||
actual = v.(V)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) LoadAndDelete(key K) (value V, ok bool) {
|
||||
v, ok := m._map.LoadAndDelete(key)
|
||||
if ok {
|
||||
value = v.(V)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) CompareAndSwap(key K, old, new V) (swapped bool) {
|
||||
return m._map.CompareAndSwap(key, old, new)
|
||||
}
|
||||
|
||||
func (m *SyncMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) {
|
||||
return m._map.CompareAndDelete(key, old)
|
||||
}
|
||||
29
server/env/env.go
vendored
29
server/env/env.go
vendored
@@ -10,10 +10,13 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
AppCtrlPort uint16 = 18080
|
||||
AppDataPort uint16 = 18081
|
||||
AppWebPort uint16 = 8848
|
||||
AppLogMode = "dev"
|
||||
AppCtrlPort uint16 = 18080
|
||||
AppDataPort uint16 = 18081
|
||||
AppWebPort uint16 = 8848
|
||||
AppLogMode = "dev"
|
||||
AppExitTimeout = 5 // 等待服务停止的超时时间
|
||||
AppDataTimeout = 10 // 等待数据通道连接的超时时间
|
||||
AppUserTimeout = 10 // 等待用户发送数据的超时时间(端口复用需要分析协议,如果用户长期不发送数据,将会阻塞分析协程)
|
||||
|
||||
ClientId string
|
||||
ClientSecret string
|
||||
@@ -67,6 +70,24 @@ func Init() {
|
||||
AppLogMode = value
|
||||
}
|
||||
|
||||
value = os.Getenv("APP_EXIT_TIMEOUT")
|
||||
if value != "" {
|
||||
appExitTimeout, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("环境变量 APP_EXIT_TIMEOUT 格式错误: %v", err))
|
||||
}
|
||||
AppExitTimeout = appExitTimeout
|
||||
}
|
||||
|
||||
value = os.Getenv("APP_DATA_TIMEOUT")
|
||||
if value != "" {
|
||||
appDataTimeout, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("环境变量 APP_DATA_TIMEOUT 格式错误: %v", err))
|
||||
}
|
||||
AppDataTimeout = appDataTimeout
|
||||
}
|
||||
|
||||
value = os.Getenv("CLIENT_ID")
|
||||
if value != "" {
|
||||
ClientId = value
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/core"
|
||||
"strings"
|
||||
|
||||
"errors"
|
||||
@@ -27,7 +27,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
||||
slog.String("proxy", conn.Protocol),
|
||||
slog.String("node", conn.LocalAddr().String()),
|
||||
slog.String("proto", proto),
|
||||
slog.String("dest", conn.DestAddr().String()),
|
||||
slog.String("dest", conn.Dest.String()),
|
||||
slog.String("domain", domain),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/core"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@@ -36,7 +36,7 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
|
||||
}
|
||||
|
||||
// 查找权限配置
|
||||
var permit, ok = app.Permits[uint16(localPort)]
|
||||
var permit, ok = app.Permits.Load(uint16(localPort))
|
||||
if !ok {
|
||||
return nil, errors.New("没有权限")
|
||||
}
|
||||
@@ -68,10 +68,11 @@ func Protect(conn net.Conn, proto Protocol, username, password *string) (*core.A
|
||||
}
|
||||
}
|
||||
|
||||
var id, _ = app.Assigns.Load(uint16(localPort))
|
||||
return &core.AuthContext{
|
||||
Timeout: time.Since(permit.Expire).Seconds(),
|
||||
Payload: core.Payload{
|
||||
ID: app.Assigns[uint16(localPort)],
|
||||
ID: id,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -11,25 +12,21 @@ import (
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/app"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/fwd/dispatcher"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"proxy-server/server/report"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"errors"
|
||||
)
|
||||
|
||||
type CtrlCmd struct {
|
||||
conn net.Conn
|
||||
buf []byte
|
||||
}
|
||||
type CtrlCmdType int
|
||||
|
||||
var ctrlCmdChan = make(chan CtrlCmd, 1024)
|
||||
const (
|
||||
CtrlCmdPong CtrlCmdType = iota + 1
|
||||
CtrlCmdPing
|
||||
CtrlCmdOpen
|
||||
CtrlCmdClose
|
||||
CtrlCmdProxy
|
||||
)
|
||||
|
||||
func (s *Service) startCtrlTun() error {
|
||||
func (s *Service) listenCtrl() error {
|
||||
ctrlPort := env.AppCtrlPort
|
||||
slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort)))
|
||||
|
||||
@@ -56,7 +53,7 @@ func (s *Service) startCtrlTun() error {
|
||||
go func() {
|
||||
defer s.ctrlConnWg.Done()
|
||||
defer utils.Close(conn)
|
||||
err := s.processCtrlConn(conn)
|
||||
err := s.processCtrlConn(s.ctx, conn)
|
||||
if err != nil {
|
||||
slog.Error("处理控制通道连接失败", "err", err)
|
||||
}
|
||||
@@ -67,25 +64,80 @@ func (s *Service) startCtrlTun() error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
func (s *Service) processCtrlConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
reader := bufio.NewReader(conn)
|
||||
for {
|
||||
// 循环等待直到服务关闭
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
}
|
||||
|
||||
var recv = make([]byte, 4)
|
||||
_, err := io.ReadFull(reader, recv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取客户端 ID 失败: %w", err)
|
||||
// 读取命令
|
||||
cmdByte, err := reader.ReadByte()
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取节点命令失败: %w", err)
|
||||
}
|
||||
var cmd = CtrlCmdType(cmdByte)
|
||||
switch cmd {
|
||||
|
||||
// 连接建立命令
|
||||
case CtrlCmdOpen:
|
||||
var recv = make([]byte, 4)
|
||||
_, err = io.ReadFull(reader, recv)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取节点 ID 失败: %w", err)
|
||||
}
|
||||
var client = int32(binary.BigEndian.Uint32(recv))
|
||||
err = s.onOpen(conn, client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理连接建立命令失败: %w", err)
|
||||
}
|
||||
|
||||
// 心跳命令
|
||||
case CtrlCmdPing:
|
||||
err = s.onPing(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理心跳命令失败: %w", err)
|
||||
}
|
||||
|
||||
// 连接关闭命令
|
||||
case CtrlCmdClose:
|
||||
err = s.onClose(conn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("处理关闭命令失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
// 忽略其他不应该由节点发起的命令
|
||||
default:
|
||||
return fmt.Errorf("无法处理控制命令: %d", cmd)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) onPing(conn net.Conn) (err error) {
|
||||
return s.sendPong(conn)
|
||||
}
|
||||
|
||||
func (s *Service) onOpen(conn net.Conn, client int32) (err error) {
|
||||
// open 命令全局只执行一次
|
||||
_, ok := app.Clients.Load(client)
|
||||
if ok {
|
||||
return fmt.Errorf("节点 ID %d 已经连接", client)
|
||||
}
|
||||
var client = int32(binary.BigEndian.Uint32(recv))
|
||||
|
||||
// 分配端口
|
||||
var minim uint16 = 20000
|
||||
var maxim uint16 = 60000
|
||||
var port uint16
|
||||
for i := minim; i < maxim; i++ {
|
||||
var _, ok = app.Assigns[i]
|
||||
var _, ok = app.Assigns.Load(i)
|
||||
if !ok {
|
||||
port = i
|
||||
app.Assigns[i] = client
|
||||
app.Assigns.Store(i, client)
|
||||
app.Clients.Store(client, i)
|
||||
break
|
||||
}
|
||||
}
|
||||
@@ -94,126 +146,75 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
|
||||
}
|
||||
|
||||
// 报告端口分配
|
||||
err = report.Assigned(client, port)
|
||||
if err != nil {
|
||||
if err = report.Assigned(client, port); err != nil {
|
||||
return fmt.Errorf("报告端口分配失败: %w", err)
|
||||
}
|
||||
|
||||
// 响应客户端
|
||||
_, err = conn.Write([]byte{1})
|
||||
if err != nil {
|
||||
if err = s.sendPong(conn); err != nil {
|
||||
return fmt.Errorf("响应客户端失败: %w", err)
|
||||
}
|
||||
|
||||
// 启动转发服务
|
||||
slog.Info("监听转发端口", "port", port, "client", client)
|
||||
proxy, err := dispatcher.New(port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer proxy.Close()
|
||||
|
||||
s.fwdLesWg.Add(1)
|
||||
go func() {
|
||||
defer s.fwdLesWg.Done()
|
||||
err := proxy.Run()
|
||||
slog.Info("监听转发端口", "port", port, "client", client)
|
||||
err = s.listenUser(port, conn)
|
||||
if err != nil {
|
||||
slog.Error("代理服务运行失败", "err", err)
|
||||
slog.Error("监听转发端口失败", "port", port, "client", client, "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 监听控制通道连接
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
defer close(errCh)
|
||||
_, err := reader.ReadByte()
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
// 批量同步写入
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case cmd := <-ctrlCmdChan:
|
||||
_, err := cmd.conn.Write(cmd.buf)
|
||||
if err != nil {
|
||||
slog.Error("批量写入失败", "err", err)
|
||||
utils.Close(cmd.conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// 处理连接
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return nil
|
||||
case err := <-errCh:
|
||||
switch {
|
||||
case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."):
|
||||
slog.Debug("客户端主动断开连接")
|
||||
return nil
|
||||
case err == nil:
|
||||
return errors.New("客户端握手失败")
|
||||
default:
|
||||
return fmt.Errorf("客户端意外断开连接: %w", err)
|
||||
}
|
||||
case user := <-proxy.Conn:
|
||||
metrics.TimerAuth.Store(user.Conn, time.Now())
|
||||
s.userConnWg.Add(1)
|
||||
go func() {
|
||||
defer s.userConnWg.Done()
|
||||
err := s.processUserConn(user, conn)
|
||||
if err != nil {
|
||||
slog.Error("处理用户连接失败", "err", err)
|
||||
utils.Close(user)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error {
|
||||
|
||||
// 组织写入信息
|
||||
dst := user.DestAddr().String()
|
||||
dstLen := len(dst)
|
||||
|
||||
tag := user.Tag
|
||||
tagLen := len(tag)
|
||||
|
||||
writeBuf := make([]byte, 2+dstLen+tagLen)
|
||||
writeBuf[0] = byte(dstLen)
|
||||
copy(writeBuf[1:], dst)
|
||||
writeBuf[1+dstLen] = byte(tagLen)
|
||||
copy(writeBuf[2+dstLen:], tag)
|
||||
|
||||
// 异步写入命令
|
||||
ctrlCmdChan <- CtrlCmd{
|
||||
conn: ctrl,
|
||||
buf: writeBuf,
|
||||
func (s *Service) onClose(conn net.Conn) (err error) {
|
||||
_, portStr, err := net.SplitHostPort(conn.LocalAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 记录用户连接
|
||||
s.userConnMap.Store(user.Tag, user)
|
||||
port, err := strconv.ParseUint(portStr, 10, 16)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 如果限定时间内没有建立数据通道,则关闭连接
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
id, _ := app.Assigns.LoadAndDelete(uint16(port))
|
||||
app.Clients.Delete(id)
|
||||
app.Assigns.Delete(uint16(port))
|
||||
app.Permits.Delete(uint16(port))
|
||||
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
// 服务会在退出时统一关闭未消费的连接
|
||||
case <-timeout.Done():
|
||||
storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag)
|
||||
if ok {
|
||||
slog.Debug("建立数据通道超时", "tag", user.Tag)
|
||||
utils.Close(storedUser)
|
||||
}
|
||||
err = s.sendPong(conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) sendPong(conn net.Conn) (err error) {
|
||||
_, err = conn.Write([]byte{byte(CtrlCmdPong)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("响应客户端失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) sendProxy(conn net.Conn, tag [16]byte, addr string) (err error) {
|
||||
if len(addr) > 65535 {
|
||||
return fmt.Errorf("代理地址过长: %s", addr)
|
||||
}
|
||||
|
||||
buf := make([]byte, 1+16+2+len(addr))
|
||||
buf[0] = byte(CtrlCmdProxy)
|
||||
copy(buf[1:], tag[:])
|
||||
binary.BigEndian.PutUint16(buf[17:], uint16(len(addr)))
|
||||
copy(buf[19:], addr)
|
||||
|
||||
_, err = conn.Write(buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("发送代理命令失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package fwd
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
@@ -16,7 +18,7 @@ import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
func (s *Service) startDataTun() error {
|
||||
func (s *Service) listenData() error {
|
||||
dataPort := env.AppDataPort
|
||||
slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort)))
|
||||
|
||||
@@ -57,38 +59,34 @@ func (s *Service) startDataTun() error {
|
||||
}
|
||||
|
||||
func (s *Service) processDataConn(client net.Conn) error {
|
||||
var reader = bufio.NewReader(client)
|
||||
|
||||
// 接收 status
|
||||
status, err := utils.ReadByte(client)
|
||||
// 接收连接结果
|
||||
var buf = make([]byte, 17)
|
||||
_, err := io.ReadFull(reader, buf)
|
||||
if err != nil {
|
||||
return fmt.Errorf("从客户端获取 status 失败: %w", err)
|
||||
return fmt.Errorf("从客户端获取连接结果失败: %w", err)
|
||||
}
|
||||
|
||||
// 接收 tag
|
||||
tagLen, err := utils.ReadByte(client)
|
||||
if err != nil {
|
||||
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
|
||||
}
|
||||
tagBuf, err := utils.ReadBuffer(client, int(tagLen))
|
||||
if err != nil {
|
||||
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
|
||||
}
|
||||
tag := string(tagBuf)
|
||||
tag := buf[0:16]
|
||||
status := buf[16]
|
||||
|
||||
// 找到用户连接
|
||||
user, ok := s.userConnMap.LoadAndDelete(tag)
|
||||
// 加载用户连接
|
||||
var tagStr = uuid.UUID(tag).String()
|
||||
user, ok := s.userConnMap.LoadAndDelete(tagStr)
|
||||
if !ok {
|
||||
return errors.New("用户连接已关闭,tag:" + tag)
|
||||
return fmt.Errorf("用户连接已关闭,tag:%s", tagStr)
|
||||
}
|
||||
defer utils.Close(user)
|
||||
data := time.Now()
|
||||
|
||||
// 检查状态
|
||||
if status != 1 {
|
||||
return errors.New("目标地址建立连接失败")
|
||||
}
|
||||
|
||||
// 数据转发
|
||||
// 转发数据
|
||||
data := time.Now()
|
||||
|
||||
userPipeReader, userPipeWriter := io.Pipe()
|
||||
defer utils.Close(userPipeWriter)
|
||||
teeUser := io.TeeReader(user, userPipeWriter)
|
||||
@@ -110,7 +108,7 @@ func (s *Service) processDataConn(client net.Conn) error {
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, err := io.Copy(user, client)
|
||||
_, err := io.Copy(user, reader)
|
||||
if err != nil {
|
||||
slog.Error("数据转发失败 client->user", "err", err)
|
||||
}
|
||||
@@ -118,7 +116,7 @@ func (s *Service) processDataConn(client net.Conn) error {
|
||||
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
case <-utils.ChanWgWait(s.ctx, &wg):
|
||||
case <-utils.WgWait(&wg):
|
||||
}
|
||||
|
||||
proxy := time.Now()
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/fwd/http"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"proxy-server/server/fwd/socks"
|
||||
@@ -19,13 +19,14 @@ import (
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
Port uint16
|
||||
Conn chan *core.Conn
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
readTimeout time.Duration
|
||||
Port uint16
|
||||
Conn chan *core.Conn
|
||||
}
|
||||
|
||||
func New(port uint16) (*Server, error) {
|
||||
func New(port uint16, readTimeout time.Duration) (*Server, error) {
|
||||
|
||||
if port == 0 {
|
||||
return nil, errors.New("port is required")
|
||||
@@ -35,6 +36,7 @@ func New(port uint16) (*Server, error) {
|
||||
return &Server{
|
||||
ctx,
|
||||
cancel,
|
||||
readTimeout,
|
||||
port,
|
||||
make(chan *core.Conn),
|
||||
}, nil
|
||||
@@ -54,7 +56,7 @@ func (s *Server) Run() error {
|
||||
defer utils.Close(ls)
|
||||
|
||||
m := cmux.New(ls)
|
||||
m.SetReadTimeout(5 * time.Second)
|
||||
m.SetReadTimeout(s.readTimeout)
|
||||
defer m.Close()
|
||||
|
||||
socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05})))
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/fwd/core"
|
||||
"proxy-server/server/core"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ type Service struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
userConnMap core.ConnMap
|
||||
userConnMap core.SyncMap[string, *core.Conn]
|
||||
|
||||
fwdLesWg utils.CountWaitGroup
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
@@ -40,7 +40,7 @@ func (s *Service) Run() error {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.startCtrlTun()
|
||||
err := s.listenCtrl()
|
||||
if err != nil {
|
||||
slog.Error("fwd 控制通道监听发生错误", "err", err)
|
||||
errQuit <- struct{}{}
|
||||
@@ -52,7 +52,7 @@ func (s *Service) Run() error {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := s.startDataTun()
|
||||
err := s.listenData()
|
||||
if err != nil {
|
||||
slog.Error("fwd 数据通道监听发生错误", "err", err)
|
||||
errQuit <- struct{}{}
|
||||
@@ -75,13 +75,6 @@ func (s *Service) Run() error {
|
||||
s.fwdLesWg.Wait()
|
||||
s.userConnWg.Wait()
|
||||
|
||||
// 清理资源
|
||||
s.userConnMap.Range(func(key string, value *core.Conn) bool {
|
||||
utils.Close(value)
|
||||
return true
|
||||
})
|
||||
s.userConnMap.Clear()
|
||||
|
||||
s.ctrlConnWg.Wait()
|
||||
slog.Debug("控制通道连接已关闭")
|
||||
s.dataConnWg.Wait()
|
||||
|
||||
@@ -5,12 +5,13 @@ import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"io"
|
||||
"net"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/fwd/auth"
|
||||
"proxy-server/server/fwd/core"
|
||||
"strings"
|
||||
|
||||
"errors"
|
||||
@@ -132,7 +133,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) {
|
||||
return &core.Conn{
|
||||
Conn: req.conn,
|
||||
Reader: req.reader,
|
||||
Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(),
|
||||
Tag: uuid.New(),
|
||||
Protocol: "http",
|
||||
Dest: req.dest,
|
||||
Auth: req.auth,
|
||||
@@ -176,7 +177,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) {
|
||||
return &core.Conn{
|
||||
Conn: req.conn,
|
||||
Reader: newReader,
|
||||
Tag: req.conn.RemoteAddr().String() + "_" + req.conn.LocalAddr().String(),
|
||||
Tag: uuid.New(),
|
||||
Protocol: "http",
|
||||
Dest: req.dest,
|
||||
Auth: req.auth,
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package repo
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// Channel 连接认证模型
|
||||
type Channel struct {
|
||||
gorm.Model
|
||||
UserId uint
|
||||
NodeId uint
|
||||
UserAddr string
|
||||
NodePort int
|
||||
AuthIp bool
|
||||
AuthPass bool
|
||||
Protocol string
|
||||
Username string
|
||||
Password string
|
||||
Expiration time.Time
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package repo
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
// Node 客户端模型
|
||||
type Node struct {
|
||||
gorm.Model
|
||||
Name string
|
||||
Version byte
|
||||
FwdPort uint16
|
||||
Provider string
|
||||
Location string
|
||||
|
||||
Channels []Channel `gorm:"foreignKey:NodeId"`
|
||||
}
|
||||
@@ -6,12 +6,13 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/fwd/auth"
|
||||
"proxy-server/server/fwd/core"
|
||||
"slices"
|
||||
)
|
||||
|
||||
@@ -83,7 +84,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
||||
Conn: conn,
|
||||
Reader: reader,
|
||||
Protocol: "socks5",
|
||||
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
|
||||
Tag: uuid.New(),
|
||||
Dest: request.DestAddr,
|
||||
Auth: authCtx,
|
||||
}, nil
|
||||
|
||||
82
server/fwd/user.go
Normal file
82
server/fwd/user.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package fwd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/core"
|
||||
"proxy-server/server/env"
|
||||
"proxy-server/server/fwd/dispatcher"
|
||||
"proxy-server/server/fwd/metrics"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (s *Service) listenUser(port uint16, ctrl net.Conn) error {
|
||||
dspt, err := dispatcher.New(port, time.Duration(env.AppUserTimeout)*time.Second)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dspt.Close()
|
||||
|
||||
go func() {
|
||||
err := dspt.Run()
|
||||
if err != nil {
|
||||
slog.Error("代理服务运行失败", "err", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// 处理连接
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return nil
|
||||
case user := <-dspt.Conn:
|
||||
metrics.TimerAuth.Store(user.Conn, time.Now())
|
||||
s.userConnWg.Add(1)
|
||||
go func() {
|
||||
defer s.userConnWg.Done()
|
||||
err := s.processUserConn(user, ctrl)
|
||||
if err != nil {
|
||||
slog.Error("处理用户连接失败", "err", err)
|
||||
utils.Close(user)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) (err error) {
|
||||
|
||||
// 发送代理命令
|
||||
err = s.sendProxy(ctrl, user.Tag, user.Dest.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 保存用户连接
|
||||
s.userConnMap.Store(hex.EncodeToString(user.Tag[:]), user)
|
||||
|
||||
// 如果限定时间内没有建立数据通道,则关闭连接
|
||||
var timeout, cancel = context.WithTimeout(context.Background(), time.Duration(env.AppDataTimeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
err = timeout.Err()
|
||||
case <-s.ctx.Done():
|
||||
err = s.ctx.Err()
|
||||
}
|
||||
|
||||
_, ok := s.userConnMap.LoadAndDelete(hex.EncodeToString(user.Tag[:]))
|
||||
if ok {
|
||||
utils.Close(user)
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
slog.Error("用户连接超时", "tag", hex.EncodeToString(user.Tag[:]), "addr", user.RemoteAddr().String())
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -109,24 +109,26 @@ func (s *server) Run() (err error) {
|
||||
}
|
||||
cancel()
|
||||
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
// 主协程退出流程
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
// 报告下线
|
||||
slog.Debug("报告服务下线")
|
||||
err = report.Offline(app.Name)
|
||||
if err != nil {
|
||||
slog.Error("服务下线失败", "err", err)
|
||||
}
|
||||
|
||||
// 报告下线
|
||||
slog.Debug("报告服务下线")
|
||||
err = report.Offline(app.Name)
|
||||
if err != nil {
|
||||
slog.Error("服务下线失败", "err", err)
|
||||
}
|
||||
|
||||
// 关闭 redis
|
||||
g.ExitRedis()
|
||||
// 关闭 redis
|
||||
g.ExitRedis()
|
||||
}()
|
||||
|
||||
// 等待其它服务关闭
|
||||
select {
|
||||
case <-utils.ChanWgWait(timeout, &wg):
|
||||
case <-utils.WgWait(&wg):
|
||||
slog.Info("服务正常关闭")
|
||||
case <-timeout.Done():
|
||||
case <-time.After(time.Duration(env.AppExitTimeout) * time.Second):
|
||||
slog.Warn("超时强制关闭")
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ func Auth(ctx *fiber.Ctx) (err error) {
|
||||
}
|
||||
|
||||
// 保存授权配置
|
||||
app.Permits[req.Port] = req.Permit
|
||||
app.Permits.Store(req.Port, &req.Permit)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user