diff --git a/README.md b/README.md index c8c67af..3b0a70c 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,11 @@ ## todo +读取 conn 时加上超时机制 + 检查 ip 时需要判断同一 ip 的不同写法 +客户端重连后出现连接卡死的情况 + ### 长期 考虑一下连接安全性 diff --git a/cmd/client/.env.example b/cmd/client/.env.example new file mode 100644 index 0000000..321e160 --- /dev/null +++ b/cmd/client/.env.example @@ -0,0 +1,7 @@ +FWD_PORT=20001 + +FRP_HOST=127.0.0.1 +FRP_CTRL_PORT=18080 +FRP_DATA_PORT=18081 + +RETRY_INTERVAL=5 diff --git a/cmd/client/main.go b/cmd/client/main.go index ec61e82..cdc835c 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -3,154 +3,208 @@ package main import ( "bufio" "encoding/binary" + "fmt" "github.com/joho/godotenv" + "github.com/pkg/errors" "io" "log/slog" "net" "os" "proxy-server/pkg/utils" "strconv" + "time" ) -func main() { - slog.SetLogLoggerLevel(slog.LevelDebug) +type Config struct { + FrpHost string + FrpCtrlPort uint16 + FrpDataPort uint16 + FwdPort uint16 + RetryInterval int +} - // 初始化环境变量 +var cfg Config + +var frpCtrlAddr string +var frpDataAddr string + +func main() { + + initLog() + initEnv() + + frpCtrlAddr = net.JoinHostPort(cfg.FrpHost, strconv.Itoa(int(cfg.FrpCtrlPort))) + frpDataAddr = net.JoinHostPort(cfg.FrpHost, strconv.Itoa(int(cfg.FrpDataPort))) + + // 建立控制通道 + for { + slog.Info("建立控制通道", "addr", frpCtrlAddr) + err := control() + if err != nil { + slog.Error("建立控制通道失败", err) + slog.Info(fmt.Sprintf("%d 秒后重试", cfg.RetryInterval)) + time.Sleep(time.Duration(cfg.RetryInterval) * time.Second) + } + } +} + +func control() error { + conn, err := net.Dial("tcp", frpCtrlAddr) + if err != nil { + return errors.Wrap(err, "连接失败") + } + defer utils.Close(&conn) + + // 请求转发端口 + slog.Info("注册转发端口", "port", cfg.FwdPort) + portBuf := make([]byte, 2) + binary.BigEndian.PutUint16(portBuf, cfg.FwdPort) + _, err = conn.Write(portBuf) + if err != nil { + return errors.Wrap(err, "注册转发端口失败") + } + + // 等待用户连接 + // 读写失败后退出重连,防止后续数据读写顺序错位导致卡死控制通道 + for { + slog.Info("等待用户连接") + reader := bufio.NewReader(conn) + + tagLen, err := utils.ReadByte(reader) + if err != nil { + return errors.Wrap(err, "接收 tagLen 失败") + } + tagBuf, err := utils.ReadBuffer(reader, int(tagLen)) + if err != nil { + return errors.Wrap(err, "接收 tagBuf 失败") + } + + // 建立数据通道 + go func() { + slog.Info("收到用户连接,建立数据通道") + err := data(tagLen, tagBuf) + if err != nil { + slog.Error("建立数据通道失败", err) + } + }() + } +} + +func data(tagLen byte, tagBuf []byte) error { + timerAll := time.Now() + src, err := net.Dial("tcp", frpDataAddr) + if err != nil { + return errors.Wrap(err, "连接失败") + } + defer utils.Close(&src) + + // 发送 tag + slog.Info("准备代理流量") + writeBuf := make([]byte, 1+len(tagBuf)) + writeBuf[0] = tagLen + copy(writeBuf[1:], tagBuf) + _, err = src.Write(writeBuf) + if err != nil { + return errors.Wrap(err, "发送 tag 失败") + } + + // 接收目标地址 + slog.Info("接收目标地址") + addrLen, err := utils.ReadByte(src) + if err != nil { + return errors.Wrap(err, "接收 addrLen 失败") + } + addrBuf, err := utils.ReadBuffer(src, int(addrLen)) + if err != nil { + return errors.Wrap(err, "接收 addrBuf 失败") + } + addr := string(addrBuf) + + // 数据转发 + slog.Info("向目标 " + addr + " 建立连接") + dest, err := net.Dial("tcp", addr) + if err != nil { + return errors.Wrap(err, "连接失败") + } + defer utils.Close(&dest) + + slog.Info("开始代理流量 " + src.RemoteAddr().String() + " <-> " + dest.RemoteAddr().String()) + timer := time.Now() + errCh := make(chan error) + go func() { + written, err := io.Copy(dest, src) + if err != nil && !errors.Is(err, net.ErrClosed) { + slog.Error("上行流量代理失败", "err", err) + errCh <- err + return + } else { + slog.Info("上行流量代理结束") + } + slog.Info("上行流量", "bytes", written) + errCh <- nil + }() + go func() { + written, err := io.Copy(src, dest) + if err != nil && !errors.Is(err, net.ErrClosed) { + slog.Error("下行流量代理失败", "err", err) + errCh <- err + return + } else { + slog.Info("下行流量代理结束") + } + slog.Info("下行流量", "bytes", written) + errCh <- nil + }() + <-errCh + slog.Info("代理流量结束", "time", time.Since(timer)) + slog.Info("数据通道结束", "time", time.Since(timerAll)) + return nil +} + +func initEnv() { err := godotenv.Load() if err != nil { slog.Debug("没有本地环境变量文件") } - // 建立控制连接 - for { - slog.Info("与服务端建立控制连接") - frpHost := os.Getenv("FRP_SERVER") - frpPort := os.Getenv("FRP_PORT") - frpAddr := net.JoinHostPort(frpHost, frpPort) - slog.Info("frpAddr", frpAddr) - conn, err := net.Dial("tcp", frpAddr) - if err != nil { - slog.Error("dial error", err) - return - } + cfg.FrpHost = os.Getenv("FRP_HOST") - reader := bufio.NewReader(conn) + frpCtrlPort, err := strconv.ParseUint(os.Getenv("FRP_CTRL_PORT"), 10, 16) + if err != nil { + panic("环境变量 FRP_CTRL_PORT 的值不合法 " + err.Error()) + } + cfg.FrpCtrlPort = uint16(frpCtrlPort) + if cfg.FrpCtrlPort == 0 { + panic("环境变量 FRP_CTRL_PORT 不能为空") + } - // 请求转发端口 - slog.Info("请求转发端口") - fwdPortStr := os.Getenv("FWD_PORT") - fwdPort, err := strconv.ParseUint(fwdPortStr, 10, 16) - if err != nil { - slog.Error("parse error", err) - return - } - portBuf := make([]byte, 2) - binary.BigEndian.PutUint16(portBuf, uint16(fwdPort)) - _, err = conn.Write(portBuf) - if err != nil { - slog.Error("write error", err) - return - } + frpDataPort, err := strconv.ParseUint(os.Getenv("FRP_DATA_PORT"), 10, 16) + if err != nil { + panic("环境变量 FRP_DATA_PORT 的值不合法 " + err.Error()) + } + cfg.FrpDataPort = uint16(frpDataPort) + if cfg.FrpDataPort == 0 { + panic("环境变量 FRP_DATA_PORT 不能为空") + } - // 读取目标地址 - for { - slog.Info("等待建立数据通道命令") + fwdPort, err := strconv.ParseUint(os.Getenv("FWD_PORT"), 10, 16) + if err != nil { + panic("环境变量 FWD_PORT 的值不合法 " + err.Error()) + } + cfg.FwdPort = uint16(fwdPort) + if cfg.FwdPort == 0 { + panic("环境变量 FWD_PORT 不能为空") + } - tagLen, err := reader.ReadByte() - if err != nil { - slog.Error("read error", err) - return - } - - tagBuf := make([]byte, tagLen) - _, err = io.ReadFull(conn, tagBuf) - if err != nil { - slog.Error("read error", err) - return - } - - // 建立数据通道 - go dataTun(tagLen, tagBuf) - } + cfg.RetryInterval, err = strconv.Atoi(os.Getenv("RETRY_INTERVAL")) + if err != nil { + panic("环境变量 RETRY_INTERVAL 的值不合法 " + err.Error()) + } + if cfg.RetryInterval == 0 { + cfg.RetryInterval = 5 } } -func dataTun(tagLen byte, tagBuff []byte) { - - slog.Info("建立数据通道") - src, err := net.Dial("tcp", "localhost:18081") - if err != nil { - slog.Error("建立数据通道失败", err) - return - } - defer func() { - err := src.Close() - if err != nil { - slog.Error("close error", err) - } - }() - - // 发送 tag - _, err = src.Write([]byte{tagLen}) - if err != nil { - slog.Error("write error", err) - return - } - _, err = src.Write(tagBuff) - if err != nil { - slog.Error("write error", err) - return - } - - // 接收目标地址 - slog.Info("等待目标地址") - - addrLen, err := utils.ReadByte(src) - if err != nil { - slog.Error("接收目标地址失败", err) - return - } - addrBuf, err := utils.ReadBuffer(src, int(addrLen)) - if err != nil { - slog.Error("接收目标地址失败", err) - return - } - addr := string(addrBuf) - - // 数据转发 - slog.Info("向 " + addr + " 建立连接") - dest, err := net.Dial("tcp", addr) - if err != nil { - slog.Error("与目标地址连接建立失败", err) - return - } - defer func() { - err = dest.Close() - if err != nil { - slog.Error("close error", err) - } - }() - - slog.Info("开始数据转发 " + src.RemoteAddr().String() + " <-> " + dest.RemoteAddr().String()) - - errCh := make(chan error, 2) - go func() { - _, err := io.Copy(src, dest) - if err != nil { - slog.Error("copy error f2t", err) - errCh <- err - return - } - errCh <- nil - }() - go func() { - _, err := io.Copy(dest, src) - if err != nil { - slog.Error("copy error t2f", err) - errCh <- err - return - } - errCh <- nil - }() - <-errCh +func initLog() { + slog.SetLogLoggerLevel(slog.LevelDebug) } diff --git a/.env.example b/cmd/server/.env.example similarity index 100% rename from .env.example rename to cmd/server/.env.example diff --git a/config/test/docker-compose.yaml b/config/test/docker-compose.yaml index 88457fb..33d0618 100644 --- a/config/test/docker-compose.yaml +++ b/config/test/docker-compose.yaml @@ -34,6 +34,7 @@ services: ports: - "${APP_CTRL_PORT}:${APP_CTRL_PORT}" - "${APP_DATA_PORT}:${APP_DATA_PORT}" + - "20000-20100:20000-20100" networks: - proxy-server-test depends_on: diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 900f7fc..d0ef909 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,6 +1,9 @@ package utils -import "io" +import ( + "io" + "log/slog" +) func ReadByte(reader io.Reader) (byte, error) { buffer, err := ReadBuffer(reader, 1) @@ -20,3 +23,16 @@ func ReadBuffer(reader io.Reader, size int) ([]byte, error) { return buffer, nil } + +func Close[T any](v *T) { + if v == nil { + return + } + closer, ok := any(*v).(io.Closer) + if ok { + err := closer.Close() + if err != nil { + slog.Warn("对象关闭失败", "err", err) + } + } +}