diff --git a/README.md b/README.md index a61c670..2c456f3 100644 --- a/README.md +++ b/README.md @@ -1,49 +1,34 @@ ## todo -排查启动速度很慢的问题 +鉴权时判断授权的协议 -可配置 processUserConn 超时等待时间 +建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string,提高效率 + +建立数据通道失败后,根据用户所选协议返回对应失败响应 测试跳过认证时的最大 qps(需要注意单机连接数上限,会导致连接失败) -简化数据传递时的 tag 文本量(找一个无重复 hash 的办法),并且在控制通道直接传输目标地址,客户端可以同时开始数据通道和目标地址的连接建立 - -读取 conn 时加上超时机制 - -代理节点超时控制 - -网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应 - 数据通道池化 -协程池化 + +可配配置环境变量 + +- 退出等待时间 +- 数据通道连接超时等待时间 +- 目标地址连接超时等待时间 ### 长期 -配置退出等待时间 +协程池化 需要测试,考虑是否切换到 gnet -实现一个 socks context 以在子组件中获取 socks 相关信息 - -代理端口支持混合端口转发 - 数据通道支持 tcp 多路复用(分离逻辑流) 👆 进阶黑魔法 multipath tcp + 多路复用 考虑一下连接安全性 -内部接口 rtt 是否还有优化空间(当前30-300ms,根据内容大小增长) - -### 代码清理 - -检查 slog 级别: - -ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印 - -其他级别日志就地打印,Info 只用来跟踪关键流程 - ## 开发相关 ### 环境变量 @@ -68,3 +53,34 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如 1. 关闭接听端口,防止新连接接入(user, data, ctrl) 2. 通知并等待所有正在运行的 conn 处理协程全部关闭(user, data, ctrl) 3. 结束所有保存且未使用的 conn 连接(user, ctrl) + + +### 代码清理 + +检查 slog 级别: + +ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印 + +其他级别日志就地打印,Info 只用来跟踪关键流程 + +## 协议 + +### 建立连接 + +客户端(控制通道): + +`version(1)` `id_len(1)` `id_buf(n)` + +服务端(控制通道): + +`status(1)` + +### 开启代理 + +服务端(控制通道): + +`dst_len(1)` `dst_buf(n)` `tag_len(1)` `tag_buf(n)` + +客户端(数据通道): + +`status(1)` `tag_len(1)` `tag_buf(n)` \ No newline at end of file diff --git a/client/client.go b/client/client.go index 5f0fede..e17ab7b 100644 --- a/client/client.go +++ b/client/client.go @@ -2,26 +2,32 @@ package client import ( "bufio" - "encoding/binary" + "flag" "fmt" "io" "log/slog" "net" + "net/http" "os" "proxy-server/pkg/utils" + "runtime" "strconv" "time" "github.com/joho/godotenv" "github.com/pkg/errors" + + _ "net/http/pprof" ) +const Version byte = 1 + type Config struct { - FrpHost string - FrpCtrlPort uint16 - FrpDataPort uint16 - FwdPort uint16 - RetryInterval int + Name string + FwdHost string + FwdCtrlPort uint + FwdDataPort uint + RetryInterval uint } var cfg Config @@ -32,15 +38,24 @@ var frpDataAddr string func Start() { initLog() - initEnv() + initCmd() + initDevEnv() - frpCtrlAddr = net.JoinHostPort(cfg.FrpHost, strconv.Itoa(int(cfg.FrpCtrlPort))) - frpDataAddr = net.JoinHostPort(cfg.FrpHost, strconv.Itoa(int(cfg.FrpDataPort))) + frpCtrlAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdCtrlPort))) + frpDataAddr = net.JoinHostPort(cfg.FwdHost, strconv.Itoa(int(cfg.FwdDataPort))) + + // 性能监控 + go func() { + runtime.SetBlockProfileRate(1) + err := http.ListenAndServe(":6060", nil) + if err != nil { + slog.Error("性能监控服务启动失败", "err", err) + } + }() // 建立控制通道 for { - slog.Info("建立控制通道", "addr", frpCtrlAddr) - err := control() + err := ctrl() if err != nil { slog.Error("建立控制通道失败", err) slog.Info(fmt.Sprintf("%d 秒后重试", cfg.RetryInterval)) @@ -49,7 +64,9 @@ func Start() { } } -func control() error { +func ctrl() error { + slog.Info("建立控制通道", "addr", frpCtrlAddr) + conn, err := net.Dial("tcp", frpCtrlAddr) if err != nil { return errors.Wrap(err, "连接失败") @@ -59,19 +76,50 @@ func control() error { reader := bufio.NewReader(conn) // 请求转发端口 - slog.Info("注册转发端口", "port", cfg.FwdPort) - portBuf := make([]byte, 2) - binary.BigEndian.PutUint16(portBuf, cfg.FwdPort) - _, err = conn.Write(portBuf) + _, err = conn.Write([]byte{Version}) if err != nil { - return errors.Wrap(err, "注册转发端口失败") + return errors.Wrap(err, "发送版本号失败") + } + + // 发送客户端名称 + nameLen := byte(len(cfg.Name)) + nameBuf := make([]byte, 1+nameLen) + nameBuf[0] = nameLen + copy(nameBuf[1:], cfg.Name) + _, err = conn.Write(nameBuf) + if err != nil { + return errors.Wrap(err, "发送 name 失败") + } + + // 等待服务端响应 + respBuf, err := reader.ReadByte() + if err != nil { + return errors.Wrap(err, "接收响应失败") + } + if respBuf != 1 { + return errors.New("服务端响应失败") + } else { + slog.Info("成功建立连接") } // 等待用户连接 // 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道 slog.Info("等待用户连接") for { - tagLen, err := utils.ReadByte(reader) + + // 接收 dst + dstLen, err := reader.ReadByte() + if err != nil { + return errors.Wrap(err, "接收 dstLen 失败") + } + dstBuf, err := utils.ReadBuffer(reader, int(dstLen)) + if err != nil { + return errors.Wrap(err, "接收 dstBuf 失败") + } + addr := string(dstBuf) + + // 接收 tag + tagLen, err := reader.ReadByte() if err != nil { return errors.Wrap(err, "接收 tagLen 失败") } @@ -81,9 +129,8 @@ func control() error { } // 建立数据通道 - slog.Info("收到用户连接,建立数据通道", "tag", string(tagBuf)) go func() { - err := data(tagLen, tagBuf) + err := data(addr, tagBuf) if err != nil { slog.Error("建立数据通道失败", err) } @@ -91,121 +138,86 @@ func control() error { } } -func data(tagLen byte, tagBuf []byte) error { - timerAll := time.Now() +func data(addr string, tag []byte) error { + + // 向服务端建立连接 src, err := net.Dial("tcp", frpDataAddr) if err != nil { - return errors.Wrap(err, "连接失败") - } - defer utils.Close(src) - - // 发送 tag - slog.Info("准备代理流量") - writeBuf := make([]byte, 1+tagLen) - writeBuf[0] = tagLen - copy(writeBuf[1:], tagBuf) - _, err = src.Write(writeBuf) - if err != nil { - return errors.Wrap(err, "发送 tag 失败") + return errors.Wrap(err, "连接服务端失败") } - // 接收目标地址 - slog.Info("接收目标地址") - addrLen, err := utils.ReadByte(src) - if err != nil { - return errors.Wrap(err, "接收 addrLen 失败") - } - addrBuf, err := utils.ReadBuffer(src, int(addrLen)) - if err != nil { - return errors.Wrap(err, "接收 addrBuf 失败") - } - addr := string(addrBuf) + tagLen := byte(len(tag)) + tagBuf := make([]byte, 2+tagLen) + tagBuf[1] = tagLen + copy(tagBuf[2:], tag) - // 数据转发 - slog.Info("向目标 " + addr + " 建立连接") - dest, err := net.Dial("tcp", addr) + // 向目标地址建立连接 + dst, err := net.Dial("tcp", addr) if err != nil { - return errors.Wrap(err, "连接失败") + tagBuf[0] = 0 + } else { + tagBuf[0] = 1 + } + + // 发送连接状态 + _, err = src.Write(tagBuf) + if err != nil { + utils.Close(src) + if dst != nil { + utils.Close(dst) + } + return errors.Wrap(err, "发送连接状态失败") + } + + if tagBuf[0] == 0 { + utils.Close(src) + if dst != nil { + utils.Close(dst) + } + return errors.New("目标地址连接失败") } - defer utils.Close(dest) - slog.Info("开始代理流量 " + src.RemoteAddr().String() + " <-> " + dest.RemoteAddr().String()) - timer := time.Now() - errCh := make(chan error) go func() { - written, err := io.Copy(dest, src) + defer utils.Close(dst) + _, err := io.Copy(dst, src) if err != nil && !errors.Is(err, net.ErrClosed) { slog.Error("上行流量代理失败", "err", err) - errCh <- err - return - } else { - slog.Debug("上行流量代理结束") } - slog.Debug("上行流量", "bytes", written) - errCh <- nil }() go func() { - written, err := io.Copy(src, dest) + defer utils.Close(src) + _, err := io.Copy(src, dst) if err != nil && !errors.Is(err, net.ErrClosed) { slog.Error("下行流量代理失败", "err", err) - errCh <- err - return - } else { - slog.Debug("下行流量代理结束") } - slog.Debug("下行流量", "bytes", written) - errCh <- nil }() - <-errCh - slog.Debug("代理流量结束", "time", time.Since(timer)) - slog.Debug("数据通道结束", "time", time.Since(timerAll)) return nil } -func initEnv() { +func initLog() { + slog.SetLogLoggerLevel(slog.LevelDebug) +} + +func initCmd() { + flag.StringVar(&cfg.Name, "n", "", "客户端名称") + flag.StringVar(&cfg.FwdHost, "h", "", "转发服务器地址") + flag.UintVar(&cfg.FwdCtrlPort, "c", 18080, "转发服务器控制通道端口") + flag.UintVar(&cfg.FwdDataPort, "d", 18081, "转发服务器数据通道端口") + flag.UintVar(&cfg.RetryInterval, "r", 5, "重试间隔时间") + flag.Parse() + + if cfg.Name == "" { + slog.Error("客户端名称不能为空") + flag.Usage() + os.Exit(1) + } +} + +func initDevEnv() { err := godotenv.Load() if err != nil { slog.Debug("没有本地环境变量文件") } - cfg.FrpHost = os.Getenv("FRP_HOST") - - frpCtrlPort, err := strconv.ParseUint(os.Getenv("FRP_CTRL_PORT"), 10, 16) - if err != nil { - panic("环境变量 FRP_CTRL_PORT 的值不合法 " + err.Error()) - } - cfg.FrpCtrlPort = uint16(frpCtrlPort) - if cfg.FrpCtrlPort == 0 { - panic("环境变量 FRP_CTRL_PORT 不能为空") - } - - frpDataPort, err := strconv.ParseUint(os.Getenv("FRP_DATA_PORT"), 10, 16) - if err != nil { - panic("环境变量 FRP_DATA_PORT 的值不合法 " + err.Error()) - } - cfg.FrpDataPort = uint16(frpDataPort) - if cfg.FrpDataPort == 0 { - panic("环境变量 FRP_DATA_PORT 不能为空") - } - - fwdPort, err := strconv.ParseUint(os.Getenv("FWD_PORT"), 10, 16) - if err != nil { - panic("环境变量 FWD_PORT 的值不合法 " + err.Error()) - } - cfg.FwdPort = uint16(fwdPort) - if cfg.FwdPort == 0 { - panic("环境变量 FWD_PORT 不能为空") - } - - cfg.RetryInterval, err = strconv.Atoi(os.Getenv("RETRY_INTERVAL")) - if err != nil { - panic("环境变量 RETRY_INTERVAL 的值不合法 " + err.Error()) - } - if cfg.RetryInterval == 0 { - cfg.RetryInterval = 5 - } -} - -func initLog() { - slog.SetLogLoggerLevel(slog.LevelInfo) + cfg.FwdHost = os.Getenv("FWD_HOST") } diff --git a/cmd/client/.env.example b/cmd/client/.env.example index 86dbc79..e412693 100644 --- a/cmd/client/.env.example +++ b/cmd/client/.env.example @@ -1,7 +1 @@ -FWD_PORT=20001 - -FRP_HOST=127.0.0.1 -FRP_CTRL_PORT=18080 -FRP_DATA_PORT=18081 - -RETRY_INTERVAL=3 +FWD_HOST=127.0.0.1 diff --git a/cmd/testClient/main.go b/cmd/testClient/main.go new file mode 100644 index 0000000..8b4b521 --- /dev/null +++ b/cmd/testClient/main.go @@ -0,0 +1,25 @@ +package main + +import ( + "io" + "net" +) + +func main() { + + ls, _ := net.Listen("tcp", ":8081") + for { + src, _ := ls.Accept() + go func() { + dst, _ := net.Dial("tcp", ":8080") + go func() { + defer dst.Close() + io.Copy(dst, src) + }() + go func() { + defer src.Close() + io.Copy(src, dst) + }() + }() + } +} diff --git a/scripts/sql/init.sql b/scripts/sql/init.sql index d63b8cc..98cb48d 100644 --- a/scripts/sql/init.sql +++ b/scripts/sql/init.sql @@ -3,9 +3,10 @@ drop table if exists nodes cascade; create table nodes ( id serial primary key, name varchar(255) not null unique, + version int not null, + fwd_port int not null, provider varchar(255) not null, location varchar(255) not null, - ip_address varchar(255) not null, created_at timestamp default current_timestamp, updated_at timestamp default current_timestamp, deleted_at timestamp diff --git a/server/fwd/analysis.go b/server/fwd/analysis.go index a380f8e..88d03ca 100644 --- a/server/fwd/analysis.go +++ b/server/fwd/analysis.go @@ -19,7 +19,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error { if err != nil { err = errors.Wrap(err, "analysis sniffing error") } else { - slog.Info( + slog.Debug( "用户访问记录", slog.Uint64("uid", uint64(conn.Auth.Payload.ID)), slog.String("user", conn.RemoteAddr().String()), diff --git a/server/fwd/core/auth.go b/server/fwd/core/auth.go index 356c4ca..f8b76e9 100644 --- a/server/fwd/core/auth.go +++ b/server/fwd/core/auth.go @@ -40,7 +40,7 @@ func CheckIp(conn net.Conn) (*AuthContext, error) { _, localPort, err := net.SplitHostPort(localAddr) // 查询权限记录 - slog.Info("用户 " + remoteHost + " 请求连接到 " + localPort) + slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort) var channels []models.Channel err = orm.DB. Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). diff --git a/server/fwd/ctrl.go b/server/fwd/ctrl.go new file mode 100644 index 0000000..205c522 --- /dev/null +++ b/server/fwd/ctrl.go @@ -0,0 +1,238 @@ +package fwd + +import ( + "bufio" + "context" + "log/slog" + "net" + "proxy-server/pkg/utils" + "proxy-server/server/fwd/core" + "proxy-server/server/fwd/dispatcher" + "proxy-server/server/models" + "proxy-server/server/pkg/env" + "proxy-server/server/pkg/orm" + "strconv" + "strings" + "time" + + "github.com/pkg/errors" +) + +type CtrlCmd struct { + conn net.Conn + buf []byte +} + +var ctrlCmdChan = make(chan CtrlCmd, 1024) + +func (s *Service) startCtrlTun() error { + ctrlPort := env.AppCtrlPort + slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) + + // 监听端口 + ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) + if err != nil { + return errors.Wrap(err, "监听控制通道失败") + } + defer utils.Close(ls) + + // 处理连接 + connCh := utils.ChanConnAccept(s.ctx, ls) + err = nil + for loop := true; loop; { + select { + case <-s.ctx.Done(): + loop = false + case conn, ok := <-connCh: + if !ok { + err = errors.New("获取连接失败") + loop = false + } + s.ctrlConnWg.Add(1) + go func() { + defer s.ctrlConnWg.Done() + defer utils.Close(conn) + err := s.processCtrlConn(conn) + if err != nil { + slog.Error("处理控制通道连接失败", "err", err) + } + }() + } + } + + return err +} + +func (s *Service) processCtrlConn(conn net.Conn) error { + reader := bufio.NewReader(conn) + + // version + version, err := reader.ReadByte() + if err != nil { + _ = ctrlResp(conn, CtrlFail) + return errors.Wrap(err, "获取版本号失败") + } + + // name + nameLen, err := reader.ReadByte() + if err != nil { + _ = ctrlResp(conn, CtrlFail) + return errors.Wrap(err, "获取 name 失败") + } + nameBuf, err := utils.ReadBuffer(reader, int(nameLen)) + if err != nil { + _ = ctrlResp(conn, CtrlFail) + return errors.Wrap(err, "获取 name 失败") + } + name := string(nameBuf) + + if name == "" { + _ = ctrlResp(conn, CtrlFail) + return errors.New("客户端名称不能为空") + } + + // 检查客户端 + var node models.Node + err = orm.DB.First(&node, &models.Node{ + Name: name, + }).Error + if err != nil { + _ = ctrlResp(conn, CtrlFail) + return errors.Wrap(err, "查询客户端失败") + } + + if version != node.Version { + _ = ctrlResp(conn, CtrlFail) + return errors.New("客户端版本不匹配") + } + + err = ctrlResp(conn, CtrlDone) + if err != nil { + return errors.Wrap(err, "向客户端发送响应失败") + } + + port := node.FwdPort + slog.Info("监听转发端口", "port", port, "client", name) + + // 启动转发服务 + proxy, err := dispatcher.New(port) + if err != nil { + return errors.Wrap(err, "创建 socks 转发服务失败") + } + defer proxy.Close() + + s.fwdLesWg.Add(1) + go func() { + defer s.fwdLesWg.Done() + err := proxy.Run() + if err != nil { + slog.Error("代理服务运行失败", "err", err) + } + }() + + // 监听控制通道连接 + errCh := make(chan error, 1) + go func() { + defer close(errCh) + _, err := reader.ReadByte() + errCh <- err + }() + + // 批量同步写入 + go func() { + for { + select { + case <-s.ctx.Done(): + return + case cmd := <-ctrlCmdChan: + _, err := cmd.conn.Write(cmd.buf) + if err != nil { + slog.Error("批量写入失败", "err", err) + utils.Close(cmd.conn) + } + } + } + }() + + // 处理连接 + for { + select { + case <-s.ctx.Done(): + return nil + case err := <-errCh: + switch { + case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."): + slog.Debug("客户端主动断开连接") + return nil + case err == nil: + return errors.New("客户端握手失败") + default: + return errors.Wrap(err, "客户端意外断开连接") + } + case user := <-proxy.Conn: + s.userConnWg.Add(1) + go func() { + defer s.userConnWg.Done() + err := s.processUserConn(user, conn) + if err != nil { + slog.Error("处理用户连接失败", "err", err) + utils.Close(user) + } + }() + } + } +} + +func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { + + // 组织写入信息 + dst := user.DestAddr().String() + dstLen := len(dst) + + tag := user.Tag + tagLen := len(tag) + + writeBuf := make([]byte, 2+dstLen+tagLen) + writeBuf[0] = byte(dstLen) + copy(writeBuf[1:], dst) + writeBuf[1+dstLen] = byte(tagLen) + copy(writeBuf[2+dstLen:], tag) + + // 异步写入命令 + ctrlCmdChan <- CtrlCmd{ + conn: ctrl, + buf: writeBuf, + } + + // 记录用户连接 + s.userConnMap.Store(user.Tag, user) + + // 如果限定时间内没有建立数据通道,则关闭连接 + timeout, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + select { + case <-s.ctx.Done(): + // 服务会在退出时统一关闭未消费的连接 + case <-timeout.Done(): + storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag) + if ok { + slog.Debug("建立数据通道超时", "tag", user.Tag) + utils.Close(storedUser) + } + } + + return nil +} + +type CtrlResult byte + +const ( + CtrlFail CtrlResult = iota + CtrlDone +) + +func ctrlResp(conn net.Conn, result CtrlResult) error { + _, err := conn.Write([]byte{byte(result)}) + return err +} diff --git a/server/fwd/data.go b/server/fwd/data.go new file mode 100644 index 0000000..9fa4e7b --- /dev/null +++ b/server/fwd/data.go @@ -0,0 +1,120 @@ +package fwd + +import ( + "io" + "log/slog" + "net" + "proxy-server/pkg/utils" + "proxy-server/server/pkg/env" + "strconv" + "sync" + + "github.com/pkg/errors" +) + +func (s *Service) startDataTun() error { + dataPort := env.AppDataPort + slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) + + // 监听端口 + ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) + if err != nil { + return errors.Wrap(err, "监听数据通道失败") + } + defer utils.Close(ls) + + go func() { + <-s.ctx.Done() + utils.Close(ls) + }() + + for { + conn, err := ls.Accept() + if err != nil { + return errors.Wrap(err, "监听数据通道失败") + } + + select { + case <-s.ctx.Done(): + utils.Close(conn) + return nil + default: + s.dataConnWg.Add(1) + go func() { + defer s.dataConnWg.Done() + defer utils.Close(conn) + err := s.processDataConn(conn) + if err != nil { + slog.Error("建立数据通道失败失败", "err", err) + } + }() + } + } +} + +func (s *Service) processDataConn(client net.Conn) error { + + // 接收 status + status, err := utils.ReadByte(client) + if err != nil { + return errors.Wrap(err, "从客户端获取 status 失败") + } + + // 接收 tag + tagLen, err := utils.ReadByte(client) + if err != nil { + return errors.Wrap(err, "从客户端获取 tag 失败") + } + tagBuf, err := utils.ReadBuffer(client, int(tagLen)) + if err != nil { + return errors.Wrap(err, "从客户端获取 tag 失败") + } + tag := string(tagBuf) + + // 找到用户连接 + user, ok := s.userConnMap.LoadAndDelete(tag) + if !ok { + return errors.New("用户连接已关闭,tag:" + tag) + } + defer utils.Close(user) + + // 检查状态 + if status != 1 { + return errors.New("目标地址建立连接失败") + } + + // 数据转发 + userPipeReader, userPipeWriter := io.Pipe() + defer utils.Close(userPipeWriter) + teeUser := io.TeeReader(user, userPipeWriter) + go func() { + err := analysisAndLog(user, userPipeReader) + if err != nil { + slog.Error("数据解析失败", "err", err) + } + }() + + wg := sync.WaitGroup{} + wg.Add(2) + go func() { + defer wg.Done() + _, err := io.Copy(client, teeUser) + if err != nil { + slog.Error("数据转发失败 user->client", "err", err) + } + }() + go func() { + defer wg.Done() + _, err := io.Copy(user, client) + if err != nil { + slog.Error("数据转发失败 client->user", "err", err) + } + }() + + select { + case <-s.ctx.Done(): + case <-utils.ChanWgWait(s.ctx, &wg): + } + + return nil +} diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go index 181a3b4..67f0223 100644 --- a/server/fwd/dispatcher/dispatch.go +++ b/server/fwd/dispatcher/dispatch.go @@ -77,8 +77,9 @@ func (s *Server) Run() error { } }() - errCh := make(chan error) + errCh := make(chan error, 1) go func() { + defer close(errCh) err = m.Serve() if err != nil { err = errors.Wrap(err, "dispatcher serve error") diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index d0204b6..2a9d1e1 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -1,22 +1,11 @@ package fwd import ( - "bufio" "context" - "encoding/binary" - "io" "log/slog" - "net" "proxy-server/pkg/utils" "proxy-server/server/fwd/core" - "proxy-server/server/fwd/dispatcher" - "proxy-server/server/pkg/env" - "strconv" - "strings" "sync" - "time" - - "github.com/pkg/errors" ) type Config struct { @@ -122,256 +111,3 @@ func (s *Service) Run() { wg.Wait() slog.Info("fwd 服务已退出") } - -func (s *Service) startCtrlTun() error { - ctrlPort := env.AppCtrlPort - slog.Debug("监听控制通道", slog.Uint64("port", uint64(ctrlPort))) - - // 监听端口 - ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(ctrlPort))) - if err != nil { - return errors.Wrap(err, "监听控制通道失败") - } - defer utils.Close(ls) - - // 处理连接 - connCh := utils.ChanConnAccept(s.ctx, ls) - for { - select { - case <-s.ctx.Done(): - return nil - case conn, ok := <-connCh: - if !ok { - return errors.New("获取连接失败") - } - s.ctrlConnWg.Add(1) - go func() { - defer s.ctrlConnWg.Done() - err := s.processCtrlConn(conn) - if err != nil { - slog.Error("处理控制通道连接失败", "err", err) - utils.Close(conn) - } - }() - } - } -} - -func (s *Service) processCtrlConn(conn net.Conn) error { - reader := bufio.NewReader(conn) - - // 获取转发端口 - portBuf, err := utils.ReadBuffer(reader, 2) - if err != nil { - return errors.Wrap(err, "获取转发端口失败") - } - port := binary.BigEndian.Uint16(portBuf) - - slog.Info("客户端注册", "addr", conn.RemoteAddr().String(), "port", port) - - // 启动转发服务 - proxy, err := dispatcher.New(port) - if err != nil { - return errors.Wrap(err, "创建 socks 转发服务失败") - } - defer proxy.Close() - - s.fwdLesWg.Add(1) - go func() { - defer s.fwdLesWg.Done() - err := proxy.Run() - if err != nil { - slog.Error("代理服务运行失败", "err", err) - return - } - }() - - // 监听客户端连接 - errCh := make(chan error) - defer close(errCh) - go func() { - _, err := reader.ReadByte() - errCh <- err - }() - - // 处理连接 - for { - select { - case <-s.ctx.Done(): - return nil - case err := <-errCh: - switch { - case strings.Contains(err.Error(), "An existing connection was forcibly closed by the remote host."): - slog.Debug("客户端主动断开连接") - return nil - case err == nil: - return errors.New("客户端握手失败") - default: - return errors.Wrap(err, "客户端意外断开连接") - } - case user := <-proxy.Conn: - s.userConnWg.Add(1) - go func() { - defer s.userConnWg.Done() - err := s.processUserConn(user, conn) - if err != nil { - slog.Error("处理用户连接失败", "err", err) - utils.Close(user) - } - }() - } - } -} - -func (s *Service) startDataTun() error { - dataPort := env.AppDataPort - slog.Debug("监听数据通道", slog.Uint64("port", uint64(dataPort))) - - // 监听端口 - ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) - if err != nil { - return errors.Wrap(err, "监听数据通道失败") - } - defer utils.Close(ls) - - go func() { - <-s.ctx.Done() - utils.Close(ls) - }() - - for { - conn, err := ls.Accept() - if err != nil { - return errors.Wrap(err, "监听数据通道失败") - } - - select { - case <-s.ctx.Done(): - utils.Close(conn) - return nil - default: - s.dataConnWg.Add(1) - go func() { - defer s.dataConnWg.Done() - defer utils.Close(conn) - err := s.processDataConn(conn) - if err != nil { - slog.Error("处理数据通道失败", "err", err) - } - }() - } - } -} - -func (s *Service) processDataConn(client net.Conn) error { - - // 读取 tag - var tag string - select { - case <-s.ctx.Done(): - return nil - default: - tagLen, err := utils.ReadByte(client) - if err != nil { - return errors.Wrap(err, "从客户端获取 tag 失败") - } - tagBuf, err := utils.ReadBuffer(client, int(tagLen)) - if err != nil { - return errors.Wrap(err, "从客户端获取 tag 失败") - } - tag = string(tagBuf) - } - - // 找到用户连接 - user, ok := s.userConnMap.LoadAndDelete(tag) - if !ok { - return errors.New("查找用户连接失败") - } - defer utils.Close(user) - - // 发送目标地址 - select { - case <-s.ctx.Done(): - return nil - default: - dest := user.Dest.String() - destLen := len(dest) - destBuf := make([]byte, 1+destLen) - destBuf[0] = byte(destLen) - copy(destBuf[1:], dest) - _, err := client.Write(destBuf) - if err != nil { - return errors.Wrap(err, "向客户端发送目标地址失败") - } - } - - // 数据转发 - userPipeReader, userPipeWriter := io.Pipe() - defer utils.Close(userPipeWriter) - teeUser := io.TeeReader(user, userPipeWriter) - go func() { - err := analysisAndLog(user, userPipeReader) - if err != nil { - slog.Error("数据解析失败", "err", err) - } - }() - - wg := sync.WaitGroup{} - wg.Add(2) - go func() { - defer wg.Done() - _, err := io.Copy(client, teeUser) - if err != nil { - slog.Error("数据转发失败 user->client", "err", err) - } - }() - go func() { - defer wg.Done() - _, err := io.Copy(user, client) - if err != nil { - slog.Error("数据转发失败 client->user", "err", err, "errType") - } - }() - - select { - case <-s.ctx.Done(): - case <-utils.ChanWgWait(s.ctx, &wg): - } - - return nil -} - -func (s *Service) processUserConn(user *core.Conn, ctrl net.Conn) error { - - // 发送 tag - tag := user.Tag - tagLen := len(tag) - tagBuf := make([]byte, 1+tagLen) - tagBuf[0] = byte(tagLen) - copy(tagBuf[1:], tag) - - _, err := ctrl.Write(tagBuf) - if err != nil { - return errors.Wrap(err, "向控制通道发送 tag 失败") - } - - // 记录用户连接 - s.userConnMap.Store(user.Tag, user) - - // 如果限定时间内没有建立数据通道,则关闭连接 - timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - - select { - case <-s.ctx.Done(): - // 服务会在退出时统一关闭未消费的连接 - case <-timeout.Done(): - storedUser, ok := s.userConnMap.LoadAndDelete(user.Tag) - if ok { - slog.Debug("用户连接超时,关闭连接", "tag", user.Tag) - utils.Close(storedUser) - } - } - - return nil -} diff --git a/server/models/node.go b/server/models/node.go index 349bc97..1cf57df 100644 --- a/server/models/node.go +++ b/server/models/node.go @@ -5,10 +5,11 @@ import "gorm.io/gorm" // Node 客户端模型 type Node struct { gorm.Model - Name string - Provider string - Location string - IPAddress string + Name string + Version byte + FwdPort uint16 + Provider string + Location string Channels []Channel `gorm:"foreignKey:NodeId"` } diff --git a/server/pkg/log/logs.go b/server/pkg/log/logs.go index ff7fec0..df96203 100644 --- a/server/pkg/log/logs.go +++ b/server/pkg/log/logs.go @@ -15,11 +15,13 @@ func Init() { mode = "dev" } + level := slog.LevelInfo + switch mode { case "dev": writer := colorable.NewColorable(os.Stdout) logger := slog.New(tint.NewHandler(writer, &tint.Options{ - Level: slog.LevelDebug, + Level: level, TimeFormat: time.RFC3339, ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr { err, ok := attr.Value.Any().(error) @@ -32,7 +34,7 @@ func Init() { slog.SetDefault(logger) case "test": logger := slog.New(slog.NewJSONHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelInfo, + Level: level, })) slog.SetDefault(logger) default: