优化错误处理,替换 errors.Wrap 为 fmt.Errorf

This commit is contained in:
2025-05-15 09:53:23 +08:00
parent 75569d2d6d
commit 8b7dc9e4ff
11 changed files with 80 additions and 92 deletions

View File

@@ -10,32 +10,19 @@ server/fwd: 服务端核心代码
- socks: socks5 处理器,负责处理 socks5 请求 - socks: socks5 处理器,负责处理 socks5 请求
- repo: 状态仓库,所有有状态数据都通过 repo 中的接口与外部服务交互 - repo: 状态仓库,所有有状态数据都通过 repo 中的接口与外部服务交互
### 更新测试环境
1. 构建项目
2. 使用测试配置 `.env.test` 远程启动 docker
### 转发服务结束时资源清理 ### 转发服务结束时资源清理
1. 关闭接听端口防止新连接接入user, data, ctrl 1. 关闭接听端口防止新连接接入user, data, ctrl
2. 通知并等待所有正在运行的 conn 处理协程全部关闭user, data, ctrl 2. 通知并等待所有正在运行的 conn 处理协程全部关闭user, data, ctrl
3. 结束所有保存且未使用的 conn 连接user, ctrl 3. 结束所有保存且未使用的 conn 连接user, ctrl
### 代码清理
检查 slog 级别:
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err并附带本层业务信息return 到上层统一打印
其他级别日志就地打印Info 只用来跟踪关键流程
### proxy.lock 文件格式 ### proxy.lock 文件格式
| mag_num(1) | name(16) | | mag_num(1) | name(16) |
|-------------|-----------| |-------------|-----------|
| 魔法数,固定 0x72 | 服务名称uuid | | 魔法数,固定 0x72 | 服务名称uuid |
## 协议 ## 底层协议设计
### 步骤说明 ### 步骤说明
@@ -55,13 +42,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
#### 建立控制通道 #### 建立控制通道
客户端: 1.客户端发送
| id(4) | | id(4) |
|--------| |--------|
| 客户端 ID | | 客户端 ID |
服务端: 2.服务端发送
| status(1) | | status(1) |
|-----------| |-----------|
@@ -69,13 +56,13 @@ ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如
#### 建立数据通道 #### 建立数据通道
服务端: 1.服务端发送
| tag(16) | dst_len(2) | dst_buf(n) | | tag(16) | dst_len(2) | dst_buf(n) |
|---------|------------|------------| |---------|------------|------------|
| 通道标识 | 目标地址长度 | 目标地址 | | 通道标识 | 目标地址长度 | 目标地址 |
客户端: 2.客户端发送
| tag(16) | status(1) | | tag(16) | status(1) |
|---------|-----------------------| |---------|-----------------------|

1
go.mod
View File

