优化错误处理,替换 errors.Wrap 为 fmt.Errorf

This commit is contained in:
2025-05-15 09:53:23 +08:00
parent 75569d2d6d
commit 8b7dc9e4ff
11 changed files with 80 additions and 92 deletions

View File

@@ -10,32 +10,19 @@ server/fwd: 服务端核心代码
- socks: socks5 处理器,负责处理 socks5 请求
- repo: 状态仓库,所有有状态数据都通过 repo 中的接口与外部服务交互
### 更新测试环境
1. 构建项目
2. 使用测试配置 `.env.test` 远程启动 docker
### 转发服务结束时资源清理
1. 关闭接听端口防止新连接接入user, data, ctrl
2. 通知并等待所有正在运行的 conn 处理协程全部关闭user, data, ctrl
3. 结束所有保存且未使用的 conn 连接user, ctrl
### 代码清理
检查 slog 级别:
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err并附带本层业务信息return 到上层统一打印
其他级别日志就地打印Info 只用来跟踪关键流程
### proxy.lock 文件格式
| mag_num(1) | name(16) |
|-------------|-----------|
| 魔法数,固定 0x72 | 服务名称uuid |
## 协议
## 底层协议设计
### 步骤说明
@@ -55,13 +42,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
#### 建立控制通道
客户端:
1.客户端发送
| id(4) |
|--------|
| 客户端 ID |
服务端:
2.服务端发送
| status(1) |
|-----------|
@@ -69,13 +56,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
#### 建立数据通道
服务端:
1.服务端发送
| tag(16) | dst_len(2) | dst_buf(n) |
|---------|------------|------------|
| 通道标识 | 目标地址长度 | 目标地址 |
客户端:
2.客户端发送
| tag(16) | status(1) |
|---------|-----------------------|

1
go.mod
View File

