优化 socks 解析流程

This commit is contained in:
2025-02-26 13:56:56 +08:00
parent b50dc3d91c
commit 7ee4ded08c
12 changed files with 301 additions and 104 deletions

View File

@@ -18,12 +18,28 @@ fwd 使用自定义 context 实现在一个上下文中控制 cancelerrCh 和
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应 网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
数据通道池化
### 长期 ### 长期
代理端口支持混合端口转发(支持 tcp_mux
数据通道支持 tcp 多路复用(分离逻辑流)
👆 进阶黑魔法 multipath tcp + 多路复用
考虑一下连接安全性 考虑一下连接安全性
内部接口 rtt 是否还有优化空间当前30-300ms根据内容大小增长 内部接口 rtt 是否还有优化空间当前30-300ms根据内容大小增长
### 代码清理
检查 slog 级别:
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err并附带本层业务信息return 到上层统一打印
其他级别日志就地打印Info 只用来跟踪关键流程
## 开发相关 ## 开发相关
### 环境变量 ### 环境变量

View File

@@ -8,12 +8,12 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn { func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
connCh := make(chan net.Conn) connCh := make(chan net.Conn)
go func() { go func() {
for { for {
conn, err := ls.Accept() conn, err := ls.Accept()
if err != nil { if err != nil && !errors.Is(err, net.ErrClosed) {
slog.Error("接受连接失败", err) slog.Error("接受连接失败", err)
// 临时错误重试连接 // 临时错误重试连接
var ne net.Error var ne net.Error
@@ -35,7 +35,7 @@ func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
return connCh return connCh
} }
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} { func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} {
ch := make(chan struct{}) ch := make(chan struct{})
go func() { go func() {
wg.Wait() wg.Wait()

View File

@@ -8,9 +8,9 @@ import (
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/socks"
"proxy-server/server/pkg/env" "proxy-server/server/pkg/env"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"proxy-server/server/pkg/socks5"
"proxy-server/server/web/app/models" "proxy-server/server/web/app/models"
"strconv" "strconv"
"time" "time"
@@ -23,7 +23,7 @@ type Config struct {
type Service struct { type Service struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -36,7 +36,7 @@ func New(config *Config) *Service {
return &Service{ return &Service{
Config: _config, Config: _config,
connMap: make(map[string]socks5.ProxyData), connMap: make(map[string]socks.ProxyData),
ctrlConnWg: utils.CountWaitGroup{}, ctrlConnWg: utils.CountWaitGroup{},
dataConnWg: utils.CountWaitGroup{}, dataConnWg: utils.CountWaitGroup{},
} }
@@ -95,7 +95,7 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) {
defer utils.Close(ls) defer utils.Close(ls)
// 等待连接 // 等待连接
connCh := utils.ConnChan(ctx, ls) connCh := utils.ChanConnAccept(ctx, ls)
defer close(connCh) defer close(connCh)
// 处理连接 // 处理连接
@@ -119,7 +119,7 @@ loop:
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
procCh := utils.WaitChan(timeout, &s.ctrlConnWg) procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg)
defer close(procCh) defer close(procCh)
select { select {
@@ -144,20 +144,19 @@ func (s *Service) processCtrlConn(controller net.Conn) {
reader := bufio.NewReader(controller) reader := bufio.NewReader(controller)
// 读取端口 // 读取端口
portBuf := make([]byte, 2) portBuf, err := utils.ReadBuffer(reader, 2)
_, err := io.ReadFull(reader, portBuf)
if err != nil { if err != nil {
slog.Error("读取转发端口失败", "err", err) slog.Error("接收转发端口失败", "err", err)
return return
} }
port := binary.BigEndian.Uint16(portBuf) port := binary.BigEndian.Uint16(portBuf)
// 新建代理服务 // 新建代理服务
slog.Info("新建代理服务", "port", port) slog.Info("新建代理服务", "port", port)
proxy, err := socks5.New(&socks5.Config{ proxy, err := socks.New(&socks.Config{
Name: strconv.Itoa(int(port)), Name: strconv.Itoa(int(port)),
Port: port, Port: port,
AuthMethods: []socks5.Authenticator{ AuthMethods: []socks.Authenticator{
&UserPassAuthenticator{}, &UserPassAuthenticator{},
&NoAuthAuthenticator{}, &NoAuthAuthenticator{},
}, },
@@ -207,7 +206,7 @@ func (s *Service) startDataTun(ctx context.Context, errCh chan error) {
defer utils.Close(lData) defer utils.Close(lData)
// 等待连接 // 等待连接
connCh := utils.ConnChan(ctx, lData) connCh := utils.ChanConnAccept(ctx, lData)
defer close(connCh) defer close(connCh)
// 处理连接 // 处理连接
@@ -231,7 +230,7 @@ loop:
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second) timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
procCh := utils.WaitChan(timeout, &s.dataConnWg) procCh := utils.ChanWgWait(timeout, &s.dataConnWg)
defer close(procCh) defer close(procCh)
select { select {
@@ -275,7 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
// 响应用户 // 响应用户
user := data.Conn user := data.Conn
defer utils.Close(user) defer utils.Close(user)
socks5.SendSuccess(user, client) socks.SendSuccess(user, client)
// 写入目标地址 // 写入目标地址
_, err = client.Write([]byte{byte(len(data.Dest))}) _, err = client.Write([]byte{byte(len(data.Dest))})
@@ -313,11 +312,11 @@ func (s *Service) processDataConn(client net.Conn) {
type NoAuthAuthenticator struct { type NoAuthAuthenticator struct {
} }
func (a *NoAuthAuthenticator) Method() socks5.AuthMethod { func (a *NoAuthAuthenticator) Method() socks.AuthMethod {
return socks5.NoAuth return socks.NoAuth
} }
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
// 获取用户地址 // 获取用户地址
conn, ok := writer.(net.Conn) conn, ok := writer.(net.Conn)
@@ -332,7 +331,7 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
slog.Debug("用户的地址为 " + client) slog.Debug("用户的地址为 " + client)
// 获取服务 // 获取服务
server, ok := ctx.Value("service").(*socks5.Server) server, ok := ctx.Value("service").(*socks.Server)
if !ok { if !ok {
return nil, errors.New("noAuth 认证失败,无法获取服务信息") return nil, errors.New("noAuth 认证失败,无法获取服务信息")
} }
@@ -373,8 +372,8 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
} }
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
return &socks5.AuthContext{ return &socks.AuthContext{
Method: socks5.NoAuth, Method: socks.NoAuth,
Timeout: uint(timeout), Timeout: uint(timeout),
Payload: nil, Payload: nil,
}, nil }, nil
@@ -383,11 +382,11 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
type UserPassAuthenticator struct { type UserPassAuthenticator struct {
} }
func (a *UserPassAuthenticator) Method() socks5.AuthMethod { func (a *UserPassAuthenticator) Method() socks.AuthMethod {
return socks5.UserPassAuth return socks.UserPassAuth
} }
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) { func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
// 检查认证版本 // 检查认证版本
slog.Debug("验证认证版本") slog.Debug("验证认证版本")
@@ -395,8 +394,8 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
if err != nil { if err != nil {
return nil, errors.Wrap(err, "读取版本号失败") return nil, errors.Wrap(err, "读取版本号失败")
} }
if v != socks5.AuthVersion { if v != socks.AuthVersion {
_, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure}) _, err := writer.Write([]byte{socks.Version, socks.AuthFailure})
if err != nil { if err != nil {
return nil, errors.Wrap(err, "响应认证失败") return nil, errors.Wrap(err, "响应认证失败")
} }
@@ -484,14 +483,14 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
} }
// 响应认证成功 // 响应认证成功
_, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess}) _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess})
if err != nil { if err != nil {
slog.Error("响应认证失败", "err", err) slog.Error("响应认证失败", "err", err)
return nil, err return nil, err
} }
return &socks5.AuthContext{ return &socks.AuthContext{
Method: socks5.UserPassAuth, Method: socks.UserPassAuth,
Timeout: uint(timeout), Timeout: uint(timeout),
Payload: nil, Payload: nil,
}, nil }, nil

View File

@@ -6,7 +6,7 @@ import (
"io" "io"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/pkg/socks5" socks6 "proxy-server/server/fwd/socks"
"reflect" "reflect"
"testing" "testing"
) )
@@ -49,7 +49,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
name string name string
args args args args
wantWriter string wantWriter string
want *socks5.AuthContext want *socks6.AuthContext
wantErr bool wantErr bool
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
@@ -76,7 +76,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
func TestNoAuthAuthenticator_Method(t *testing.T) { func TestNoAuthAuthenticator_Method(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
want socks5.AuthMethod want socks6.AuthMethod
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
} }
@@ -93,7 +93,7 @@ func TestNoAuthAuthenticator_Method(t *testing.T) {
func TestService_Run(t *testing.T) { func TestService_Run(t *testing.T) {
type fields struct { type fields struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -124,7 +124,7 @@ func TestService_Run(t *testing.T) {
func TestService_processCtrlConn(t *testing.T) { func TestService_processCtrlConn(t *testing.T) {
type fields struct { type fields struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -154,7 +154,7 @@ func TestService_processCtrlConn(t *testing.T) {
func TestService_processDataConn(t *testing.T) { func TestService_processDataConn(t *testing.T) {
type fields struct { type fields struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -184,7 +184,7 @@ func TestService_processDataConn(t *testing.T) {
func TestService_startCtrlTun(t *testing.T) { func TestService_startCtrlTun(t *testing.T) {
type fields struct { type fields struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -215,7 +215,7 @@ func TestService_startCtrlTun(t *testing.T) {
func TestService_startDataTun(t *testing.T) { func TestService_startDataTun(t *testing.T) {
type fields struct { type fields struct {
Config *Config Config *Config
connMap map[string]socks5.ProxyData connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup dataConnWg utils.CountWaitGroup
} }
@@ -252,7 +252,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
name string name string
args args args args
wantWriter string wantWriter string
want *socks5.AuthContext want *socks6.AuthContext
wantErr bool wantErr bool
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
@@ -279,7 +279,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
func TestUserPassAuthenticator_Method(t *testing.T) { func TestUserPassAuthenticator_Method(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
want socks5.AuthMethod want socks6.AuthMethod
}{ }{
// TODO: Add test cases. // TODO: Add test cases.
} }

View File

@@ -1,4 +1,4 @@
package socks5 package socks
import ( import (
"context" "context"
@@ -27,7 +27,7 @@ type Authenticator interface {
} }
// authenticate 执行认证流程 // authenticate 执行认证流程
func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) { func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
// 版本检查 // 版本检查
err := checkVersion(reader) err := checkVersion(reader)
@@ -46,18 +46,18 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
} }
// 认证客户端连接 // 认证客户端连接
for _, authenticator := range server.config.AuthMethods { for _, authenticator := range s.config.AuthMethods {
method := authenticator.Method() method := authenticator.Method()
if slices.Contains(methods, byte(method)) { if slices.Contains(methods, byte(method)) {
slog.Debug("使用的认证方式", method) slog.Debug("使用的认证方式", method)
_, err := writer.Write([]byte{SocksVersion, byte(method)}) _, err := writer.Write([]byte{Version, byte(method)})
if err != nil { if err != nil {
slog.Error("响应认证方式失败", err) slog.Error("响应认证方式失败", err)
return nil, err return nil, err
} }
ctx := context.WithValue(context.Background(), "service", server) ctx := context.WithValue(context.Background(), "service", s)
authContext, err := authenticator.Authenticate(ctx, reader, writer) authContext, err := authenticator.Authenticate(ctx, reader, writer)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -67,7 +67,7 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
} }
// 无适用的认证方式 // 无适用的认证方式
_, err = writer.Write([]byte{SocksVersion, NoAcceptable}) _, err = writer.Write([]byte{Version, NoAcceptable})
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,4 +1,4 @@
package socks5 package socks
type ConfigError string type ConfigError string

View File

@@ -1,4 +1,4 @@
package socks5 package socks
import ( import (
"context" "context"
@@ -65,7 +65,7 @@ func (a AddrSpec) Address() string {
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port)) return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
} }
func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) { func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
// 检查版本 // 检查版本
err := checkVersion(reader) err := checkVersion(reader)
@@ -95,13 +95,13 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
} }
// 获取目标地址 // 获取目标地址
dest, err := server.parseTarget(reader, writer) dest, err := s.parseTarget(reader, writer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
request := &Request{ request := &Request{
Version: SocksVersion, Version: Version,
Command: command, Command: command,
DestAddr: dest, DestAddr: dest,
bufConn: reader, bufConn: reader,
@@ -110,7 +110,7 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
return request, nil return request, nil
} }
func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) { func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
dest := &AddrSpec{} dest := &AddrSpec{}
aTypeBuf := make([]byte, 1) aTypeBuf := make([]byte, 1)
@@ -152,7 +152,7 @@ func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec
dest.FQDN = string(fqdnBuff) dest.FQDN = string(fqdnBuff)
// 域名解析 // 域名解析
addr, err := server.config.Resolver.Resolve(dest.FQDN) addr, err := s.config.Resolver.Resolve(dest.FQDN)
if err != nil { if err != nil {
err := sendReply(writer, hostUnreachable, nil) err := sendReply(writer, hostUnreachable, nil)
if err != nil { if err != nil {
@@ -197,33 +197,33 @@ type Request struct {
bufConn io.Reader bufConn io.Reader
} }
func (server *Server) handle(req *Request, conn net.Conn) error { func (s *Server) handle(req *Request, conn net.Conn) error {
ctx := context.Background() ctx := context.Background()
// 目标地址重写 // 目标地址重写
req.realDestAddr = req.DestAddr req.realDestAddr = req.DestAddr
if server.config.Rewriter != nil { if s.config.Rewriter != nil {
ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req) ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
} }
// 根据协商方法建立连接 // 根据协商方法建立连接
switch req.Command { switch req.Command {
case ConnectCommand: case ConnectCommand:
return server.handleConnect(ctx, conn, req) return s.handleConnect(ctx, conn, req)
case BindCommand: case BindCommand:
return server.handleBind(ctx, conn, req) return s.handleBind(ctx, conn, req)
case AssociateCommand: case AssociateCommand:
return server.handleAssociate(ctx, conn, req) return s.handleAssociate(ctx, conn, req)
default: default:
return fmt.Errorf("unsupported command: %v", req.Command) return fmt.Errorf("unsupported command: %v", req.Command)
} }
} }
func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error { func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
// 检查规则集约束 // 检查规则集约束
server.config.Logger.Printf("检查约束规则\n") s.config.Logger.Printf("检查约束规则\n")
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
@@ -233,12 +233,12 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
} }
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接") slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
server.Conn <- ProxyData{conn, req.realDestAddr.Address()} s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
return nil return nil
// 与目标服务器建立连接 // 与目标服务器建立连接
// server.config.Logger.Printf("与目标服务器建立连接\n") // s.config.Logger.Printf("与目标服务器建立连接\n")
// dial := server.config.Dial // dial := s.config.Dial
// target, err := dial("tcp", req.realDestAddr.Address()) // target, err := dial("tcp", req.realDestAddr.Address())
// if err != nil { // if err != nil {
// msg := err.Error() // msg := err.Error()
@@ -271,7 +271,7 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
// timeout := req.AuthContext.Timeout // timeout := req.AuthContext.Timeout
// slog.Debug("超时时间", "timeout", timeout) // slog.Debug("超时时间", "timeout", timeout)
// //
// timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) // timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
// defer cancel() // defer cancel()
// //
// // 代理流量 // // 代理流量
@@ -304,9 +304,9 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
} }
func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error { func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
@@ -322,9 +322,9 @@ func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Reques
return nil return nil
} }
func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error { func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err) return fmt.Errorf("failed to send reply: %v", err)
} }
@@ -370,7 +370,7 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
} }
msg := make([]byte, 6+len(addrBody)) msg := make([]byte, 6+len(addrBody))
msg[0] = SocksVersion msg[0] = Version
msg[1] = resp msg[1] = resp
msg[2] = 0 // Reserved msg[2] = 0 // Reserved
msg[3] = addrType msg[3] = addrType

View File

@@ -1,4 +1,4 @@
package socks5 package socks
import ( import (
"net" "net"

View File

@@ -1,4 +1,4 @@
package socks5 package socks
import ( import (
"context" "context"

View File

@@ -1,8 +1,8 @@
package socks5 package socks
import ( import (
"bufio" "bufio"
"errors" "context"
"fmt" "fmt"
"io" "io"
"log" "log"
@@ -11,10 +11,13 @@ import (
"os" "os"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"strconv" "strconv"
"time"
"github.com/pkg/errors"
) )
const ( const (
SocksVersion = byte(5) Version = byte(5)
) )
type Config struct { type Config struct {
@@ -47,6 +50,9 @@ type Config struct {
type Server struct { type Server struct {
config *Config config *Config
ctx context.Context
cancel context.CancelFunc
wg utils.CountWaitGroup
Name string Name string
Port uint16 Port uint16
Conn chan ProxyData Conn chan ProxyData
@@ -76,8 +82,12 @@ func New(conf *Config) (*Server, error) {
} }
} }
ctx, cancel := context.WithCancel(context.Background())
return &Server{ return &Server{
config: conf, config: conf,
ctx: ctx,
cancel: cancel,
wg: utils.CountWaitGroup{},
Name: conf.Name, Name: conf.Name,
Port: conf.Port, Port: conf.Port,
Conn: make(chan ProxyData, 100), Conn: make(chan ProxyData, 100),
@@ -85,46 +95,83 @@ func New(conf *Config) (*Server, error) {
} }
// Run 监听端口 // Run 监听端口
func (server *Server) Run() error { func (s *Server) Run() error {
host := server.config.Host
port := server.config.Port
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
slog.Info("启动 socks5 代理服务") slog.Info("启动 socks5 代理服务")
listener, err := net.Listen("tcp", addr) // 监听端口
host := s.config.Host
port := s.config.Port
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
ls, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return err return errors.Wrap(err, "监听端口失败")
} }
defer utils.Close(listener) defer utils.Close(ls)
slog.Info("正在监听端口", slog.Uint64("port", uint64(port)))
slog.Info("代理服务已启动,正在监听端口 " + addr) // 处理连接
connCh := utils.ChanConnAccept(s.ctx, ls)
defer close(connCh)
for { err = nil
conn, err := listener.Accept() for loop := true; loop; {
if err != nil { select {
slog.Error("客户端连接失败", err) case <-s.ctx.Done():
continue slog.Debug("服务主动停止")
} loop = false
case conn, ok := <-connCh:
go func() { if !ok {
err := server.serve(conn) err = errors.New("意外错误,无法获取连接")
if err != nil { loop = false
slog.Error("连接异常退出", err) break
} }
}() s.wg.Add(1)
go func() {
defer s.wg.Done()
err := s.serve(conn)
if err != nil {
slog.Error("处理连接失败", err)
}
}()
}
} }
if err != nil {
s.Close()
}
// 关闭服务
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wgCh := utils.ChanWgWait(timeout, &s.wg)
err = nil
select {
case <-timeout.Done():
err = errors.New("关闭超时(强制关闭)")
case <-wgCh:
}
if s.Conn != nil {
close(s.Conn)
}
return err
}
// Close 关闭服务
func (s *Server) Close() {
s.cancel()
} }
// serve 建立连接 // serve 建立连接
func (server *Server) serve(conn net.Conn) error { func (s *Server) serve(conn net.Conn) error {
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接") slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// 认证 // 认证
slog.Debug("开始认证流程") slog.Debug("开始认证流程")
authContext, err := server.authenticate(reader, conn) authContext, err := s.authenticate(reader, conn)
if err != nil { if err != nil {
utils.Close(conn) utils.Close(conn)
slog.Error("认证失败", err) slog.Error("认证失败", err)
@@ -135,7 +182,7 @@ func (server *Server) serve(conn net.Conn) error {
// 处理连接请求 // 处理连接请求
slog.Debug("处理连接请求") slog.Debug("处理连接请求")
request, err := server.request(reader, conn) request, err := s.request(reader, conn)
if err != nil { if err != nil {
slog.Error("连接请求处理失败", err) slog.Error("连接请求处理失败", err)
return err return err
@@ -156,7 +203,7 @@ func (server *Server) serve(conn net.Conn) error {
// 处理请求 // 处理请求
slog.Debug("开始代理流量") slog.Debug("开始代理流量")
err = server.handle(request, conn) err = s.handle(request, conn)
if err != nil { if err != nil {
return err return err
} }
@@ -173,7 +220,7 @@ func checkVersion(reader io.Reader) error {
slog.Debug("客户端请求版本", "version", version) slog.Debug("客户端请求版本", "version", version)
if version != SocksVersion { if version != Version {
return errors.New("客户端版本不兼容") return errors.New("客户端版本不兼容")
} }

135
template/service/service.go Normal file
View File

@@ -0,0 +1,135 @@
package service
import (
"context"
"errors"
"io"
"log"
"net"
"strconv"
"sync"
"time"
)
type Config struct {
Host string
Port uint16
CloseWait time.Duration
}
type Server struct {
config *Config
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func New(conf *Config) (*Server, error) {
if conf.Host == "" {
conf.Host = "localhost"
}
if conf.Port == 0 {
return nil, errors.New("port is required")
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
conf,
ctx,
cancel,
sync.WaitGroup{},
}, nil
}
func (s *Server) Run() error {
// start listen
addr := net.JoinHostPort(s.config.Host, strconv.Itoa(int(s.config.Port)))
ls, err := net.Listen("tcp", addr)
if err != nil {
return errors.New("failed to listen")
}
defer closeRes(ls)
// wait accept
connCh := make(chan net.Conn)
defer close(connCh)
go func() {
for {
conn, err := ls.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
log.Println("accept failed", err)
}
// retry on temporary error
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return
}
select {
case <-s.ctx.Done():
closeRes(conn)
return
case connCh <- conn:
}
}
}()
// handle accept
func() {
for {
select {
case <-s.ctx.Done():
return
case conn, ok := <-connCh:
if !ok {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
_ = s.handle(conn)
}()
}
}
}()
// close
timeout, cancel := context.WithTimeout(context.Background(), s.config.CloseWait)
defer cancel()
waitCh := make(chan struct{})
defer close(waitCh)
go func() {
s.wg.Wait()
select {
case <-timeout.Done(): // waitCh may be closed
case waitCh <- struct{}{}:
}
}()
err = nil
select {
case <-timeout.Done():
err = errors.New("close timeout")
case <-waitCh:
}
return err
}
func (s *Server) Close() {
s.cancel()
}
func (s *Server) handle(conn net.Conn) error {
defer closeRes(conn)
return nil
}
func closeRes[T io.Closer](res T) {
_ = res.Close()
}

View File

@@ -2,7 +2,7 @@ package fwd
import ( import (
"net" "net"
"proxy-server/server/pkg/socks5" socks6 "proxy-server/server/fwd/socks"
"testing" "testing"
) )
@@ -26,7 +26,7 @@ func fakeRequest() {
} }
// 发送认证请求 // 发送认证请求
_, err = conn.Write([]byte{socks5.SocksVersion, byte(1), byte(socks5.NoAuth)}) _, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)})
if err != nil { if err != nil {
panic(err) panic(err)
} }