@@ -8,7 +8,6 @@ require (
github.com/joho/godotenv v1.5.1 github.com/joho/godotenv v1.5.1
github.com/lmittmann/tint v1.0.7 github.com/lmittmann/tint v1.0.7
github.com/mattn/go-colorable v0.1.14 github.com/mattn/go-colorable v0.1.14
github.com/pkg/errors v0.9.1
github.com/soheilhy/cmux v0.1.5 github.com/soheilhy/cmux v0.1.5
gorm.io/driver/postgres v1.5.11 gorm.io/driver/postgres v1.5.11
gorm.io/gen v0.3.26 gorm.io/gen v0.3.26

2
go.sum
View File

@@ -84,8 +84,6 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M= github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNHvL12M=
github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=

View File

@@ -5,7 +5,7 @@ import (
"log/slog" "log/slog"
"net" "net"
"github.com/pkg/errors" "errors"
) )
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn { func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {

View File

@@ -3,13 +3,14 @@ package fwd
import ( import (
"bufio" "bufio"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"log/slog" "log/slog"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"strings" "strings"
"github.com/pkg/errors" "errors"
) )
func analysisAndLog(conn *core.Conn, reader io.Reader) error { func analysisAndLog(conn *core.Conn, reader io.Reader) error {
@@ -17,7 +18,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
domain, proto, err := sniffing(buf) domain, proto, err := sniffing(buf)
if err != nil { if err != nil {
err = errors.Wrap(err, "analysis sniffing error") err = fmt.Errorf("sniffing error: %w", err)
} else { } else {
slog.Debug( slog.Debug(
"用户访问记录", "用户访问记录",
@@ -39,7 +40,7 @@ func analysisAndLog(conn *core.Conn, reader io.Reader) error {
func sniffing(reader *bufio.Reader) (string, string, error) { func sniffing(reader *bufio.Reader) (string, string, error) {
peek, err := reader.Peek(8) peek, err := reader.Peek(8)
if err != nil { if err != nil {
return "", "", errors.Wrap(err, "sniffing peek error") return "", "", fmt.Errorf("sniffing peek error: %w", err)
} }
method, ok := isHttp(peek) method, ok := isHttp(peek)
@@ -126,7 +127,7 @@ func analysisHttp(reader *bufio.Reader) (string, error) {
// reade top // reade top
top, err := httpReadLine(reader) top, err := httpReadLine(reader)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis http read top error") return "", fmt.Errorf("http read top error: %w", err)
} }
// read header // read header
@@ -153,7 +154,7 @@ func httpReadLine(reader *bufio.Reader) (line string, err error) {
for { for {
line, prefix, err := reader.ReadLine() line, prefix, err := reader.ReadLine()
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis http read line error") return "", fmt.Errorf("http read line error: %w", err)
} }
lineStr.Write(line) lineStr.Write(line)
if !prefix { if !prefix {
@@ -168,13 +169,13 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// tls record // tls record
_, err := utils.ReadBuffer(reader, 5) _, err := utils.ReadBuffer(reader, 5)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read head error") return "", fmt.Errorf("https read head error: %w", err)
} }
// tls type // tls type
hsType, err := reader.ReadByte() hsType, err := reader.ReadByte()
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read hsType error") return "", fmt.Errorf("https read hsType error: %w", err)
} }
switch hsType { switch hsType {
@@ -183,59 +184,59 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// length // length
_, err = utils.ReadBuffer(reader, 3) _, err = utils.ReadBuffer(reader, 3)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read tls length error") return "", fmt.Errorf("https read tls length error: %w", err)
} }
// version // version
_, err = utils.ReadBuffer(reader, 2) _, err = utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read version error") return "", fmt.Errorf("https read version error: %w", err)
} }
// random // random
_, err = utils.ReadBuffer(reader, 32) _, err = utils.ReadBuffer(reader, 32)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read random error") return "", fmt.Errorf("https read random error: %w", err)
} }
// session id length // session id length
sessionIdLen, err := reader.ReadByte() sessionIdLen, err := reader.ReadByte()
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read sessionIdLen error") return "", fmt.Errorf("https read sessionIdLen error: %w", err)
} }
// session id // session id
_, err = utils.ReadBuffer(reader, int(sessionIdLen)) _, err = utils.ReadBuffer(reader, int(sessionIdLen))
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read sessionId error") return "", fmt.Errorf("https read sessionId error: %w", err)
} }
// cipher suites length // cipher suites length
cLenBuf, err := utils.ReadBuffer(reader, 2) cLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read cLen error") return "", fmt.Errorf("https read cLen error: %w", err)
} }
cLen := binary.BigEndian.Uint16(cLenBuf) cLen := binary.BigEndian.Uint16(cLenBuf)
// cipher suites // cipher suites
_, err = utils.ReadBuffer(reader, int(cLen)) _, err = utils.ReadBuffer(reader, int(cLen))
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read c error") return "", fmt.Errorf("https read c error: %w", err)
} }
// compression methods length // compression methods length
cmLen, err := reader.ReadByte() cmLen, err := reader.ReadByte()
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read cmLen error") return "", fmt.Errorf("https read cmLen error: %w", err)
} }
// compression methods // compression methods
_, err = utils.ReadBuffer(reader, int(cmLen)) _, err = utils.ReadBuffer(reader, int(cmLen))
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read cm error") return "", fmt.Errorf("https read cm error: %w", err)
} }
// extensions length // extensions length
eLenBuf, err := utils.ReadBuffer(reader, 2) eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read eLen error") return "", fmt.Errorf("https read eLen error: %w", err)
} }
eLen := binary.BigEndian.Uint16(eLenBuf) eLen := binary.BigEndian.Uint16(eLenBuf)
@@ -247,14 +248,14 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// extension type // extension type
eTypeBuf, err := utils.ReadBuffer(reader, 2) eTypeBuf, err := utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read extension type error") return "", fmt.Errorf("https read extension type error: %w", err)
} }
eType := binary.BigEndian.Uint16(eTypeBuf) eType := binary.BigEndian.Uint16(eTypeBuf)
// extension length // extension length
eLenBuf, err := utils.ReadBuffer(reader, 2) eLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read extension length error") return "", fmt.Errorf("https read extension length error: %w", err)
} }
eLen := binary.BigEndian.Uint16(eLenBuf) eLen := binary.BigEndian.Uint16(eLenBuf)
@@ -263,23 +264,23 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// server name list length // server name list length
_, err = utils.ReadBuffer(reader, 2) _, err = utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read server name list length error") return "", fmt.Errorf("https read server name list length error: %w", err)
} }
// server name type // server name type
_, err = reader.ReadByte() _, err = reader.ReadByte()
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read server name type error") return "", fmt.Errorf("https read server name type error: %w", err)
} }
// server name length // server name length
sLenBuf, err := utils.ReadBuffer(reader, 2) sLenBuf, err := utils.ReadBuffer(reader, 2)
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read server name length error") return "", fmt.Errorf("https read server name length error: %w", err)
} }
sLen := binary.BigEndian.Uint16(sLenBuf) sLen := binary.BigEndian.Uint16(sLenBuf)
// server name // server name
bytes, err := utils.ReadBuffer(reader, int(sLen)) bytes, err := utils.ReadBuffer(reader, int(sLen))
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read server name error") return "", fmt.Errorf("https read server name error: %w", err)
} }
host = string(bytes) host = string(bytes)
@@ -289,7 +290,7 @@ func analysisTls(reader *bufio.Reader) (string, error) {
// other extension // other extension
_, err = utils.ReadBuffer(reader, int(eLen)) _, err = utils.ReadBuffer(reader, int(eLen))
if err != nil { if err != nil {
return "", errors.Wrap(err, "analysis https read extension error") return "", fmt.Errorf("https read extension error: %w", err)
} }
} }
i += 4 + int(eLen) i += 4 + int(eLen)