@@ -8,7 +8,6 @@ require (
github.com/joho/godotenv v1.5.1
github.com/lmittmann/tint v1.0.7
github.com/mattn/go-colorable v0.1.14
github.com/pkg/errors v0.9.1
github.com/soheilhy/cmux v0.1.5
gorm.io/driver/postgres v1.5.11
gorm.io/gen v0.3.26

2
go.sum
View File

@@ -84,8 +84,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=

View File

@@ -5,7 +5,7 @@ import (
"log/slog"
"net"
"github.com/pkg/errors"
"errors"
)
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {

View File

@@ -3,13 +3,14 @@ package fwd
import (
"bufio"
"encoding/binary"
"fmt"
"io"
"log/slog"
"proxy-server/pkg/utils"
"proxy-server/server/fwd/core"
"strings"
"github.com/pkg/errors"
"errors"
)
func analysisAndLog(conn *core.Conn, reader io.Reader) error {
@@ -17,7 +18,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
domain, proto, err := sniffing(buf)
if err != nil {
err = errors.Wrap(err, "analysis sniffing error")
err = fmt.Errorf("sniffing error: %w", err)
} else {
slog.Debug(
"用户访问记录",
@@ -39,7 +40,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
func sniffing(reader *bufio.Reader) (string, string, error) {
peek, err := reader.Peek(8)
if err != nil {
return "", "", errors.Wrap(err, "sniffing peek error")
return "", "", fmt.Errorf("sniffing peek error: %w", err)
}
method, ok := isHttp(peek)
@@ -126,7 +127,7 @@ func analysisHttp(reader *bufio.Reader) (string, error) {
// reade top
top, err := httpReadLine(reader)
if err != nil {
return "", errors.Wrap(err, "analysis http read top error")
return "", fmt.Errorf("http read top error: %w", err)
}
// read header
@@ -153,7 +154,7 @@ func httpReadLine(reader *bufio.Reader) (line string, err error) {
for {
line, prefix, err := reader.ReadLine()
if err != nil {
return "", errors.Wrap(err, "analysis http read line error")
return "", fmt.Errorf("http read line error: %w", err)
}
lineStr.Write(line)
if !prefix {
@@ -168,13 +169,13 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// tls record
_, err := utils.ReadBuffer(reader, 5)
if err != nil {
return "", errors.Wrap(err, "analysis https read head error")
return "", fmt.Errorf("https read head error: %w", err)
}
// tls type
hsType, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read hsType error")
return "", fmt.Errorf("https read hsType error: %w", err)
}
switch hsType {
@@ -183,59 +184,59 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// length
_, err = utils.ReadBuffer(reader, 3)
if err != nil {
return "", errors.Wrap(err, "analysis https read tls length error")
return "", fmt.Errorf("https read tls length error: %w", err)
}
// version
_, err = utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read version error")
return "", fmt.Errorf("https read version error: %w", err)
}
// random
_, err = utils.ReadBuffer(reader, 32)
if err != nil {
return "", errors.Wrap(err, "analysis https read random error")
return "", fmt.Errorf("https read random error: %w", err)
}
// session id length
sessionIdLen, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read sessionIdLen error")
return "", fmt.Errorf("https read sessionIdLen error: %w", err)
}
// session id
_, err = utils.ReadBuffer(reader, int(sessionIdLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read sessionId error")
return "", fmt.Errorf("https read sessionId error: %w", err)
}
// cipher suites length
cLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read cLen error")
return "", fmt.Errorf("https read cLen error: %w", err)
}
cLen := binary.BigEndian.Uint16(cLenBuf)
// cipher suites
_, err = utils.ReadBuffer(reader, int(cLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read c error")
return "", fmt.Errorf("https read c error: %w", err)
}
// compression methods length
cmLen, err := reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read cmLen error")
return "", fmt.Errorf("https read cmLen error: %w", err)
}
// compression methods
_, err = utils.ReadBuffer(reader, int(cmLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read cm error")
return "", fmt.Errorf("https read cm error: %w", err)
}
// extensions length
eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read eLen error")
return "", fmt.Errorf("https read eLen error: %w", err)
}
eLen := binary.BigEndian.Uint16(eLenBuf)
@@ -247,14 +248,14 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// extension type
eTypeBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read extension type error")
return "", fmt.Errorf("https read extension type error: %w", err)
}
eType := binary.BigEndian.Uint16(eTypeBuf)
// extension length
eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read extension length error")
return "", fmt.Errorf("https read extension length error: %w", err)
}
eLen := binary.BigEndian.Uint16(eLenBuf)
@@ -263,23 +264,23 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// server name list length
_, err = utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read server name list length error")
return "", fmt.Errorf("https read server name list length error: %w", err)
}
// server name type
_, err = reader.ReadByte()
if err != nil {
return "", errors.Wrap(err, "analysis https read server name type error")
return "", fmt.Errorf("https read server name type error: %w", err)
}
// server name length
sLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
return "", errors.Wrap(err, "analysis https read server name length error")
return "", fmt.Errorf("https read server name length error: %w", err)
}
sLen := binary.BigEndian.Uint16(sLenBuf)
// server name
bytes, err := utils.ReadBuffer(reader, int(sLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read server name error")
return "", fmt.Errorf("https read server name error: %w", err)
}
host = string(bytes)
@@ -289,7 +290,7 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// other extension
_, err = utils.ReadBuffer(reader, int(eLen))
if err != nil {
return "", errors.Wrap(err, "analysis https read extension error")
return "", fmt.Errorf("https read extension error: %w", err)
}
}
i += 4 + int(eLen)

View File

@@ -1,6 +1,7 @@
package auth
import (
"fmt"
"log/slog"
"net"
"proxy-server/server/fwd/core"
@@ -9,7 +10,7 @@ import (
"strconv"
"time"
"github.com/pkg/errors"
"errors"
)
type Protocol string
@@ -25,7 +26,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
return nil, fmt.Errorf("无法获取连接信息: %w", err)
}
// 获取服务端口
@@ -33,7 +34,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
}
// 查询权限记录
@@ -51,7 +52,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
// 记录应该只有一条
channel, err := orm.MaySingle(channels)
if err != nil {
return nil, errors.Wrap(err, "不在白名单内")
return nil, errors.New("不在白名单内")
}
// 检查是否需要密码认证
@@ -80,7 +81,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
}
// 查询权限记录
@@ -92,7 +93,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.Wrap(err, "用户不存在")
return nil, errors.New("用户不存在")
}
// 检查密码 todo 哈希
@@ -107,7 +108,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
remoteAddr := conn.RemoteAddr().String()
remoteHost, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return nil, errors.Wrap(err, "无法获取连接信息")
return nil, fmt.Errorf("无法获取连接信息: %w", err)
}
// 查询权限记录

View File

