diff --git a/README.md b/README.md index 82bec8a..d8519c4 100644 --- a/README.md +++ b/README.md @@ -63,4 +63,8 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如 1. 构建项目 2. 使用测试配置 `.env.test` 远程启动 docker -## 转发服务 +### 转发服务结束时资源清理 + +1. 关闭接听端口,防止新连接接入(user, data, ctrl) +2. 通知并等待所有正在运行的 conn 处理协程全部关闭(user, data, ctrl) +3. 结束所有保存且未使用的 conn 连接(user, ctrl) diff --git a/go.mod b/go.mod index b9e407f..48ee8da 100644 --- a/go.mod +++ b/go.mod @@ -4,11 +4,11 @@ go 1.24 require ( github.com/gin-gonic/gin v1.10.0 - github.com/google/gopacket v1.1.19 github.com/joho/godotenv v1.5.1 github.com/lmittmann/tint v1.0.7 github.com/mattn/go-colorable v0.1.14 github.com/pkg/errors v0.9.1 + github.com/soheilhy/cmux v0.1.5 gorm.io/driver/postgres v1.5.11 gorm.io/gen v0.3.26 gorm.io/gorm v1.25.12 diff --git a/go.sum b/go.sum index 28c0f52..f9c1dac 100644 --- a/go.sum +++ b/go.sum @@ -36,8 +36,6 @@ github.com/golang-sql/sqlexp v0.1.0/go.mod h1:J4ad9Vo8ZCWQ2GMrC4UCQy1JpCbwU9m3EO github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= @@ -90,6 +88,8 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -109,32 +109,30 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ golang.org/x/arch v0.14.0 h1:z9JUEZWr8x4rR0OU6c4/4t6E6jOZ8/QBS2bBYBm4tx4= golang.org/x/arch v0.14.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus= golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M= -golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= -golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.22.0 h1:D4nJWe9zXqHOmWqj4VMOJhvzj7bEZg4wEYa759z1pH4= golang.org/x/mod v0.22.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8= golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w= golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM= golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY= -golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.28.0 h1:WuB6qZ4RPCQo5aP3WdKZS7i595EdWqWR8vqJTlwTVK8= golang.org/x/tools v0.28.0/go.mod h1:dcIOrVd3mfQKTgrDVQHqCPMWy6lnhfhtX3hLXYVLfRw= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/utils/chan.go b/pkg/utils/chan.go index 61f485a..40587a6 100644 --- a/pkg/utils/chan.go +++ b/pkg/utils/chan.go @@ -9,21 +9,22 @@ import ( ) func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { - connCh := make(chan net.Conn) + ch := make(chan net.Conn) go func() { + defer close(ch) for { conn, err := ls.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return } - slog.Error("接受连接失败", err) // 临时错误重试连接 var ne net.Error if errors.As(err, &ne) && ne.Temporary() { slog.Debug("临时错误重试") continue } + slog.Error("接受连接失败", err) return } // ctx 取消后退出 @@ -31,16 +32,17 @@ func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { case <-ctx.Done(): Close(conn) return - case connCh <- conn: + case ch <- conn: } } }() - return connCh + return ch } func ChanWgWait[T WaitGroup](ctx context.Context, wg T) chan struct{} { ch := make(chan struct{}) go func() { + defer close(ch) wg.Wait() select { case <-ctx.Done(): diff --git a/server/fwd/analysis.go b/server/fwd/analysis.go index 1d1b686..ba9bc4a 100644 --- a/server/fwd/analysis.go +++ b/server/fwd/analysis.go @@ -6,13 +6,13 @@ import ( "io" "log/slog" "proxy-server/pkg/utils" - "proxy-server/server/fwd/socks" + "proxy-server/server/fwd/core" "strings" "github.com/pkg/errors" ) -func analysisAndLog(conn socks.ProxyConn, reader io.Reader) error { +func analysisAndLog(conn *core.Conn, reader io.Reader) error { buf := bufio.NewReader(reader) domain, proto, err := sniffing(buf) @@ -21,12 +21,12 @@ func analysisAndLog(conn socks.ProxyConn, reader io.Reader) error { } else { slog.Info( "用户访问记录", - slog.Uint64("uid", uint64(conn.Uid)), - slog.String("user", conn.Conn.RemoteAddr().String()), - slog.String("proxy", "socks"), - slog.String("node", conn.Conn.LocalAddr().String()), + slog.Uint64("uid", uint64(conn.Auth.Payload.ID)), + slog.String("user", conn.RemoteAddr().String()), + slog.String("proxy", conn.Protocol), + slog.String("node", conn.LocalAddr().String()), slog.String("proto", proto), - slog.String("dest", conn.Dest), + slog.String("dest", conn.DestAddr().String()), slog.String("domain", domain), ) } diff --git a/server/fwd/auth.go b/server/fwd/auth.go deleted file mode 100644 index a8ebc94..0000000 --- a/server/fwd/auth.go +++ /dev/null @@ -1,206 +0,0 @@ -package fwd - -import ( - "context" - "io" - "log/slog" - "net" - "proxy-server/pkg/utils" - "proxy-server/server/fwd/socks" - "proxy-server/server/pkg/orm" - "proxy-server/server/web/models" - "time" - - "github.com/pkg/errors" -) - -type NoAuthAuthenticator struct { -} - -func (a *NoAuthAuthenticator) Method() socks.AuthMethod { - return socks.NoAuth -} - -func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { - - // 获取用户地址 - conn, ok := writer.(net.Conn) - if !ok { - return nil, errors.New("noAuth 认证失败,无法获取连接信息") - } - addr := conn.RemoteAddr().String() - client, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, errors.Wrap(err, "noAuth 认证失败") - } - slog.Debug("用户的地址为 " + client) - - // 获取服务 - server, ok := ctx.Value("service").(*socks.Server) - if !ok { - return nil, errors.New("noAuth 认证失败,无法获取服务信息") - } - node := server.Name - slog.Debug("服务的名称为 " + server.Name) - - // 查询权限记录 - slog.Info("用户 " + client + " 请求连接到 " + node) - var channels []models.Channel - err = orm.DB. - Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", node). - Joins("INNER JOIN public.users u ON channels.user_id = u.id"). - Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", client). - Where(&models.Channel{ - AuthIp: true, - }). - Find(&channels).Error - if err != nil { - return nil, errors.New("noAuth 查询用户权限失败") - } - - // 记录应该只有一条 - channel, err := orm.MaySingle(channels) - if err != nil { - return nil, errors.Wrap(err, "noAuth 没有权限") - } - - // 检查是否需要密码认证 - if channel.AuthPass { - return nil, errors.New("noAuth 没有权限,需要密码认证") - } - - // 检查权限是否过期 - timeout := channel.Expiration.Sub(time.Now()).Seconds() - slog.Info("用户剩余时间", "timeout", timeout) - if timeout <= 0 { - return nil, errors.New("noAuth 权限已过期") - } - slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) - - return &socks.Authentication{ - Method: socks.NoAuth, - Timeout: uint(timeout), - Payload: socks.Payload{ - ID: channel.UserId, - }, - }, nil -} - -type UserPassAuthenticator struct { -} - -func (a *UserPassAuthenticator) Method() socks.AuthMethod { - return socks.UserPassAuth -} - -func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { - - // 检查认证版本 - slog.Debug("验证认证版本") - v, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取版本号失败") - } - if v != socks.AuthVersion { - _, err := writer.Write([]byte{socks.Version, socks.AuthFailure}) - if err != nil { - return nil, errors.Wrap(err, "响应认证失败") - } - return nil, errors.New("认证版本参数不正确") - } - - // 读取账号 - slog.Debug("验证用户账号") - uLen, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取用户名长度失败") - } - usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) - if err != nil { - return nil, errors.Wrap(err, "读取用户名失败") - } - username := string(usernameBuf) - - // 读取密码 - pLen, err := utils.ReadByte(reader) - if err != nil { - return nil, errors.Wrap(err, "读取密码长度失败") - } - passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) - if err != nil { - return nil, errors.Wrap(err, "读取密码失败") - } - password := string(passwordBuf) - - // 查询通道配置 - var channel models.Channel - err = orm.DB. - Where(&models.Channel{ - Username: username, - AuthPass: true, - }). - First(&channel).Error - if err != nil { - return nil, errors.Wrap(err, "查询用户失败") - } - - // 检查密码 todo 哈希 - if channel.Password != password { - return nil, errors.New("密码错误") - } - - // 检查权限是否过期 - timeout := channel.Expiration.Sub(time.Now()).Seconds() - slog.Info("用户剩余时间", "timeout", timeout) - if timeout <= 0 { - return nil, errors.New("权限已过期") - } - - // 如果用户设置了双验证则检查 ip 是否在白名单中 - if channel.AuthIp { - slog.Debug("验证用户 ip") - - // 获取用户地址 - conn, ok := writer.(net.Conn) - if !ok { - return nil, errors.New("无法获取连接信息") - } - addr := conn.RemoteAddr().String() - client, _, err := net.SplitHostPort(addr) - if err != nil { - return nil, errors.Wrap(err, "无法获取连接信息") - } - - // 查询通道配置 - var ips []models.UserIp - err = orm.DB. - Where(&models.UserIp{ - UserId: channel.UserId, - IpAddress: client, - }). - Find(&ips).Error - if err != nil { - return nil, errors.Wrap(err, "查询用户 ip 失败") - } - - // 检查是否在白名单中 - if len(ips) == 0 { - return nil, errors.New("没有权限") - } - } - - // 响应认证成功 - _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess}) - if err != nil { - slog.Error("响应认证失败", "err", err) - return nil, err - } - - return &socks.Authentication{ - Method: socks.UserPassAuth, - Timeout: uint(timeout), - Payload: socks.Payload{ - ID: channel.UserId, - }, - }, nil -} diff --git a/server/fwd/core/auth.go b/server/fwd/core/auth.go new file mode 100644 index 0000000..33571ed --- /dev/null +++ b/server/fwd/core/auth.go @@ -0,0 +1,136 @@ +package core + +import ( + "log/slog" + "net" + "proxy-server/server/models" + "proxy-server/server/pkg/orm" + "time" + + "github.com/pkg/errors" +) + +type Payload struct { + ID uint +} + +type AuthContext struct { + Timeout float64 + Payload Payload + Meta map[string]any +} + +func CheckIp(conn net.Conn) (*AuthContext, error) { + + // 获取用户地址 + remoteAddr := conn.RemoteAddr().String() + remoteHost, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return nil, errors.Wrap(err, "noAuth 认证失败") + } + + // 获取服务端口 + localAddr := conn.LocalAddr().String() + _, localPort, err := net.SplitHostPort(localAddr) + + // 查询权限记录 + slog.Info("用户 " + remoteHost + " 请求连接到 " + localPort) + var channels []models.Channel + err = orm.DB. + Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). + Joins("INNER JOIN public.users u ON channels.user_id = u.id"). + Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost). + Where(&models.Channel{ + AuthIp: true, + }). + Find(&channels).Error + if err != nil { + return nil, errors.New("查询用户权限失败") + } + + // 记录应该只有一条 + channel, err := orm.MaySingle(channels) + if err != nil { + return nil, errors.Wrap(err, "不在白名单内") + } + + // 检查是否需要密码认证 + if channel.AuthPass { + return nil, errors.New("需要密码认证") + } + + // 检查权限是否过期 + timeout := channel.Expiration.Sub(time.Now()).Seconds() + if timeout <= 0 { + return nil, errors.New("权限已过期") + } + + return &AuthContext{ + Timeout: timeout, + Payload: Payload{ + channel.UserId, + }, + }, nil +} + +func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { + + // 查询通道配置 + var channel models.Channel + err := orm.DB. + Where(&models.Channel{ + Username: username, + AuthPass: true, + }). + First(&channel).Error + if err != nil { + return nil, errors.Wrap(err, "用户不存在") + } + + // 检查密码 todo 哈希 + if channel.Password != password { + return nil, errors.New("密码错误") + } + + // 检查权限是否过期 + timeout := channel.Expiration.Sub(time.Now()).Seconds() + if timeout <= 0 { + return nil, errors.New("权限已过期") + } + + // 如果用户设置了双验证则检查 ip 是否在白名单中 + if channel.AuthIp { + slog.Debug("验证用户 ip") + + // 获取用户地址 + remoteAddr := conn.RemoteAddr().String() + remoteHost, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return nil, errors.Wrap(err, "无法获取连接信息") + } + + // 查询通道配置 + + var ips int64 + err = orm.DB. + Where(&models.UserIp{ + UserId: channel.UserId, + IpAddress: remoteHost, + }). + Count(&ips).Error + if err != nil { + return nil, errors.Wrap(err, "查询白名单失败") + } + + if ips == 0 { + return nil, errors.New("不在白名单内") + } + } + + return &AuthContext{ + Timeout: timeout, + Payload: Payload{ + channel.UserId, + }, + }, nil +} diff --git a/server/fwd/core/conn.go b/server/fwd/core/conn.go new file mode 100644 index 0000000..cc367c3 --- /dev/null +++ b/server/fwd/core/conn.go @@ -0,0 +1,67 @@ +package core + +import ( + "bufio" + "fmt" + "net" + "time" +) + +type Conn struct { + Conn net.Conn + Reader *bufio.Reader + Tag string + Protocol string + Dest *FwdAddr + Auth *AuthContext +} + +func (c Conn) Read(b []byte) (n int, err error) { + return c.Reader.Read(b) +} + +func (c Conn) Write(b []byte) (n int, err error) { + return c.Conn.Write(b) +} + +func (c Conn) Close() error { + return c.Conn.Close() +} + +func (c Conn) LocalAddr() net.Addr { + return c.Conn.LocalAddr() +} + +func (c Conn) RemoteAddr() net.Addr { + return c.Conn.RemoteAddr() +} + +func (c Conn) SetDeadline(t time.Time) error { + return c.Conn.SetDeadline(t) +} + +func (c Conn) SetReadDeadline(t time.Time) error { + return c.Conn.SetReadDeadline(t) +} + +func (c Conn) SetWriteDeadline(t time.Time) error { + return c.Conn.SetWriteDeadline(t) +} + +func (c Conn) DestAddr() net.Addr { + return c.Dest +} + +type FwdAddr struct { + IP net.IP + Port int + Domain string +} + +func (a FwdAddr) Network() string { + return "tcp" +} + +func (a FwdAddr) String() string { + return fmt.Sprintf("%s:%d", a.IP, a.Port) +} diff --git a/server/fwd/dispatcher/dispatch.go b/server/fwd/dispatcher/dispatch.go new file mode 100644 index 0000000..975ef25 --- /dev/null +++ b/server/fwd/dispatcher/dispatch.go @@ -0,0 +1,142 @@ +package dispatcher + +import ( + "context" + "log/slog" + "net" + "proxy-server/pkg/utils" + "proxy-server/server/fwd/core" + "proxy-server/server/fwd/socks" + "strconv" + "time" + + "github.com/pkg/errors" + "github.com/soheilhy/cmux" +) + +type Server struct { + ctx context.Context + cancel context.CancelFunc + Port uint16 + Conn chan *core.Conn +} + +func New(port uint16) (*Server, error) { + + if port == 0 { + return nil, errors.New("port is required") + } + + ctx, cancel := context.WithCancel(context.Background()) + return &Server{ + ctx, + cancel, + port, + make(chan *core.Conn), + }, nil +} + +func (s *Server) Close() { + s.cancel() +} + +func (s *Server) Run() error { + port := strconv.Itoa(int(s.Port)) + + ls, err := net.Listen("tcp", ":"+port) + if err != nil { + return errors.Wrap(err, "dispatcher 监听失败") + } + + m := cmux.New(ls) + m.SetReadTimeout(5 * time.Second) + + go func() { + <-s.ctx.Done() + close(s.Conn) + m.Close() + }() + + socksLs := m.Match(cmux.PrefixMatcher(string([]byte{0x05}))) + defer utils.Close(socksLs) + go func() { + err = s.acceptSocks(socksLs) + if err != nil { + slog.Error("dispatcher socks accept error", "err", err) + } + }() + + httpLs := m.Match(cmux.HTTP1Fast("PATCH")) + defer utils.Close(httpLs) + go func() { + err = s.acceptHttp(httpLs) + if err != nil { + slog.Error("dispatcher http accept error", "err", err) + } + }() + + err = m.Serve() + if err != nil { + return errors.Wrap(err, "dispatcher serve error") + } + + return nil +} + +func (s *Server) acceptHttp(ls net.Listener) error { + for { + conn, err := ls.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + var ne net.Error + if errors.As(err, &ne) && ne.Temporary() { + continue + } + return errors.Wrap(err, "dispatcher http accept error") + } + + go func() { + err := s.processHttp(conn) + if err != nil { + slog.Error("dispatcher http process error", "err", err) + } + }() + } +} + +func (s *Server) acceptSocks(ls net.Listener) error { + for { + conn, err := ls.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return nil + } + var ne net.Error + if errors.As(err, &ne) && ne.Temporary() { + continue + } + return errors.Wrap(err, "dispatcher socks accept error") + } + + go func() { + conn, err := socks.Process(s.ctx, conn) + if err != nil { + slog.Error("处理 socks 连接失败", "err", err) + } + select { + case <-s.ctx.Done(): + utils.Close(conn) + case s.Conn <- conn: + } + }() + } +} + +func (s *Server) processHttp(conn net.Conn) error { + return nil +} + +type Conn struct { +} diff --git a/server/fwd/fwd.go b/server/fwd/fwd.go index c83f248..feede1d 100644 --- a/server/fwd/fwd.go +++ b/server/fwd/fwd.go @@ -8,11 +8,11 @@ import ( "log/slog" "net" "proxy-server/pkg/utils" - "proxy-server/server/fwd/socks" + "proxy-server/server/fwd/core" + "proxy-server/server/fwd/dispatcher" "proxy-server/server/pkg/env" "strconv" "sync" - "time" "github.com/pkg/errors" ) @@ -21,14 +21,17 @@ type Config struct { } type Service struct { - Config *Config - ctx context.Context - cancel context.CancelFunc - userConnMap sync.Map + Config *Config + ctx context.Context + cancel context.CancelFunc + userConnMap sync.Map + ctrlConnMap sync.Map + + fwdLesWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup - fwdLesWg utils.CountWaitGroup + userConnWg utils.CountWaitGroup } func New(config *Config) *Service { @@ -42,16 +45,17 @@ func New(config *Config) *Service { ctx: ctx, cancel: cancel, userConnMap: sync.Map{}, - ctrlConnWg: utils.CountWaitGroup{}, - dataConnWg: utils.CountWaitGroup{}, - fwdLesWg: utils.CountWaitGroup{}, + ctrlConnMap: sync.Map{}, + + fwdLesWg: utils.CountWaitGroup{}, + ctrlConnWg: utils.CountWaitGroup{}, + dataConnWg: utils.CountWaitGroup{}, + userConnWg: utils.CountWaitGroup{}, } } func (s *Service) Close() { - start := time.Now() s.cancel() - slog.Debug("退出服务", "duration", time.Since(start)) } func (s *Service) Run() { @@ -95,15 +99,28 @@ func (s *Service) Run() { s.Close() } + wg.Wait() + // 协程建立有先后顺序,不能乱,否则会泄露 + s.dataConnWg.Wait() + s.ctrlConnWg.Wait() + s.fwdLesWg.Wait() + s.userConnWg.Wait() + // 清理资源 s.userConnMap.Range(func(key, value any) bool { - conn := value.(socks.ProxyConn) + conn := value.(core.Conn) utils.Close(conn) - s.userConnMap.Delete(key) return true }) s.userConnMap.Clear() + s.ctrlConnMap.Range(func(key, value any) bool { + conn := value.(net.Conn) + utils.Close(conn) + return true + }) + s.ctrlConnMap.Clear() + s.ctrlConnWg.Wait() slog.Debug("控制通道连接已关闭") s.dataConnWg.Wait() @@ -125,38 +142,33 @@ func (s *Service) startCtrlTun() error { } defer utils.Close(ls) - // 等待连接 - connCh := utils.ChanConnAccept(s.ctx, ls) - defer close(connCh) - // 处理连接 + connCh := utils.ChanConnAccept(s.ctx, ls) for { select { case <-s.ctx.Done(): - slog.Debug("服务关闭 startCtrlTun") return nil case conn, ok := <-connCh: if !ok { - slog.Debug("结束处理连接,由于获取连接失败") return errors.New("获取连接失败") } s.ctrlConnWg.Add(1) go func() { defer s.ctrlConnWg.Done() - defer utils.Close(conn) err := s.processCtrlConn(conn) if err != nil { slog.Error("处理控制通道连接失败", "err", err) + utils.Close(conn) } }() } } } -func (s *Service) processCtrlConn(controller net.Conn) error { - slog.Debug("客户端连入", "addr", controller.RemoteAddr().String()) +func (s *Service) processCtrlConn(conn net.Conn) error { + slog.Debug("客户端连入", "addr", conn.RemoteAddr().String()) - reader := bufio.NewReader(controller) + reader := bufio.NewReader(conn) // 获取转发端口 portBuf, err := utils.ReadBuffer(reader, 2) @@ -165,63 +177,19 @@ func (s *Service) processCtrlConn(controller net.Conn) error { } port := binary.BigEndian.Uint16(portBuf) - // 开放转发端口 todo 混合转发 - slog.Debug("开放转发端口", "port", port) - proxy, err := socks.New(&socks.Config{ - Name: strconv.Itoa(int(port)), - Port: port, - AuthMethods: []socks.Authenticator{ - &UserPassAuthenticator{}, - &NoAuthAuthenticator{}, - }, - }) - if err != nil { - return errors.Wrap(err, "创建 socks 转发服务失败") - } - defer proxy.Close() - + // 开放转发端口 s.fwdLesWg.Add(1) go func() { defer s.fwdLesWg.Done() - err := proxy.Run() + err := s.startFwdTun(port) if err != nil { slog.Error("代理服务启动失败", "err", err) return } }() - // 等待用户连接 - wg := sync.WaitGroup{} - for loop := true; loop; { - select { - case <-s.ctx.Done(): - loop = false - case user, ok := <-proxy.Conn: - if !ok { - loop = false - err = errors.New("无法获取连接") - } - wg.Add(1) - go func() { - defer wg.Done() - - tag := user.Tag() - tagLen := len(tag) - tagBuf := make([]byte, 1+tagLen) - tagBuf[0] = byte(tagLen) - copy(tagBuf[1:], tag) - _, err := controller.Write(tagBuf) - if err != nil { - utils.Close(user) - slog.Error("向客户端发送 tag 失败", "err", err) - return - } - s.userConnMap.Store(tag, user) - }() - } - } - - wg.Wait() + // 记录控制连接 + s.ctrlConnMap.Store(port, conn) return nil } @@ -236,19 +204,14 @@ func (s *Service) startDataTun() error { } defer utils.Close(ls) - // 等待连接 - connCh := utils.ChanConnAccept(s.ctx, ls) - defer close(connCh) - // 处理连接 + connCh := utils.ChanConnAccept(s.ctx, ls) for { select { case <-s.ctx.Done(): - slog.Debug("服务关闭 startDataTun") return nil case conn, ok := <-connCh: if !ok { - slog.Debug("结束处理连接,由于获取连接失败") return errors.New("获取连接失败") } s.dataConnWg.Add(1) @@ -268,67 +231,60 @@ func (s *Service) processDataConn(client net.Conn) error { slog.Info("客户端准备接收数据 " + client.RemoteAddr().String()) // 读取 tag - tagLen, err := utils.ReadByte(client) - if err != nil { - return errors.Wrap(err, "从客户端获取 tag 失败") - } - tagBuf, err := utils.ReadBuffer(client, int(tagLen)) - if err != nil { - return errors.Wrap(err, "从客户端获取 tag 失败") - } - tag := string(tagBuf) - - // 找到用户连接 - var data socks.ProxyConn + var tag string select { case <-s.ctx.Done(): return nil default: - dataAny, ok := s.userConnMap.Load(tag) - if !ok { - return errors.New("查找用户连接失败") + tagLen, err := utils.ReadByte(client) + if err != nil { + return errors.Wrap(err, "从客户端获取 tag 失败") } - data = dataAny.(socks.ProxyConn) - defer func() { - s.userConnMap.Delete(tag) - utils.Close(data) - }() + tagBuf, err := utils.ReadBuffer(client, int(tagLen)) + if err != nil { + return errors.Wrap(err, "从客户端获取 tag 失败") + } + tag = string(tagBuf) } - // 响应用户 - user := data.Conn - err = socks.SendSuccess(user, client) - if err != nil { - // todo 考虑是否需要处理服务关闭后导致用户连接被关闭的情况 - return errors.Wrap(err, "向用户发送成功消息失败") + // 找到用户连接 + userAny, ok := s.userConnMap.Load(tag) + if !ok { + return errors.New("查找用户连接失败") } + user := userAny.(*core.Conn) + defer utils.Close(user) + defer s.userConnMap.Delete(tag) // 发送目标地址 - dest := data.Dest - destLen := len(dest) - destBuf := make([]byte, 1+destLen) - destBuf[0] = byte(destLen) - copy(destBuf[1:], dest) - _, err = client.Write(destBuf) - if err != nil { - return errors.Wrap(err, "向客户端发送目标地址失败") + select { + case <-s.ctx.Done(): + return nil + default: + dest := user.Dest.String() + destLen := len(dest) + destBuf := make([]byte, 1+destLen) + destBuf[0] = byte(destLen) + copy(destBuf[1:], dest) + _, err := client.Write(destBuf) + if err != nil { + return errors.Wrap(err, "向客户端发送目标地址失败") + } } // 数据转发 - slog.Info("开始数据转发 " + client.RemoteAddr().String() + " <-> " + data.Dest) - userPipeReader, userPipeWriter := io.Pipe() defer utils.Close(userPipeWriter) teeUser := io.TeeReader(user, userPipeWriter) go func() { - err := analysisAndLog(data, userPipeReader) + err := analysisAndLog(user, userPipeReader) if err != nil { slog.Error("数据解析失败", "err", err) } }() wg := sync.WaitGroup{} - wg.Add(1) + wg.Add(2) go func() { defer wg.Done() _, err := io.Copy(client, teeUser) @@ -336,7 +292,6 @@ func (s *Service) processDataConn(client net.Conn) error { slog.Error("数据转发失败 user->client", "err", err) } }() - wg.Add(1) go func() { defer wg.Done() _, err := io.Copy(user, client) @@ -348,8 +303,76 @@ func (s *Service) processDataConn(client net.Conn) error { } } }() - wg.Wait() - slog.Info("数据转发结束 " + client.RemoteAddr().String() + " <-> " + data.Dest) + select { + case <-s.ctx.Done(): + case <-utils.ChanWgWait(s.ctx, &wg): + } + return nil } + +func (s *Service) startFwdTun(port uint16) error { + slog.Debug("监听转发通道", "port", port) + + proxy, err := dispatcher.New(port) + if err != nil { + return errors.Wrap(err, "创建 socks 转发服务失败") + } + defer proxy.Close() + + go func() { + err := proxy.Run() + if err != nil { + slog.Error("代理服务异常退出", "err", err) + } + }() + + for { + select { + case <-s.ctx.Done(): + return nil + case conn := <-proxy.Conn: + s.userConnWg.Add(1) + go func() { + defer s.userConnWg.Done() + err := s.processUserConn(conn, port) + if err != nil { + slog.Error("处理用户连接失败", "err", err) + } + }() + } + } +} + +func (s *Service) processUserConn(conn *core.Conn, port uint16) error { + + // 记录用户连接 + s.userConnMap.Store(conn.Tag, conn) + + // 通知客户端建立数据通道 + ctrlConnAny, ok := s.ctrlConnMap.Load(port) + if !ok { + return errors.New("查找控制连接失败") + } + ctrlConn := ctrlConnAny.(net.Conn) + + // 发送 tag + select { + case <-s.ctx.Done(): + return nil + default: + tag := conn.Tag + tagLen := len(tag) + tagBuf := make([]byte, 1+tagLen) + tagBuf[0] = byte(tagLen) + copy(tagBuf[1:], tag) + _, err := ctrlConn.Write(tagBuf) + if err != nil { + return errors.Wrap(err, "向控制通道发送 tag 失败") + } + } + + return nil + +} diff --git a/server/fwd/http/http.go b/server/fwd/http/http.go index fc732a9..d02cfda 100644 --- a/server/fwd/http/http.go +++ b/server/fwd/http/http.go @@ -1,4 +1 @@ package http - -func Start() { -} diff --git a/server/fwd/socks/auth.go b/server/fwd/socks/auth.go deleted file mode 100644 index e672747..0000000 --- a/server/fwd/socks/auth.go +++ /dev/null @@ -1,87 +0,0 @@ -package socks - -import ( - "context" - "io" - "log/slog" - "proxy-server/pkg/utils" - "slices" - - "github.com/pkg/errors" -) - -type AuthMethod byte - -const ( - AuthVersion = byte(1) - AuthSuccess = byte(0) - AuthFailure = byte(1) - NoAuth = AuthMethod(0) - UserPassAuth = AuthMethod(2) - NoAcceptable = byte(0xFF) -) - -type Authenticator interface { - Method() AuthMethod - Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*Authentication, error) -} - -// authenticate 执行认证流程 -func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*Authentication, error) { - - // 版本检查 - err := checkVersion(reader) - if err != nil { - return nil, err - } - - // 获取客户端认证方式 - nAuth, err := utils.ReadByte(reader) - if err != nil { - return nil, err - } - methods, err := utils.ReadBuffer(reader, int(nAuth)) - if err != nil { - return nil, err - } - - // 认证客户端连接 - for _, authenticator := range s.config.AuthMethods { - method := authenticator.Method() - if slices.Contains(methods, byte(method)) { - slog.Debug("使用的认证方式", method) - - _, err := writer.Write([]byte{Version, byte(method)}) - if err != nil { - slog.Error("响应认证方式失败", err) - return nil, err - } - - ctx := context.WithValue(context.Background(), "service", s) - authContext, err := authenticator.Authenticate(ctx, reader, writer) - if err != nil { - return nil, err - } - return authContext, nil - } - } - - // 无适用的认证方式 - _, err = writer.Write([]byte{Version, NoAcceptable}) - if err != nil { - return nil, err - } - - return nil, errors.New("没有适用的认证方式") -} - -type Authentication struct { - Method AuthMethod - Timeout uint - Payload Payload - Data map[string]any -} - -type Payload struct { - ID uint -} diff --git a/server/fwd/socks/error.go b/server/fwd/socks/error.go deleted file mode 100644 index 3f3f24d..0000000 --- a/server/fwd/socks/error.go +++ /dev/null @@ -1,7 +0,0 @@ -package socks - -type ConfigError string - -func (c ConfigError) Error() string { - return string(c) -} diff --git a/server/fwd/socks/request.go b/server/fwd/socks/request.go deleted file mode 100644 index e639004..0000000 --- a/server/fwd/socks/request.go +++ /dev/null @@ -1,353 +0,0 @@ -package socks - -import ( - "context" - "fmt" - "io" - "log/slog" - "net" - "proxy-server/pkg/utils" - "strconv" - - "github.com/pkg/errors" -) - -const ( - ConnectCommand = byte(1) - BindCommand = byte(2) - AssociateCommand = byte(3) - ipv4Address = byte(1) - fqdnAddress = byte(3) - ipv6Address = byte(4) -) - -const ( - successReply byte = iota - serverFailure - ruleFailure - networkUnreachable - hostUnreachable - connectionRefused - ttlExpired - commandNotSupported - addrTypeNotSupported -) - -var ( - unrecognizedAddrType = fmt.Errorf("unrecognized address type") -) - -// AddressRewriter is used to rewrite a destination transparently -type AddressRewriter interface { - Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec) -} - -// AddrSpec 地址 -type AddrSpec struct { - FQDN string - IP net.IP - Port int -} - -func (a AddrSpec) String() string { - if a.FQDN != "" { - return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port) - } - return fmt.Sprintf("%s:%d", a.IP, a.Port) -} - -// Address returns a string suitable to dial; prefer returning IP-based -// address, fallback to FQDN -func (a AddrSpec) Address() string { - if 0 != len(a.IP) { - return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port)) - } - return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) -} - -func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) { - - // 检查版本 - err := checkVersion(reader) - if err != nil { - return nil, err - } - - // 检查连接命令 - command, err := utils.ReadByte(reader) - if err != nil { - return nil, err - } - - slog.Debug("客户端使用的连接指令:%v", command) - if command != ConnectCommand && command != BindCommand && command != AssociateCommand { - err = sendReply(writer, commandNotSupported, nil) - if err != nil { - return nil, err - } - return nil, errors.New("不支持该连接指令") - } - - // 跳过保留字段 rsv - _, err = utils.ReadByte(reader) - if err != nil { - return nil, err - } - - // 获取目标地址 - dest, err := s.parseTarget(reader, writer) - if err != nil { - return nil, err - } - - request := &Request{ - Version: Version, - Command: command, - DestAddr: dest, - bufConn: reader, - } - - return request, nil -} - -func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) { - dest := &AddrSpec{} - - aTypeBuf := make([]byte, 1) - _, err := reader.Read(aTypeBuf) - if err != nil { - return nil, err - } - - switch aTypeBuf[0] { - - case ipv4Address: - addr := make([]byte, 4) - _, err := io.ReadFull(reader, addr) - if err != nil { - return nil, err - } - dest.IP = addr - - case ipv6Address: - addr := make([]byte, 16) - _, err := io.ReadFull(reader, addr) - if err != nil { - return nil, err - } - dest.IP = addr - - case fqdnAddress: - aLenBuf := make([]byte, 1) - _, err := reader.Read(aLenBuf) - if err != nil { - return nil, err - } - - fqdnBuff := make([]byte, int(aLenBuf[0])) - _, err = io.ReadFull(reader, fqdnBuff) - if err != nil { - return nil, err - } - dest.FQDN = string(fqdnBuff) - - // 域名解析 - addr, err := s.config.Resolver.Resolve(dest.FQDN) - if err != nil { - err := sendReply(writer, hostUnreachable, nil) - if err != nil { - return nil, fmt.Errorf("failed to send reply: %v", err) - } - return nil, fmt.Errorf("failed to resolve destination '%v': %v", dest.FQDN, err) - } - dest.IP = addr - - default: - err := sendReply(writer, addrTypeNotSupported, nil) - if err != nil { - return nil, err - } - return nil, unrecognizedAddrType - } - - portBuf := make([]byte, 2) - _, err = io.ReadFull(reader, portBuf) - if err != nil { - return nil, err - } - dest.Port = (int(portBuf[0]) << 8) | int(portBuf[1]) - - return dest, nil -} - -// A Request represents request received by a server -type Request struct { - // Protocol version - Version uint8 - // Requested command - Command uint8 - // Authentication provided during negotiation - Authentication *Authentication - // AddrSpec of the network that sent the request - RemoteAddr *AddrSpec - // AddrSpec of the desired destination - DestAddr *AddrSpec - // AddrSpec of the actual destination (might be affected by rewrite) - realDestAddr *AddrSpec - bufConn io.Reader -} - -func (s *Server) handle(req *Request, conn net.Conn) error { - ctx := context.Background() - - // 目标地址重写 - req.realDestAddr = req.DestAddr - if s.config.Rewriter != nil { - ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req) - } - - // 根据协商方法建立连接 - switch req.Command { - case ConnectCommand: - return s.handleConnect(ctx, conn, req) - case BindCommand: - return s.handleBind(ctx, conn, req) - case AssociateCommand: - return s.handleAssociate(ctx, conn, req) - default: - return fmt.Errorf("unsupported command: %v", req.Command) - } -} - -func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error { - // 检查规则集约束 - s.config.Logger.Printf("检查约束规则\n") - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { - if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return fmt.Errorf("request to %v blocked by rules", req.DestAddr) - } else { - ctx = ctx_ - } - - slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接") - select { - case <-s.ctx.Done(): - if conn != nil { - utils.Close(conn) - } - case s.Conn <- ProxyConn{ - req.Authentication.Payload.ID, - conn, - req.realDestAddr.Address(), - }: - } - return nil -} - -func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error { - // Check if this is allowed - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { - if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return fmt.Errorf("bind to %v blocked by rules", req.DestAddr) - } else { - ctx = ctx_ - } - - // TODO: Support bind - if err := sendReply(conn, commandNotSupported, nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return nil -} - -func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error { - // Check if this is allowed - if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { - if err := sendReply(conn, ruleFailure, nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return fmt.Errorf("associate to %v blocked by rules", req.DestAddr) - } else { - ctx = ctx_ - } - - // TODO: Support associate - if err := sendReply(conn, commandNotSupported, nil); err != nil { - return fmt.Errorf("failed to send reply: %v", err) - } - return nil -} - -func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error { - var addrType uint8 - var addrBody []byte - var addrPort uint16 - switch { - case addr == nil: - addrType = ipv4Address - addrBody = []byte{0, 0, 0, 0} - addrPort = 0 - - case addr.FQDN != "": - addrType = fqdnAddress - addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...) - addrPort = uint16(addr.Port) - - case addr.IP.To4() != nil: - addrType = ipv4Address - addrBody = addr.IP.To4() - addrPort = uint16(addr.Port) - - case addr.IP.To16() != nil: - addrType = ipv6Address - addrBody = addr.IP.To16() - addrPort = uint16(addr.Port) - - default: - return fmt.Errorf("failed to format address: %v", addr) - } - - msg := make([]byte, 6+len(addrBody)) - msg[0] = Version - msg[1] = resp - msg[2] = 0 // Reserved - msg[3] = addrType - copy(msg[4:], addrBody) - msg[4+len(addrBody)] = byte(addrPort >> 8) - msg[4+len(addrBody)+1] = byte(addrPort & 0xff) - - _, err := w.Write(msg) - return err -} - -func SendSuccess(user net.Conn, target net.Conn) error { - local := target.LocalAddr().(*net.TCPAddr) - bind := AddrSpec{IP: local.IP, Port: local.Port} - err := sendReply(user, successReply, &bind) - if err != nil { - return err - } - return nil -} - -type ProxyConn struct { - Uid uint - // 用户连入的连接 - Conn net.Conn - // 用户目标地址 - Dest string -} - -func (d ProxyConn) Tag() string { - local := d.Conn.LocalAddr() - remote := d.Conn.RemoteAddr() - return fmt.Sprintf("%s-%s", remote, local) -} - -func (d ProxyConn) Close() error { - return d.Conn.Close() -} diff --git a/server/fwd/socks/resolver.go b/server/fwd/socks/resolver.go deleted file mode 100644 index 04f45c6..0000000 --- a/server/fwd/socks/resolver.go +++ /dev/null @@ -1,21 +0,0 @@ -package socks - -import ( - "net" -) - -// NameResolver 域名解析器 -type NameResolver interface { - Resolve(name string) (net.IP, error) -} - -// DNSResolver 使用系统 dns 服务解析域名 -type DNSResolver struct{} - -func (d DNSResolver) Resolve(name string) (net.IP, error) { - addr, err := net.ResolveIPAddr("ip", name) - if err != nil { - return nil, err - } - return addr.IP, err -} diff --git a/server/fwd/socks/ruleset.go b/server/fwd/socks/ruleset.go deleted file mode 100644 index 00a6686..0000000 --- a/server/fwd/socks/ruleset.go +++ /dev/null @@ -1,41 +0,0 @@ -package socks - -import ( - "context" -) - -// RuleSet is used to provide custom rules to allow or prohibit actions -type RuleSet interface { - Allow(ctx context.Context, req *Request) (context.Context, bool) -} - -// PermitAll returns a RuleSet which allows all types of connections -func PermitAll() RuleSet { - return &PermitCommand{true, true, true} -} - -// PermitNone returns a RuleSet which disallows all types of connections -func PermitNone() RuleSet { - return &PermitCommand{false, false, false} -} - -// PermitCommand is an implementation of the RuleSet which -// enables filtering supported commands -type PermitCommand struct { - EnableConnect bool - EnableBind bool - EnableAssociate bool -} - -func (p *PermitCommand) Allow(ctx context.Context, req *Request) (context.Context, bool) { - switch req.Command { - case ConnectCommand: - return ctx, p.EnableConnect - case BindCommand: - return ctx, p.EnableBind - case AssociateCommand: - return ctx, p.EnableAssociate - } - - return ctx, false -} diff --git a/server/fwd/socks/socks.go b/server/fwd/socks/socks.go index e9fdde6..9615bfa 100644 --- a/server/fwd/socks/socks.go +++ b/server/fwd/socks/socks.go @@ -3,212 +3,90 @@ package socks import ( "bufio" "context" + "encoding/binary" "fmt" "io" - "log" "log/slog" "net" - "os" "proxy-server/pkg/utils" - "strconv" - "time" + "proxy-server/server/fwd/core" + "slices" "github.com/pkg/errors" ) const ( - Version = byte(5) + Version = byte(5) + AuthVersion = byte(1) ) -type Config struct { - Name string +const ( + NoAuth = byte(0) + UserPassAuth = byte(2) + NoAcceptable = byte(0xFF) +) - Host string - Port uint16 +const ( + AuthSuccess = byte(0) + AuthFailure = byte(1) +) - // 认证方法 - AuthMethods []Authenticator +const ( + ConnectCommand = byte(1) + BindCommand = byte(2) + AssociateCommand = byte(3) +) - // 域名解析 - Resolver NameResolver +const ( + ipv4Address = byte(1) + fqdnAddress = byte(3) + ipv6Address = byte(4) +) - // 自定义认证规则 - Rules RuleSet - - // 地址重写 - Rewriter AddressRewriter - - // 用于 bind 和 associate - BindIP net.IP - - // Logger - Logger *log.Logger - - // 自定义连接流程 - Dial func(network, addr string) (net.Conn, error) -} - -type Server struct { - config *Config - ctx context.Context - cancel context.CancelFunc - wg utils.CountWaitGroup - Name string - Port uint16 - Conn chan ProxyConn -} - -// New 创建服务器 -func New(conf *Config) (*Server, error) { - if conf == nil { - conf = &Config{} - } - - if len(conf.AuthMethods) == 0 { - return nil, ConfigError("认证方法不能为空") - } - - if conf.Resolver == nil { - conf.Resolver = DNSResolver{} - } - - if conf.Rules == nil { - conf.Rules = PermitAll() - } - - if conf.Logger == nil { - conf.Logger = log.New(os.Stdout, "", log.LstdFlags) - } - - if conf.Dial == nil { - conf.Dial = func(network, addr string) (net.Conn, error) { - return net.Dial(network, addr) - } - } - - ctx, cancel := context.WithCancel(context.Background()) - return &Server{ - config: conf, - ctx: ctx, - cancel: cancel, - wg: utils.CountWaitGroup{}, - Name: conf.Name, - Port: conf.Port, - Conn: make(chan ProxyConn), - }, nil -} - -// Run 监听端口 -func (s *Server) Run() error { - slog.Debug("启动 socks5 代理服务") - - // 监听端口 - host := s.config.Host - port := s.config.Port - addr := net.JoinHostPort(host, strconv.Itoa(int(port))) - ls, err := net.Listen("tcp", addr) - if err != nil { - return errors.Wrap(err, "监听端口失败") - } - defer utils.Close(ls) - slog.Debug("正在监听端口", slog.Uint64("port", uint64(port))) - - // 处理连接 - connCh := utils.ChanConnAccept(s.ctx, ls) - defer close(connCh) - - err = nil - for loop := true; loop; { - select { - case <-s.ctx.Done(): - slog.Debug("socks 服务主动停止") - loop = false - case conn, ok := <-connCh: - if !ok { - err = errors.New("意外错误,无法获取连接") - loop = false - s.Close() - break - } - s.wg.Add(1) - go func() { - defer s.wg.Done() - // 连接要传出,不能在这里关闭连接 - err := s.process(conn) - if err != nil { - slog.Error("处理连接失败", err) - } - }() - } - } - - // 关闭服务 - timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - wgCh := utils.ChanWgWait(timeout, &s.wg) - - err = nil - select { - case <-timeout.Done(): - err = errors.New("关闭超时(强制关闭)") - case <-wgCh: - } - - close(s.Conn) - return err -} - -// Close 关闭服务 -func (s *Server) Close() { - s.cancel() -} - -// process 建立连接 -func (s *Server) process(conn net.Conn) error { - slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接") +const ( + successReply byte = iota + serverFailure + ruleFailure + networkUnreachable + hostUnreachable + connectionRefused + ttlExpired + commandNotSupported + addrTypeNotSupported +) +// Process 处理连接 +func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) { reader := bufio.NewReader(conn) // 认证 - slog.Debug("开始认证流程") - authContext, err := s.authenticate(reader, conn) + auth, err := authenticate(ctx, reader, conn) if err != nil { - utils.Close(conn) - slog.Error("认证失败", err) - return err - } else { - slog.Debug("认证完成") + return nil, errors.Wrap(err, "认证失败") } // 处理连接请求 - slog.Debug("处理连接请求") - request, err := s.request(reader, conn) + request, err := request(ctx, reader, conn) if err != nil { - slog.Error("连接请求处理失败", err) - return err - } else { - slog.Debug("连接请求处理完成") + return nil, errors.Wrap(err, "处理连接请求失败") } - request.Authentication = authContext - user, ok := conn.RemoteAddr().(*net.TCPAddr) - if !ok { - return fmt.Errorf("获取用户地址失败") + // 代理连接 + if request.Command != ConnectCommand { + return nil, errors.New("不支持的连接指令") } - request.RemoteAddr = &AddrSpec{ - IP: user.IP, - Port: user.Port, - } + // 响应成功 + err = sendReply(conn, successReply, request.DestAddr) - // 处理请求 - slog.Debug("开始代理流量") - err = s.handle(request, conn) - if err != nil { - return err - } - - return nil + return &core.Conn{ + Conn: conn, + Reader: reader, + Protocol: "socks5", + Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), + Dest: request.DestAddr, + Auth: auth, + }, nil } // checkVersion 检查客户端版本 @@ -218,11 +96,280 @@ func checkVersion(reader io.Reader) error { return err } - slog.Debug("客户端请求版本", "version", version) - if version != Version { return errors.New("客户端版本不兼容") } return nil } + +// authenticate 执行认证流程 +func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*core.AuthContext, error) { + + // 版本检查 + err := checkVersion(reader) + if err != nil { + return nil, err + } + + // 获取客户端认证方式 + nAuth, err := utils.ReadByte(reader) + if err != nil { + return nil, err + } + methods, err := utils.ReadBuffer(reader, int(nAuth)) + if err != nil { + return nil, err + } + + // 密码模式 + if slices.Contains(methods, UserPassAuth) { + _, err := conn.Write([]byte{Version, byte(UserPassAuth)}) + if err != nil { + return nil, errors.Wrap(err, "响应认证方式失败") + } + + // 检查认证版本 + slog.Debug("验证认证版本") + v, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取版本号失败") + } + if v != AuthVersion { + _, err := conn.Write([]byte{Version, AuthFailure}) + if err != nil { + return nil, errors.Wrap(err, "响应认证失败") + } + return nil, errors.New("认证版本参数不正确") + } + + // 读取账号 + slog.Debug("验证用户账号") + uLen, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取用户名长度失败") + } + usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) + if err != nil { + return nil, errors.Wrap(err, "读取用户名失败") + } + username := string(usernameBuf) + + // 读取密码 + pLen, err := utils.ReadByte(reader) + if err != nil { + return nil, errors.Wrap(err, "读取密码长度失败") + } + passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) + if err != nil { + return nil, errors.Wrap(err, "读取密码失败") + } + password := string(passwordBuf) + + // 检查权限 + authContext, err := core.CheckPass(conn, username, password) + if err != nil { + return nil, errors.Wrap(err, "权限检查失败") + } + + // 响应认证成功 + _, err = conn.Write([]byte{AuthVersion, AuthSuccess}) + if err != nil { + return nil, errors.Wrap(err, "响应认证成功失败") + } + + return authContext, nil + } + + // 无认证 + if slices.Contains(methods, NoAuth) { + _, err = conn.Write([]byte{Version, NoAuth}) + if err != nil { + return nil, errors.Wrap(err, "响应认证方式失败") + } + + authContext, err := core.CheckIp(conn) + if err != nil { + return nil, errors.Wrap(err, "权限检查失败") + } + + return authContext, nil + } + + // 无适用的认证方式 + _, err = conn.Write([]byte{Version, NoAcceptable}) + if err != nil { + return nil, err + } + + return nil, errors.New("没有适用的认证方式") +} + +type Request struct { + Command uint8 + DestAddr *core.FwdAddr +} + +// request 处理连接请求 +func request(ctx context.Context, reader io.Reader, writer io.Writer) (*Request, error) { + + // 检查版本 + err := checkVersion(reader) + if err != nil { + return nil, err + } + + // 检查连接命令 + command, err := utils.ReadByte(reader) + if err != nil { + return nil, err + } + + if command != ConnectCommand && command != BindCommand && command != AssociateCommand { + err = sendReply(writer, commandNotSupported, nil) + if err != nil { + return nil, err + } + return nil, errors.New("不支持该连接指令") + } + + // 跳过保留字段 rsv + _, err = utils.ReadByte(reader) + if err != nil { + return nil, err + } + + // 获取目标地址 + dest, err := parseTarget(reader, writer) + if err != nil { + return nil, err + } + + request := &Request{ + Command: command, + DestAddr: dest, + } + + return request, nil +} + +func parseTarget(reader io.Reader, writer io.Writer) (*core.FwdAddr, error) { + dest := &core.FwdAddr{} + + aTypeBuf, err := utils.ReadByte(reader) + if err != nil { + return nil, err + } + + switch aTypeBuf { + + case ipv4Address: + addr := make([]byte, 4) + _, err := io.ReadFull(reader, addr) + if err != nil { + return nil, err + } + dest.IP = addr + + case ipv6Address: + addr := make([]byte, 16) + _, err := io.ReadFull(reader, addr) + if err != nil { + return nil, err + } + dest.IP = addr + + case fqdnAddress: + aLenBuf := make([]byte, 1) + _, err := reader.Read(aLenBuf) + if err != nil { + return nil, err + } + + fqdnBuff := make([]byte, int(aLenBuf[0])) + _, err = io.ReadFull(reader, fqdnBuff) + if err != nil { + return nil, err + } + dest.Domain = string(fqdnBuff) + + // 域名解析 + addr, err := net.ResolveIPAddr("ip", dest.Domain) + if err != nil { + err := sendReply(writer, hostUnreachable, nil) + if err != nil { + return nil, fmt.Errorf("failed to send reply: %v", err) + } + return nil, fmt.Errorf("failed to resolve destination '%v': %v", dest.Domain, err) + } + dest.IP = addr.IP + + default: + err := sendReply(writer, addrTypeNotSupported, nil) + if err != nil { + return nil, err + } + return nil, fmt.Errorf("unrecognized address type") + } + + portBuf := make([]byte, 2) + _, err = io.ReadFull(reader, portBuf) + if err != nil { + return nil, err + } + dest.Port = int(binary.BigEndian.Uint16(portBuf)) + + return dest, nil +} + +func sendReply(w io.Writer, resp uint8, addr *core.FwdAddr) error { + var addrType uint8 + var addrBody []byte + var addrPort uint16 + switch { + case addr == nil: + addrType = ipv4Address + addrBody = []byte{0, 0, 0, 0} + addrPort = 0 + + case addr.Domain != "": + addrType = fqdnAddress + addrBody = append([]byte{byte(len(addr.Domain))}, addr.Domain...) + addrPort = uint16(addr.Port) + + case addr.IP.To4() != nil: + addrType = ipv4Address + addrBody = addr.IP.To4() + addrPort = uint16(addr.Port) + + case addr.IP.To16() != nil: + addrType = ipv6Address + addrBody = addr.IP.To16() + addrPort = uint16(addr.Port) + + default: + return fmt.Errorf("failed to format address: %v", addr) + } + + msg := make([]byte, 6+len(addrBody)) + msg[0] = Version + msg[1] = resp + msg[2] = 0 // Reserved + msg[3] = addrType + copy(msg[4:], addrBody) + msg[4+len(addrBody)] = byte(addrPort >> 8) + msg[4+len(addrBody)+1] = byte(addrPort & 0xff) + + _, err := w.Write(msg) + return err +} + +func SendSuccess(user net.Conn, target net.Conn) error { + local := target.LocalAddr().(*net.TCPAddr) + bind := core.FwdAddr{IP: local.IP, Port: local.Port} + err := sendReply(user, successReply, &bind) + if err != nil { + return err + } + return nil +} diff --git a/server/web/models/channel.go b/server/models/channel.go similarity index 100% rename from server/web/models/channel.go rename to server/models/channel.go diff --git a/server/web/models/node.go b/server/models/node.go similarity index 100% rename from server/web/models/node.go rename to server/models/node.go diff --git a/server/web/models/user-ip.go b/server/models/user-ip.go similarity index 100% rename from server/web/models/user-ip.go rename to server/models/user-ip.go diff --git a/server/web/models/user.go b/server/models/user.go similarity index 100% rename from server/web/models/user.go rename to server/models/user.go diff --git a/server/server.go b/server/server.go index 00a5662..1ed0965 100644 --- a/server/server.go +++ b/server/server.go @@ -66,11 +66,9 @@ func Start() { timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() - wgCh := utils.ChanWgWait(timeout, &wg) - close(wgCh) select { - case <-wgCh: + case <-utils.ChanWgWait(timeout, &wg): slog.Info("服务已退出") case <-timeout.Done(): slog.Warn("退出超时,强制退出") diff --git a/server/web/handlers/channel.go b/server/web/handlers/channel.go index 4bfff3e..bffe3c2 100644 --- a/server/web/handlers/channel.go +++ b/server/web/handlers/channel.go @@ -2,9 +2,9 @@ package handlers import ( "log/slog" + "proxy-server/server/models" "proxy-server/server/pkg/orm" "proxy-server/server/pkg/resp" - "proxy-server/server/web/models" "strings" "time" diff --git a/server/web/handlers/node.go b/server/web/handlers/node.go index 4272cb8..205f5bc 100644 --- a/server/web/handlers/node.go +++ b/server/web/handlers/node.go @@ -2,8 +2,8 @@ package handlers import ( "os" + "proxy-server/server/models" "proxy-server/server/pkg/orm" - "proxy-server/server/web/models" "github.com/gin-gonic/gin" "github.com/pkg/errors"