View File

@@ -1,6 +1,7 @@
package auth package auth
import ( import (
"fmt"
"log/slog" "log/slog"
"net" "net"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
@@ -9,7 +10,7 @@ import (
"strconv" "strconv"
"time" "time"
"github.com/pkg/errors" "errors"
) )
type Protocol string type Protocol string
@@ -25,7 +26,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
remoteHost, _, err := net.SplitHostPort(remoteAddr) remoteHost, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败") return nil, fmt.Errorf("无法获取连接信息: %w", err)
} }
// 获取服务端口 // 获取服务端口
@@ -33,7 +34,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
_, _localPort, err := net.SplitHostPort(localAddr) _, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort) localPort, err := strconv.Atoi(_localPort)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败") return nil, fmt.Errorf("noAuth 认证失败: %w", err)
} }
// 查询权限记录 // 查询权限记录
@@ -51,7 +52,7 @@ func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
// 记录应该只有一条 // 记录应该只有一条
channel, err := orm.MaySingle(channels) channel, err := orm.MaySingle(channels)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "不在白名单内") return nil, errors.New("不在白名单内")
} }
// 检查是否需要密码认证 // 检查是否需要密码认证
@@ -80,7 +81,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
_, _localPort, err := net.SplitHostPort(localAddr) _, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort) localPort, err := strconv.Atoi(_localPort)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败") return nil, fmt.Errorf("noAuth 认证失败: %w", err)
} }
// 查询权限记录 // 查询权限记录
@@ -92,7 +93,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
Protocol: string(proto), Protocol: string(proto),
}).Error }).Error
if err != nil { if err != nil {
return nil, errors.Wrap(err, "用户不存在") return nil, errors.New("用户不存在")
} }
// 检查密码 todo 哈希 // 检查密码 todo 哈希
@@ -107,7 +108,7 @@ func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
remoteHost, _, err := net.SplitHostPort(remoteAddr) remoteHost, _, err := net.SplitHostPort(remoteAddr)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "无法获取连接信息") return nil, fmt.Errorf("无法获取连接信息: %w", err)
} }
// 查询权限记录 // 查询权限记录