@@ -1,6 +1,7 @@
package fwd
import (
"fmt"
"io"
"log/slog"
"net"
@@ -12,7 +13,7 @@ import (
"sync"
"time"
"github.com/pkg/errors"
"errors"
)
func (s *Service) startDataTun() error {
@@ -22,7 +23,7 @@ func (s *Service) startDataTun() error {
// 监听端口
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort)))
if err != nil {
return errors.Wrap(err, "监听数据通道失败")
return fmt.Errorf("监听数据通道失败: %w", err)
}
defer utils.Close(ls)
@@ -34,7 +35,7 @@ func (s *Service) startDataTun() error {
for {
conn, err := ls.Accept()
if err != nil {
return errors.Wrap(err, "监听数据通道失败")
return fmt.Errorf("监听数据通道失败: %w", err)
}
select {
@@ -60,17 +61,17 @@ func (s *Service) processDataConn(client net.Conn) error {
// 接收 status
status, err := utils.ReadByte(client)
if err != nil {
return errors.Wrap(err, "从客户端获取 status 失败")
return fmt.Errorf("从客户端获取 status 失败: %w", err)
}
// 接收 tag
tagLen, err := utils.ReadByte(client)
if err != nil {
return errors.Wrap(err, "从客户端获取 tag 失败")
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
}
tagBuf, err := utils.ReadBuffer(client, int(tagLen))
if err != nil {
return errors.Wrap(err, "从客户端获取 tag 失败")
return fmt.Errorf("从客户端获取 tag 失败: %w", err)
}
tag := string(tagBuf)

View File

@@ -2,6 +2,7 @@ package dispatcher
import (
"context"
"fmt"
"log/slog"
"net"
"proxy-server/pkg/utils"
@@ -13,7 +14,7 @@ import (
"strings"
"time"
"github.com/pkg/errors"
"errors"
"github.com/soheilhy/cmux"
)
@@ -48,7 +49,7 @@ func (s *Server) Run() error {
ls, err := net.Listen("tcp", ":"+port)
if err != nil {
return errors.Wrap(err, "dispatcher 监听失败")
return fmt.Errorf("dispatcher 监听失败: %w", err)
}
defer utils.Close(ls)
@@ -83,7 +84,7 @@ func (s *Server) Run() error {
defer close(errCh)
err = m.Serve()
if err != nil {
err = errors.Wrap(err, "dispatcher serve error")
err = fmt.Errorf("dispatcher serve error: %w", err)
}
errCh <- err
}()
@@ -110,7 +111,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return errors.Wrap(err, "dispatcher http accept error")
return fmt.Errorf("dispatcher http accept error: %w", err)
}
metrics.TimerStart.Store(conn, time.Now())
@@ -142,7 +143,7 @@ func (s *Server) acceptSocks(ls net.Listener) error {
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return errors.Wrap(err, "dispatcher socks accept error")
return fmt.Errorf("dispatcher socks accept error: %w", err)
}
metrics.TimerStart.Store(conn, time.Now())

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/textproto"
@@ -12,7 +13,7 @@ import (
"proxy-server/server/fwd/core"
"strings"
"github.com/pkg/errors"
"errors"
)
type Request struct {
@@ -43,7 +44,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 请求头
headers, err := textReader.ReadMIMEHeader()
if err != nil {
return nil, errors.Wrap(err, "解析请求头失败")
return nil, fmt.Errorf("解析请求头失败: %v", err)
}
// 验证账号
@@ -55,9 +56,9 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
return nil, fmt.Errorf("响应 407 失败: %v", err)
}
return nil, errors.Wrap(authErr, "验证账号失败")
return nil, fmt.Errorf("验证账号失败: %v", authErr)
}
} else {
authParts := strings.Split(authInfo, " ")
@@ -69,16 +70,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
}
authBytes, err := base64.URLEncoding.DecodeString(authParts[1])
if err != nil {
return nil, errors.Wrap(err, "解码认证信息失败")
return nil, fmt.Errorf("解码认证信息失败: %v", err)
}
authPair := strings.Split(string(authBytes), ":")
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
return nil, fmt.Errorf("响应 407 失败: %v", err)
}
return nil, errors.Wrap(authErr, "验证账号失败")
return nil, fmt.Errorf("验证账号失败: %v", authErr)
}
}
@@ -92,7 +93,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
}
addr, err := net.ResolveTCPAddr("tcp", host)
if err != nil {
return nil, errors.Wrap(err, "解析 Host 失败")
return nil, fmt.Errorf("解析 Host 失败: %v", err)
}
request := &Request{
@@ -131,7 +132,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) {
// 响应 CONNECT
_, err := req.conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 CONNECT 失败")
return nil, fmt.Errorf("响应 CONNECT 失败: %v", err)
}
return &core.Conn{
@@ -149,7 +150,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) {
// 修改请求头
rawUrl, err := url.Parse(req.uri)
if err != nil {
return nil, errors.Wrap(err, "解析请求地址失败")
return nil, fmt.Errorf("解析请求地址失败: %v", err)
}
rawUrl.Scheme = ""
rawUrl.Host = ""

View File

@@ -4,6 +4,7 @@ import (
"bufio"
"context"
"encoding/binary"
"errors"
"fmt"
"io"
"log/slog"
@@ -12,8 +13,6 @@ import (
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"slices"
"github.com/pkg/errors"
)
const (
@@ -63,18 +62,18 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 认证
authCtx, err := authenticate(ctx, reader, conn)
if err != nil {
return nil, errors.Wrap(err, "认证失败")
return nil, fmt.Errorf("认证失败: %w", err)
}
// 处理连接请求
request, err := request(ctx, reader, conn)
if err != nil {
return nil, errors.Wrap(err, "处理连接请求失败")
return nil, fmt.Errorf("处理连接请求失败: %w", err)
}
// 代理连接
if request.Command != ConnectCommand {
return nil, errors.New("不支持的连接指令")
return nil, fmt.Errorf("不支持的连接指令: %d", request.Command)
}
// 响应成功
@@ -127,56 +126,56 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
if slices.Contains(methods, UserPassAuth) {
_, err := conn.Write([]byte{Version, byte(UserPassAuth)})
if err != nil {
return nil, errors.Wrap(err, "响应认证方式失败")
return nil, fmt.Errorf("响应认证方式失败: %w", err)
}
// 检查认证版本
slog.Debug("验证认证版本")
v, err := utils.ReadByte(reader)
if err != nil {
return nil, errors.Wrap(err, "读取版本号失败")
return nil, fmt.Errorf("读取版本号失败: %w", err)
}
if v != AuthVersion {
_, err := conn.Write([]byte{Version, AuthFailure})
if err != nil {
return nil, errors.Wrap(err, "响应认证失败")
return nil, fmt.Errorf("响应认证失败: %w", err)
}
return nil, errors.New("认证版本参数不正确")
return nil, fmt.Errorf("认证版本参数不正确: %w", err)
}
// 读取账号
slog.Debug("验证用户账号")
uLen, err := utils.ReadByte(reader)
if err != nil {
return nil, errors.Wrap(err, "读取用户名长度失败")
return nil, fmt.Errorf("读取用户名长度失败: %w", err)
}
usernameBuf, err := utils.ReadBuffer(reader, int(uLen))
if err != nil {
return nil, errors.Wrap(err, "读取用户名失败")
return nil, fmt.Errorf("读取用户名失败: %w", err)
}
username := string(usernameBuf)
// 读取密码
pLen, err := utils.ReadByte(reader)
if err != nil {
return nil, errors.Wrap(err, "读取密码长度失败")
return nil, fmt.Errorf("读取密码长度失败: %w", err)
}
passwordBuf, err := utils.ReadBuffer(reader, int(pLen))
if err != nil {
return nil, errors.Wrap(err, "读取密码失败")
return nil, fmt.Errorf("读取密码失败: %w", err)
}
password := string(passwordBuf)
// 检查权限
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
return nil, fmt.Errorf("权限检查失败: %w", err)
}
// 响应认证成功
_, err = conn.Write([]byte{AuthVersion, AuthSuccess})
if err != nil {
return nil, errors.Wrap(err, "响应认证成功失败")
return nil, fmt.Errorf("响应认证成功失败: %w", err)
}
return authContext, nil
@@ -186,12 +185,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
if slices.Contains(methods, NoAuth) {
_, err = conn.Write([]byte{Version, NoAuth})
if err != nil {
return nil, errors.Wrap(err, "响应认证方式失败")
return nil, fmt.Errorf("响应认证方式失败: %w", err)
}
authCtx, err := auth.CheckIp(conn, auth.Socks5)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
return nil, fmt.Errorf("权限检查失败: %w", err)
}
return authCtx, nil

View File

@@ -5,7 +5,7 @@ import (
"log/slog"
"proxy-server/server/pkg/env"
"github.com/pkg/errors"
"errors"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"