优化错误处理,替换 errors.Wrap 为 fmt.Errorf
This commit is contained in:
23
README.md
23
README.md
@@ -10,32 +10,19 @@ server/fwd: 服务端核心代码
|
|||||||
- socks: socks5 处理器,负责处理 socks5 请求
|
- socks: socks5 处理器,负责处理 socks5 请求
|
||||||
- repo: 状态仓库,所有有状态数据都通过 repo 中的接口与外部服务交互
|
- repo: 状态仓库,所有有状态数据都通过 repo 中的接口与外部服务交互
|
||||||
|
|
||||||
### 更新测试环境
|
|
||||||
|
|
||||||
1. 构建项目
|
|
||||||
2. 使用测试配置 `.env.test` 远程启动 docker
|
|
||||||
|
|
||||||
### 转发服务结束时资源清理
|
### 转发服务结束时资源清理
|
||||||
|
|
||||||
1. 关闭接听端口,防止新连接接入(user, data, ctrl)
|
1. 关闭接听端口,防止新连接接入(user, data, ctrl)
|
||||||
2. 通知并等待所有正在运行的 conn 处理协程全部关闭(user, data, ctrl)
|
2. 通知并等待所有正在运行的 conn 处理协程全部关闭(user, data, ctrl)
|
||||||
3. 结束所有保存且未使用的 conn 连接(user, ctrl)
|
3. 结束所有保存且未使用的 conn 连接(user, ctrl)
|
||||||
|
|
||||||
### 代码清理
|
|
||||||
|
|
||||||
检查 slog 级别:
|
|
||||||
|
|
||||||
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印
|
|
||||||
|
|
||||||
其他级别日志就地打印,Info 只用来跟踪关键流程
|
|
||||||
|
|
||||||
### proxy.lock 文件格式
|
### proxy.lock 文件格式
|
||||||
|
|
||||||
| mag_num(1) | name(16) |
|
| mag_num(1) | name(16) |
|
||||||
|-------------|-----------|
|
|-------------|-----------|
|
||||||
| 魔法数,固定 0x72 | 服务名称,uuid |
|
| 魔法数,固定 0x72 | 服务名称,uuid |
|
||||||
|
|
||||||
## 协议
|
## 底层协议设计
|
||||||
|
|
||||||
### 步骤说明
|
### 步骤说明
|
||||||
|
|
||||||
@@ -55,13 +42,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
|
|||||||
|
|
||||||
#### 建立控制通道
|
#### 建立控制通道
|
||||||
|
|
||||||
客户端:
|
1.客户端发送:
|
||||||
|
|
||||||
| id(4) |
|
| id(4) |
|
||||||
|--------|
|
|--------|
|
||||||
| 客户端 ID |
|
| 客户端 ID |
|
||||||
|
|
||||||
服务端:
|
2.服务端发送:
|
||||||
|
|
||||||
| status(1) |
|
| status(1) |
|
||||||
|-----------|
|
|-----------|
|
||||||
@@ -69,13 +56,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
|
|||||||
|
|
||||||
#### 建立数据通道
|
#### 建立数据通道
|
||||||
|
|
||||||
服务端:
|
1.服务端发送:
|
||||||
|
|
||||||
| tag(16) | dst_len(2) | dst_buf(n) |
|
| tag(16) | dst_len(2) | dst_buf(n) |
|
||||||
|---------|------------|------------|
|
|---------|------------|------------|
|
||||||
| 通道标识 | 目标地址长度 | 目标地址 |
|
| 通道标识 | 目标地址长度 | 目标地址 |
|
||||||
|
|
||||||
客户端:
|
2.客户端发送:
|
||||||
|
|
||||||
| tag(16) | status(1) |
|
| tag(16) | status(1) |
|
||||||
|---------|-----------------------|
|
|---------|-----------------------|
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -8,7 +8,6 @@ require (
|
|||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/lmittmann/tint v1.0.7
|
github.com/lmittmann/tint v1.0.7
|
||||||
github.com/mattn/go-colorable v0.1.14
|
github.com/mattn/go-colorable v0.1.14
|
||||||
github.com/pkg/errors v0.9.1
|
|
||||||
github.com/soheilhy/cmux v0.1.5
|
github.com/soheilhy/cmux v0.1.5
|
||||||
gorm.io/driver/postgres v1.5.11
|
gorm.io/driver/postgres v1.5.11
|
||||||
gorm.io/gen v0.3.26
|
gorm.io/gen v0.3.26
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -84,8 +84,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
|||||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
|
||||||
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
|
||||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
|
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||||
|
|||||||
@@ -3,13 +3,14 @@ package fwd
|
|||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
||||||
@@ -17,7 +18,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
|||||||
|
|
||||||
domain, proto, err := sniffing(buf)
|
domain, proto, err := sniffing(buf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrap(err, "analysis sniffing error")
|
err = fmt.Errorf("sniffing error: %w", err)
|
||||||
} else {
|
} else {
|
||||||
slog.Debug(
|
slog.Debug(
|
||||||
"用户访问记录",
|
"用户访问记录",
|
||||||
@@ -39,7 +40,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
|
|||||||
func sniffing(reader *bufio.Reader) (string, string, error) {
|
func sniffing(reader *bufio.Reader) (string, string, error) {
|
||||||
peek, err := reader.Peek(8)
|
peek, err := reader.Peek(8)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errors.Wrap(err, "sniffing peek error")
|
return "", "", fmt.Errorf("sniffing peek error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
method, ok := isHttp(peek)
|
method, ok := isHttp(peek)
|
||||||
@@ -126,7 +127,7 @@ func analysisHttp(reader *bufio.Reader) (string, error) {
|
|||||||
// reade top
|
// reade top
|
||||||
top, err := httpReadLine(reader)
|
top, err := httpReadLine(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis http read top error")
|
return "", fmt.Errorf("http read top error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// read header
|
// read header
|
||||||
@@ -153,7 +154,7 @@ func httpReadLine(reader *bufio.Reader) (line string, err error) {
|
|||||||
for {
|
for {
|
||||||
line, prefix, err := reader.ReadLine()
|
line, prefix, err := reader.ReadLine()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis http read line error")
|
return "", fmt.Errorf("http read line error: %w", err)
|
||||||
}
|
}
|
||||||
lineStr.Write(line)
|
lineStr.Write(line)
|
||||||
if !prefix {
|
if !prefix {
|
||||||
@@ -168,13 +169,13 @@ func analysisTls(reader *bufio.Reader) (string, error) {
|
|||||||
// tls record
|
// tls record
|
||||||
_, err := utils.ReadBuffer(reader, 5)
|
_, err := utils.ReadBuffer(reader, 5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read head error")
|
return "", fmt.Errorf("https read head error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// tls type
|
// tls type
|
||||||
hsType, err := reader.ReadByte()
|
hsType, err := reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read hsType error")
|
return "", fmt.Errorf("https read hsType error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
switch hsType {
|
switch hsType {
|
||||||
@@ -183,59 +184,59 @@ func analysisTls(reader *bufio.Reader) (string, error) {
|
|||||||
// length
|
// length
|
||||||
_, err = utils.ReadBuffer(reader, 3)
|
_, err = utils.ReadBuffer(reader, 3)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read tls length error")
|
return "", fmt.Errorf("https read tls length error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// version
|
// version
|
||||||
_, err = utils.ReadBuffer(reader, 2)
|
_, err = utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read version error")
|
return "", fmt.Errorf("https read version error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// random
|
// random
|
||||||
_, err = utils.ReadBuffer(reader, 32)
|
_, err = utils.ReadBuffer(reader, 32)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read random error")
|
return "", fmt.Errorf("https read random error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// session id length
|
// session id length
|
||||||
sessionIdLen, err := reader.ReadByte()
|
sessionIdLen, err := reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read sessionIdLen error")
|
return "", fmt.Errorf("https read sessionIdLen error: %w", err)
|
||||||
}
|
}
|
||||||
// session id
|
// session id
|
||||||
_, err = utils.ReadBuffer(reader, int(sessionIdLen))
|
_, err = utils.ReadBuffer(reader, int(sessionIdLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read sessionId error")
|
return "", fmt.Errorf("https read sessionId error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// cipher suites length
|
// cipher suites length
|
||||||
cLenBuf, err := utils.ReadBuffer(reader, 2)
|
cLenBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read cLen error")
|
return "", fmt.Errorf("https read cLen error: %w", err)
|
||||||
}
|
}
|
||||||
cLen := binary.BigEndian.Uint16(cLenBuf)
|
cLen := binary.BigEndian.Uint16(cLenBuf)
|
||||||
// cipher suites
|
// cipher suites
|
||||||
_, err = utils.ReadBuffer(reader, int(cLen))
|
_, err = utils.ReadBuffer(reader, int(cLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read c error")
|
return "", fmt.Errorf("https read c error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// compression methods length
|
// compression methods length
|
||||||
cmLen, err := reader.ReadByte()
|
cmLen, err := reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read cmLen error")
|
return "", fmt.Errorf("https read cmLen error: %w", err)
|
||||||
}
|
}
|
||||||
// compression methods
|
// compression methods
|
||||||
_, err = utils.ReadBuffer(reader, int(cmLen))
|
_, err = utils.ReadBuffer(reader, int(cmLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read cm error")
|
return "", fmt.Errorf("https read cm error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// extensions length
|
// extensions length
|
||||||
eLenBuf, err := utils.ReadBuffer(reader, 2)
|
eLenBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read eLen error")
|
return "", fmt.Errorf("https read eLen error: %w", err)
|
||||||
}
|
}
|
||||||
eLen := binary.BigEndian.Uint16(eLenBuf)
|
eLen := binary.BigEndian.Uint16(eLenBuf)
|
||||||
|
|
||||||
@@ -247,14 +248,14 @@ func analysisTls(reader *bufio.Reader) (string, error) {
|
|||||||
// extension type
|
// extension type
|
||||||
eTypeBuf, err := utils.ReadBuffer(reader, 2)
|
eTypeBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read extension type error")
|
return "", fmt.Errorf("https read extension type error: %w", err)
|
||||||
}
|
}
|
||||||
eType := binary.BigEndian.Uint16(eTypeBuf)
|
eType := binary.BigEndian.Uint16(eTypeBuf)
|
||||||
|
|
||||||
// extension length
|
// extension length
|
||||||
eLenBuf, err := utils.ReadBuffer(reader, 2)
|
eLenBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read extension length error")
|
return "", fmt.Errorf("https read extension length error: %w", err)
|
||||||
}
|
}
|
||||||
eLen := binary.BigEndian.Uint16(eLenBuf)
|
eLen := binary.BigEndian.Uint16(eLenBuf)
|
||||||
|
|
||||||
@@ -263,23 +264,23 @@ func analysisTls(reader *bufio.Reader) (string, error) {
|
|||||||
// server name list length
|
// server name list length
|
||||||
_, err = utils.ReadBuffer(reader, 2)
|
_, err = utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read server name list length error")
|
return "", fmt.Errorf("https read server name list length error: %w", err)
|
||||||
}
|
}
|
||||||
// server name type
|
// server name type
|
||||||
_, err = reader.ReadByte()
|
_, err = reader.ReadByte()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read server name type error")
|
return "", fmt.Errorf("https read server name type error: %w", err)
|
||||||
}
|
}
|
||||||
// server name length
|
// server name length
|
||||||
sLenBuf, err := utils.ReadBuffer(reader, 2)
|
sLenBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read server name length error")
|
return "", fmt.Errorf("https read server name length error: %w", err)
|
||||||
}
|
}
|
||||||
sLen := binary.BigEndian.Uint16(sLenBuf)
|
sLen := binary.BigEndian.Uint16(sLenBuf)
|
||||||
// server name
|
// server name
|
||||||
bytes, err := utils.ReadBuffer(reader, int(sLen))
|
bytes, err := utils.ReadBuffer(reader, int(sLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read server name error")
|
return "", fmt.Errorf("https read server name error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
host = string(bytes)
|
host = string(bytes)
|
||||||
@@ -289,7 +290,7 @@ func analysisTls(reader *bufio.Reader) (string, error) {
|
|||||||
// other extension
|
// other extension
|
||||||
_, err = utils.ReadBuffer(reader, int(eLen))
|
_, err = utils.ReadBuffer(reader, int(eLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", errors.Wrap(err, "analysis https read extension error")
|
return "", fmt.Errorf("https read extension error: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
i += 4 + int(eLen)
|
i += 4 + int(eLen)
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
@@ -9,7 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Protocol string
|
type Protocol string
|
||||||
@@ -25,7 +26,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
|||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "noAuth 认证失败")
|
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取服务端口
|
// 获取服务端口
|
||||||
@@ -33,7 +34,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
|||||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||||
localPort, err := strconv.Atoi(_localPort)
|
localPort, err := strconv.Atoi(_localPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "noAuth 认证失败")
|
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询权限记录
|
// 查询权限记录
|
||||||
@@ -51,7 +52,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
|||||||
// 记录应该只有一条
|
// 记录应该只有一条
|
||||||
channel, err := orm.MaySingle(channels)
|
channel, err := orm.MaySingle(channels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "不在白名单内")
|
return nil, errors.New("不在白名单内")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否需要密码认证
|
// 检查是否需要密码认证
|
||||||
@@ -80,7 +81,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
|
|||||||
_, _localPort, err := net.SplitHostPort(localAddr)
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||||
localPort, err := strconv.Atoi(_localPort)
|
localPort, err := strconv.Atoi(_localPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "noAuth 认证失败")
|
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询权限记录
|
// 查询权限记录
|
||||||
@@ -92,7 +93,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
|
|||||||
Protocol: string(proto),
|
Protocol: string(proto),
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "用户不存在")
|
return nil, errors.New("用户不存在")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查密码 todo 哈希
|
// 检查密码 todo 哈希
|
||||||
@@ -107,7 +108,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
|
|||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "无法获取连接信息")
|
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询权限记录
|
// 查询权限记录
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package fwd
|
package fwd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
@@ -12,7 +13,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (s *Service) startDataTun() error {
|
func (s *Service) startDataTun() error {
|
||||||
@@ -22,7 +23,7 @@ func (s *Service) startDataTun() error {
|
|||||||
// 监听端口
|
// 监听端口
|
||||||
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort)))
|
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "监听数据通道失败")
|
return fmt.Errorf("监听数据通道失败: %w", err)
|
||||||
}
|
}
|
||||||
defer utils.Close(ls)
|
defer utils.Close(ls)
|
||||||
|
|
||||||
@@ -34,7 +35,7 @@ func (s *Service) startDataTun() error {
|
|||||||
for {
|
for {
|
||||||
conn, err := ls.Accept()
|
conn, err := ls.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "监听数据通道失败")
|
return fmt.Errorf("监听数据通道失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -60,17 +61,17 @@ func (s *Service) processDataConn(client net.Conn) error {
|
|||||||
// 接收 status
|
// 接收 status
|
||||||
status, err := utils.ReadByte(client)
|
status, err := utils.ReadByte(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "从客户端获取 status 失败")
|
return fmt.Errorf("从客户端获取 status 失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 接收 tag
|
// 接收 tag
|
||||||
tagLen, err := utils.ReadByte(client)
|
tagLen, err := utils.ReadByte(client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "从客户端获取 tag 失败")
|
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
|
||||||
}
|
}
|
||||||
tagBuf, err := utils.ReadBuffer(client, int(tagLen))
|
tagBuf, err := utils.ReadBuffer(client, int(tagLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "从客户端获取 tag 失败")
|
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
|
||||||
}
|
}
|
||||||
tag := string(tagBuf)
|
tag := string(tagBuf)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package dispatcher
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
@@ -13,7 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
"github.com/soheilhy/cmux"
|
"github.com/soheilhy/cmux"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -48,7 +49,7 @@ func (s *Server) Run() error {
|
|||||||
|
|
||||||
ls, err := net.Listen("tcp", ":"+port)
|
ls, err := net.Listen("tcp", ":"+port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Wrap(err, "dispatcher 监听失败")
|
return fmt.Errorf("dispatcher 监听失败: %w", err)
|
||||||
}
|
}
|
||||||
defer utils.Close(ls)
|
defer utils.Close(ls)
|
||||||
|
|
||||||
@@ -83,7 +84,7 @@ func (s *Server) Run() error {
|
|||||||
defer close(errCh)
|
defer close(errCh)
|
||||||
err = m.Serve()
|
err = m.Serve()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.Wrap(err, "dispatcher serve error")
|
err = fmt.Errorf("dispatcher serve error: %w", err)
|
||||||
}
|
}
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}()
|
}()
|
||||||
@@ -110,7 +111,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
|
|||||||
if errors.As(err, &ne) && ne.Temporary() {
|
if errors.As(err, &ne) && ne.Temporary() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return errors.Wrap(err, "dispatcher http accept error")
|
return fmt.Errorf("dispatcher http accept error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.TimerStart.Store(conn, time.Now())
|
metrics.TimerStart.Store(conn, time.Now())
|
||||||
@@ -142,7 +143,7 @@ func (s *Server) acceptSocks(ls net.Listener) error {
|
|||||||
if errors.As(err, &ne) && ne.Temporary() {
|
if errors.As(err, &ne) && ne.Temporary() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return errors.Wrap(err, "dispatcher socks accept error")
|
return fmt.Errorf("dispatcher socks accept error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.TimerStart.Store(conn, time.Now())
|
metrics.TimerStart.Store(conn, time.Now())
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
@@ -12,7 +13,7 @@ import (
|
|||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
@@ -43,7 +44,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
// 请求头
|
// 请求头
|
||||||
headers, err := textReader.ReadMIMEHeader()
|
headers, err := textReader.ReadMIMEHeader()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "解析请求头失败")
|
return nil, fmt.Errorf("解析请求头失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 验证账号
|
// 验证账号
|
||||||
@@ -55,9 +56,9 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
if authErr != nil {
|
if authErr != nil {
|
||||||
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应 407 失败")
|
return nil, fmt.Errorf("响应 407 失败: %v", err)
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(authErr, "验证账号失败")
|
return nil, fmt.Errorf("验证账号失败: %v", authErr)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
authParts := strings.Split(authInfo, " ")
|
authParts := strings.Split(authInfo, " ")
|
||||||
@@ -69,16 +70,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
}
|
}
|
||||||
authBytes, err := base64.URLEncoding.DecodeString(authParts[1])
|
authBytes, err := base64.URLEncoding.DecodeString(authParts[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "解码认证信息失败")
|
return nil, fmt.Errorf("解码认证信息失败: %v", err)
|
||||||
}
|
}
|
||||||
authPair := strings.Split(string(authBytes), ":")
|
authPair := strings.Split(string(authBytes), ":")
|
||||||
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
|
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
|
||||||
if authErr != nil {
|
if authErr != nil {
|
||||||
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应 407 失败")
|
return nil, fmt.Errorf("响应 407 失败: %v", err)
|
||||||
}
|
}
|
||||||
return nil, errors.Wrap(authErr, "验证账号失败")
|
return nil, fmt.Errorf("验证账号失败: %v", authErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -92,7 +93,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
}
|
}
|
||||||
addr, err := net.ResolveTCPAddr("tcp", host)
|
addr, err := net.ResolveTCPAddr("tcp", host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "解析 Host 失败")
|
return nil, fmt.Errorf("解析 Host 失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
request := &Request{
|
request := &Request{
|
||||||
@@ -131,7 +132,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) {
|
|||||||
// 响应 CONNECT
|
// 响应 CONNECT
|
||||||
_, err := req.conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
_, err := req.conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应 CONNECT 失败")
|
return nil, fmt.Errorf("响应 CONNECT 失败: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &core.Conn{
|
return &core.Conn{
|
||||||
@@ -149,7 +150,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) {
|
|||||||
// 修改请求头
|
// 修改请求头
|
||||||
rawUrl, err := url.Parse(req.uri)
|
rawUrl, err := url.Parse(req.uri)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "解析请求地址失败")
|
return nil, fmt.Errorf("解析请求地址失败: %v", err)
|
||||||
}
|
}
|
||||||
rawUrl.Scheme = ""
|
rawUrl.Scheme = ""
|
||||||
rawUrl.Host = ""
|
rawUrl.Host = ""
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"bufio"
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -12,8 +13,6 @@ import (
|
|||||||
"proxy-server/server/fwd/auth"
|
"proxy-server/server/fwd/auth"
|
||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -63,18 +62,18 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
// 认证
|
// 认证
|
||||||
authCtx, err := authenticate(ctx, reader, conn)
|
authCtx, err := authenticate(ctx, reader, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "认证失败")
|
return nil, fmt.Errorf("认证失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理连接请求
|
// 处理连接请求
|
||||||
request, err := request(ctx, reader, conn)
|
request, err := request(ctx, reader, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "处理连接请求失败")
|
return nil, fmt.Errorf("处理连接请求失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 代理连接
|
// 代理连接
|
||||||
if request.Command != ConnectCommand {
|
if request.Command != ConnectCommand {
|
||||||
return nil, errors.New("不支持的连接指令")
|
return nil, fmt.Errorf("不支持的连接指令: %d", request.Command)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 响应成功
|
// 响应成功
|
||||||
@@ -127,56 +126,56 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
|||||||
if slices.Contains(methods, UserPassAuth) {
|
if slices.Contains(methods, UserPassAuth) {
|
||||||
_, err := conn.Write([]byte{Version, byte(UserPassAuth)})
|
_, err := conn.Write([]byte{Version, byte(UserPassAuth)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应认证方式失败")
|
return nil, fmt.Errorf("响应认证方式失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查认证版本
|
// 检查认证版本
|
||||||
slog.Debug("验证认证版本")
|
slog.Debug("验证认证版本")
|
||||||
v, err := utils.ReadByte(reader)
|
v, err := utils.ReadByte(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取版本号失败")
|
return nil, fmt.Errorf("读取版本号失败: %w", err)
|
||||||
}
|
}
|
||||||
if v != AuthVersion {
|
if v != AuthVersion {
|
||||||
_, err := conn.Write([]byte{Version, AuthFailure})
|
_, err := conn.Write([]byte{Version, AuthFailure})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应认证失败")
|
return nil, fmt.Errorf("响应认证失败: %w", err)
|
||||||
}
|
}
|
||||||
return nil, errors.New("认证版本参数不正确")
|
return nil, fmt.Errorf("认证版本参数不正确: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取账号
|
// 读取账号
|
||||||
slog.Debug("验证用户账号")
|
slog.Debug("验证用户账号")
|
||||||
uLen, err := utils.ReadByte(reader)
|
uLen, err := utils.ReadByte(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取用户名长度失败")
|
return nil, fmt.Errorf("读取用户名长度失败: %w", err)
|
||||||
}
|
}
|
||||||
usernameBuf, err := utils.ReadBuffer(reader, int(uLen))
|
usernameBuf, err := utils.ReadBuffer(reader, int(uLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取用户名失败")
|
return nil, fmt.Errorf("读取用户名失败: %w", err)
|
||||||
}
|
}
|
||||||
username := string(usernameBuf)
|
username := string(usernameBuf)
|
||||||
|
|
||||||
// 读取密码
|
// 读取密码
|
||||||
pLen, err := utils.ReadByte(reader)
|
pLen, err := utils.ReadByte(reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取密码长度失败")
|
return nil, fmt.Errorf("读取密码长度失败: %w", err)
|
||||||
}
|
}
|
||||||
passwordBuf, err := utils.ReadBuffer(reader, int(pLen))
|
passwordBuf, err := utils.ReadBuffer(reader, int(pLen))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取密码失败")
|
return nil, fmt.Errorf("读取密码失败: %w", err)
|
||||||
}
|
}
|
||||||
password := string(passwordBuf)
|
password := string(passwordBuf)
|
||||||
|
|
||||||
// 检查权限
|
// 检查权限
|
||||||
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
|
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "权限检查失败")
|
return nil, fmt.Errorf("权限检查失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 响应认证成功
|
// 响应认证成功
|
||||||
_, err = conn.Write([]byte{AuthVersion, AuthSuccess})
|
_, err = conn.Write([]byte{AuthVersion, AuthSuccess})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应认证成功失败")
|
return nil, fmt.Errorf("响应认证成功失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return authContext, nil
|
return authContext, nil
|
||||||
@@ -186,12 +185,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
|||||||
if slices.Contains(methods, NoAuth) {
|
if slices.Contains(methods, NoAuth) {
|
||||||
_, err = conn.Write([]byte{Version, NoAuth})
|
_, err = conn.Write([]byte{Version, NoAuth})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应认证方式失败")
|
return nil, fmt.Errorf("响应认证方式失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authCtx, err := auth.CheckIp(conn, auth.Socks5)
|
authCtx, err := auth.CheckIp(conn, auth.Socks5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "权限检查失败")
|
return nil, fmt.Errorf("权限检查失败: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return authCtx, nil
|
return authCtx, nil
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"proxy-server/server/pkg/env"
|
"proxy-server/server/pkg/env"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"errors"
|
||||||
"gorm.io/driver/postgres"
|
"gorm.io/driver/postgres"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|||||||
Reference in New Issue
Block a user