View File

@@ -1,6 +1,7 @@
package fwd package fwd
import ( import (
"fmt"
"io" "io"
"log/slog" "log/slog"
"net" "net"
@@ -12,7 +13,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/pkg/errors" "errors"
) )
func (s *Service) startDataTun() error { func (s *Service) startDataTun() error {
@@ -22,7 +23,7 @@ func (s *Service) startDataTun() error {
// 监听端口 // 监听端口
ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort))) ls, err := net.Listen("tcp", ":"+strconv.Itoa(int(dataPort)))
if err != nil { if err != nil {
return errors.Wrap(err, "监听数据通道失败") return fmt.Errorf("监听数据通道失败: %w", err)
} }
defer utils.Close(ls) defer utils.Close(ls)
@@ -34,7 +35,7 @@ func (s *Service) startDataTun() error {
for { for {
conn, err := ls.Accept() conn, err := ls.Accept()
if err != nil { if err != nil {
return errors.Wrap(err, "监听数据通道失败") return fmt.Errorf("监听数据通道失败: %w", err)
} }
select { select {
@@ -60,17 +61,17 @@ func (s *Service) processDataConn(client net.Conn) error {
// 接收 status // 接收 status
status, err := utils.ReadByte(client) status, err := utils.ReadByte(client)
if err != nil { if err != nil {
return errors.Wrap(err, "从客户端获取 status 失败") return fmt.Errorf("从客户端获取 status 失败: %w", err)
} }
// 接收 tag // 接收 tag
tagLen, err := utils.ReadByte(client) tagLen, err := utils.ReadByte(client)
if err != nil { if err != nil {
return errors.Wrap(err, "从客户端获取 tag 失败") return fmt.Errorf("从客户端获取 tag 失败: %w", err)
} }
tagBuf, err := utils.ReadBuffer(client, int(tagLen)) tagBuf, err := utils.ReadBuffer(client, int(tagLen))
if err != nil { if err != nil {
return errors.Wrap(err, "从客户端获取 tag 失败") return fmt.Errorf("从客户端获取 tag 失败: %w", err)
} }
tag := string(tagBuf) tag := string(tagBuf)

View File

@@ -2,6 +2,7 @@ package dispatcher
import ( import (
"context" "context"
"fmt"
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
@@ -13,7 +14,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/pkg/errors" "errors"
"github.com/soheilhy/cmux" "github.com/soheilhy/cmux"
) )
@@ -48,7 +49,7 @@ func (s *Server) Run() error {
ls, err := net.Listen("tcp", ":"+port) ls, err := net.Listen("tcp", ":"+port)
if err != nil { if err != nil {
return errors.Wrap(err, "dispatcher 监听失败") return fmt.Errorf("dispatcher 监听失败: %w", err)
} }
defer utils.Close(ls) defer utils.Close(ls)
@@ -83,7 +84,7 @@ func (s *Server) Run() error {
defer close(errCh) defer close(errCh)
err = m.Serve() err = m.Serve()
if err != nil { if err != nil {
err = errors.Wrap(err, "dispatcher serve error") err = fmt.Errorf("dispatcher serve error: %w", err)
} }
errCh <- err errCh <- err
}() }()
@@ -110,7 +111,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
if errors.As(err, &ne) && ne.Temporary() { if errors.As(err, &ne) && ne.Temporary() {
continue continue
} }
return errors.Wrap(err, "dispatcher http accept error") return fmt.Errorf("dispatcher http accept error: %w", err)
} }
metrics.TimerStart.Store(conn, time.Now()) metrics.TimerStart.Store(conn, time.Now())
@@ -142,7 +143,7 @@ func (s *Server) acceptSocks(ls net.Listener) error {
if errors.As(err, &ne) && ne.Temporary() { if errors.As(err, &ne) && ne.Temporary() {
continue continue
} }
return errors.Wrap(err, "dispatcher socks accept error") return fmt.Errorf("dispatcher socks accept error: %w", err)
} }
metrics.TimerStart.Store(conn, time.Now()) metrics.TimerStart.Store(conn, time.Now())

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"encoding/base64" "encoding/base64"
"fmt"
"io" "io"
"net" "net"
"net/textproto" "net/textproto"
@@ -12,7 +13,7 @@ import (
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"strings" "strings"
"github.com/pkg/errors" "errors"
) )
type Request struct { type Request struct {
@@ -43,7 +44,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 请求头 // 请求头
headers, err := textReader.ReadMIMEHeader() headers, err := textReader.ReadMIMEHeader()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "解析请求头失败") return nil, fmt.Errorf("解析请求头失败: %v", err)
} }
// 验证账号 // 验证账号
@@ -55,9 +56,9 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
if authErr != nil { if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n")) _, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应 407 失败") return nil, fmt.Errorf("响应 407 失败: %v", err)
} }
return nil, errors.Wrap(authErr, "验证账号失败") return nil, fmt.Errorf("验证账号失败: %v", authErr)
} }
} else { } else {
authParts := strings.Split(authInfo, " ") authParts := strings.Split(authInfo, " ")
@@ -69,16 +70,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
} }
authBytes, err := base64.URLEncoding.DecodeString(authParts[1]) authBytes, err := base64.URLEncoding.DecodeString(authParts[1])
if err != nil { if err != nil {
return nil, errors.Wrap(err, "解码认证信息失败") return nil, fmt.Errorf("解码认证信息失败: %v", err)
} }
authPair := strings.Split(string(authBytes), ":") authPair := strings.Split(string(authBytes), ":")
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1]) authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil { if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n")) _, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应 407 失败") return nil, fmt.Errorf("响应 407 失败: %v", err)
} }
return nil, errors.Wrap(authErr, "验证账号失败") return nil, fmt.Errorf("验证账号失败: %v", authErr)
} }
} }
@@ -92,7 +93,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
} }
addr, err := net.ResolveTCPAddr("tcp", host) addr, err := net.ResolveTCPAddr("tcp", host)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "解析 Host 失败") return nil, fmt.Errorf("解析 Host 失败: %v", err)
} }
request := &Request{ request := &Request{
@@ -131,7 +132,7 @@ func processHttps(ctx context.Context, req *Request) (*core.Conn, error) {
// 响应 CONNECT // 响应 CONNECT
_, err := req.conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n")) _, err := req.conn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应 CONNECT 失败") return nil, fmt.Errorf("响应 CONNECT 失败: %v", err)
} }
return &core.Conn{ return &core.Conn{
@@ -149,7 +150,7 @@ func processHttp(ctx context.Context, req *Request) (*core.Conn, error) {
// 修改请求头 // 修改请求头
rawUrl, err := url.Parse(req.uri) rawUrl, err := url.Parse(req.uri)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "解析请求地址失败") return nil, fmt.Errorf("解析请求地址失败: %v", err)
} }
rawUrl.Scheme = "" rawUrl.Scheme = ""
rawUrl.Host = "" rawUrl.Host = ""

