239 lines
4.9 KiB
Go
239 lines
4.9 KiB
Go
package fwd
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"log/slog"
|
|
"net"
|
|
"proxy-server/pkg/utils"
|
|
"proxy-server/server/fwd/core"
|
|
"proxy-server/server/fwd/dispatcher"
|
|
"proxy-server/server/models"
|
|
"proxy-server/server/pkg/env"
|
|
"proxy-server/server/pkg/orm"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type CtrlCmd struct {
|
|
conn net.Conn
|
|
buf []byte
|
|
}
|
|
|
|
var ctrlCmdChan = make(chan CtrlCmd, 1024)
|
|
|
|
func (s *Service) startCtrlTun() error {
|
|
ctrlPort := env.AppCtrlPort
|
|
slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort)))
|
|
|
|
// 监听端口
|
|
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort)))
|
|
if err != nil {
|
|
return errors.Wrap(err, "监听控制通道失败")
|
|
}
|
|
defer utils.Close(ls)
|
|
|
|
// 处理连接
|
|
connCh := utils.ChanConnAccept(s.ctx, ls)
|
|
err = nil
|
|
for loop := true; loop; {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
loop = false
|
|
case conn, ok := <-connCh:
|
|
if !ok {
|
|
err = errors.New("获取连接失败")
|
|
loop = false
|
|
}
|
|
s.ctrlConnWg.Add(1)
|
|
go func() {
|
|
defer s.ctrlConnWg.Done()
|
|
defer utils.Close(conn)
|
|
err := s.processCtrlConn(conn)
|
|
if err != nil {
|
|
slog.Error("处理控制通道连接失败", "err", err)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (s *Service) processCtrlConn(conn net.Conn) error {
|
|
reader := bufio.NewReader(conn)
|
|
|
|
// version
|
|
version, err := reader.ReadByte()
|
|
if err != nil {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.Wrap(err, "获取版本号失败")
|
|
}
|
|
|
|
// name
|
|
nameLen, err := reader.ReadByte()
|
|
if err != nil {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.Wrap(err, "获取 name 失败")
|
|
}
|
|
nameBuf, err := utils.ReadBuffer(reader, int(nameLen))
|
|
if err != nil {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.Wrap(err, "获取 name 失败")
|
|
}
|
|
name := string(nameBuf)
|
|
|
|
if name == "" {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.New("客户端名称不能为空")
|
|
}
|
|
|
|
// 检查客户端
|
|
var node models.Node
|
|
err = orm.DB.First(&node, &models.Node{
|
|
Name: name,
|
|
}).Error
|
|
if err != nil {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.Wrap(err, "查询客户端失败")
|
|
}
|
|
|
|
if version != node.Version {
|
|
_ = ctrlResp(conn, CtrlFail)
|
|
return errors.New("客户端版本不匹配")
|
|
}
|
|
|
|
err = ctrlResp(conn, CtrlDone)
|
|
if err != nil {
|
|
return errors.Wrap(err, "向客户端发送响应失败")
|
|
}
|
|
|
|
port := node.FwdPort
|
|
slog.Info("监听转发端口", "port", port, "client", name)
|
|
|
|
// 启动转发服务
|
|
proxy, err := dispatcher.New(port)
|
|
if err != nil {
|
|
return errors.Wrap(err, "创建 socks 转发服务失败")
|
|
}
|
|
defer proxy.Close()
|
|
|
|
s.fwdLesWg.Add(1)
|
|
go func() {
|
|
defer s.fwdLesWg.Done()
|
|
err := proxy.Run()
|
|
if err != nil {
|
|
slog.Error("代理服务运行失败", "err", err)
|
|
}
|
|
}()
|
|
|
|
// 监听控制通道连接
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
defer close(errCh)
|
|
_, err := reader.ReadByte()
|
|
errCh <- err
|
|
}()
|
|
|
|
// 批量同步写入
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return
|
|
case cmd := <-ctrlCmdChan:
|
|
_, err := cmd.conn.Write(cmd.buf)
|
|
if err != nil {
|
|
slog.Error("批量写入失败", "err", err)
|
|
utils.Close(cmd.conn)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
// 处理连接
|
|
for {
|
|
select {
|
|
case <-s.ctx.Done():
|
|
return nil
|
|
case err := <-errCh:
|
|
switch {
|
|
case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."):
|
|
slog.Debug("客户端主动断开连接")
|
|
return nil
|
|
case err == nil:
|
|
return errors.New("客户端握手失败")
|
|
default:
|
|
return errors.Wrap(err, "客户端意外断开连接")
|
|
}
|
|
case user := <-proxy.Conn:
|
|
s.userConnWg.Add(1)
|
|
go func() {
|
|
defer s.userConnWg.Done()
|
|
err := s.processUserConn(user, conn)
|
|
if err != nil {
|
|
slog.Error("处理用户连接失败", "err", err)
|
|
utils.Close(user)
|
|
}
|
|
}()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error {
|
|
|
|
// 组织写入信息
|
|
dst := user.DestAddr().String()
|
|
dstLen := len(dst)
|
|
|
|
tag := user.Tag
|
|
tagLen := len(tag)
|
|
|
|
writeBuf := make([]byte, 2+dstLen+tagLen)
|
|
writeBuf[0] = byte(dstLen)
|
|
copy(writeBuf[1:], dst)
|
|
writeBuf[1+dstLen] = byte(tagLen)
|
|
copy(writeBuf[2+dstLen:], tag)
|
|
|
|
// 异步写入命令
|
|
ctrlCmdChan <- CtrlCmd{
|
|
conn: ctrl,
|
|
buf: writeBuf,
|
|
}
|
|
|
|
// 记录用户连接
|
|
s.userConnMap.Store(user.Tag, user)
|
|
|
|
// 如果限定时间内没有建立数据通道,则关闭连接
|
|
timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
select {
|
|
case <-s.ctx.Done():
|
|
// 服务会在退出时统一关闭未消费的连接
|
|
case <-timeout.Done():
|
|
storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag)
|
|
if ok {
|
|
slog.Debug("建立数据通道超时", "tag", user.Tag)
|
|
utils.Close(storedUser)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type CtrlResult byte
|
|
|
|
const (
|
|
CtrlFail CtrlResult = iota
|
|
CtrlDone
|
|
)
|
|
|
|
func ctrlResp(conn net.Conn, result CtrlResult) error {
|
|
_, err := conn.Write([]byte{byte(result)})
|
|
return err
|
|
}
|