View File

@@ -4,6 +4,7 @@ import (
"bufio" "bufio"
"context" "context"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
@@ -12,8 +13,6 @@ import (
"proxy-server/server/fwd/auth" "proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"slices" "slices"
"github.com/pkg/errors"
) )
const ( const (
@@ -63,18 +62,18 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 认证 // 认证
authCtx, err := authenticate(ctx, reader, conn) authCtx, err := authenticate(ctx, reader, conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "认证失败") return nil, fmt.Errorf("认证失败: %w", err)
} }
// 处理连接请求 // 处理连接请求
request, err := request(ctx, reader, conn) request, err := request(ctx, reader, conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "处理连接请求失败") return nil, fmt.Errorf("处理连接请求失败: %w", err)
} }
// 代理连接 // 代理连接
if request.Command != ConnectCommand { if request.Command != ConnectCommand {
return nil, errors.New("不支持的连接指令") return nil, fmt.Errorf("不支持的连接指令: %d", request.Command)
} }
// 响应成功 // 响应成功
@@ -127,56 +126,56 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
if slices.Contains(methods, UserPassAuth) { if slices.Contains(methods, UserPassAuth) {
_, err := conn.Write([]byte{Version, byte(UserPassAuth)}) _, err := conn.Write([]byte{Version, byte(UserPassAuth)})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应认证方式失败") return nil, fmt.Errorf("响应认证方式失败: %w", err)
} }
// 检查认证版本 // 检查认证版本
slog.Debug("验证认证版本") slog.Debug("验证认证版本")
v, err := utils.ReadByte(reader) v, err := utils.ReadByte(reader)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取版本号失败") return nil, fmt.Errorf("读取版本号失败: %w", err)
} }
if v != AuthVersion { if v != AuthVersion {
_, err := conn.Write([]byte{Version, AuthFailure}) _, err := conn.Write([]byte{Version, AuthFailure})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应认证失败") return nil, fmt.Errorf("响应认证失败: %w", err)
} }
return nil, errors.New("认证版本参数不正确") return nil, fmt.Errorf("认证版本参数不正确: %w", err)
} }
// 读取账号 // 读取账号
slog.Debug("验证用户账号") slog.Debug("验证用户账号")
uLen, err := utils.ReadByte(reader) uLen, err := utils.ReadByte(reader)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取用户名长度失败") return nil, fmt.Errorf("读取用户名长度失败: %w", err)
} }
usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) usernameBuf, err := utils.ReadBuffer(reader, int(uLen))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取用户名失败") return nil, fmt.Errorf("读取用户名失败: %w", err)
} }
username := string(usernameBuf) username := string(usernameBuf)
// 读取密码 // 读取密码
pLen, err := utils.ReadByte(reader) pLen, err := utils.ReadByte(reader)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取密码长度失败") return nil, fmt.Errorf("读取密码长度失败: %w", err)
} }
passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) passwordBuf, err := utils.ReadBuffer(reader, int(pLen))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取密码失败") return nil, fmt.Errorf("读取密码失败: %w", err)
} }
password := string(passwordBuf) password := string(passwordBuf)
// 检查权限 // 检查权限
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password) authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "权限检查失败") return nil, fmt.Errorf("权限检查失败: %w", err)
} }
// 响应认证成功 // 响应认证成功
_, err = conn.Write([]byte{AuthVersion, AuthSuccess}) _, err = conn.Write([]byte{AuthVersion, AuthSuccess})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应认证成功失败") return nil, fmt.Errorf("响应认证成功失败: %w", err)
} }
return authContext, nil return authContext, nil
@@ -186,12 +185,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
if slices.Contains(methods, NoAuth) { if slices.Contains(methods, NoAuth) {
_, err = conn.Write([]byte{Version, NoAuth}) _, err = conn.Write([]byte{Version, NoAuth})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应认证方式失败") return nil, fmt.Errorf("响应认证方式失败: %w", err)
} }
authCtx, err := auth.CheckIp(conn, auth.Socks5) authCtx, err := auth.CheckIp(conn, auth.Socks5)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "权限检查失败") return nil, fmt.Errorf("权限检查失败: %w", err)
} }
return authCtx, nil return authCtx, nil

View File

@@ -5,7 +5,7 @@ import (
"log/slog" "log/slog"
"proxy-server/server/pkg/env" "proxy-server/server/pkg/env"
"github.com/pkg/errors" "errors"
"gorm.io/driver/postgres" "gorm.io/driver/postgres"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"