优化 socks 解析流程

This commit is contained in:
2025-02-26 13:56:56 +08:00
parent b50dc3d91c
commit 7ee4ded08c
12 changed files with 301 additions and 104 deletions

View File

@@ -18,12 +18,28 @@ fwd 使用自定义 context 实现在一个上下文中控制 cancelerrCh 和
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
数据通道池化
### 长期
代理端口支持混合端口转发(支持 tcp_mux
数据通道支持 tcp 多路复用(分离逻辑流)
👆 进阶黑魔法 multipath tcp + 多路复用
考虑一下连接安全性
内部接口 rtt 是否还有优化空间当前30-300ms根据内容大小增长
### 代码清理
检查 slog 级别:
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err并附带本层业务信息return 到上层统一打印
其他级别日志就地打印Info 只用来跟踪关键流程
## 开发相关
### 环境变量

View File

@@ -8,12 +8,12 @@ import (
"github.com/pkg/errors"
)
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
connCh := make(chan net.Conn)
go func() {
for {
conn, err := ls.Accept()
if err != nil {
if err != nil && !errors.Is(err, net.ErrClosed) {
slog.Error("接受连接失败", err)
// 临时错误重试连接
var ne net.Error
@@ -35,7 +35,7 @@ func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
return connCh
}
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} {
func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} {
ch := make(chan struct{})
go func() {
wg.Wait()

View File

@@ -8,9 +8,9 @@ import (
"log/slog"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/fwd/socks"
"proxy-server/server/pkg/env"
"proxy-server/server/pkg/orm"
"proxy-server/server/pkg/socks5"
"proxy-server/server/web/app/models"
"strconv"
"time"
@@ -23,7 +23,7 @@ type Config struct {
type Service struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -36,7 +36,7 @@ func New(config *Config) *Service {
return &Service{
Config: _config,
connMap: make(map[string]socks5.ProxyData),
connMap: make(map[string]socks.ProxyData),
ctrlConnWg: utils.CountWaitGroup{},
dataConnWg: utils.CountWaitGroup{},
}
@@ -95,7 +95,7 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) {
defer utils.Close(ls)
// 等待连接
connCh := utils.ConnChan(ctx, ls)
connCh := utils.ChanConnAccept(ctx, ls)
defer close(connCh)
// 处理连接
@@ -119,7 +119,7 @@ loop:
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
procCh := utils.WaitChan(timeout, &s.ctrlConnWg)
procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg)
defer close(procCh)
select {
@@ -144,20 +144,19 @@ func (s *Service) processCtrlConn(controller net.Conn) {
reader := bufio.NewReader(controller)
// 读取端口
portBuf := make([]byte, 2)
_, err := io.ReadFull(reader, portBuf)
portBuf, err := utils.ReadBuffer(reader, 2)
if err != nil {
slog.Error("读取转发端口失败", "err", err)
slog.Error("接收转发端口失败", "err", err)
return
}
port := binary.BigEndian.Uint16(portBuf)
// 新建代理服务
slog.Info("新建代理服务", "port", port)
proxy, err := socks5.New(&socks5.Config{
proxy, err := socks.New(&socks.Config{
Name: strconv.Itoa(int(port)),
Port: port,
AuthMethods: []socks5.Authenticator{
AuthMethods: []socks.Authenticator{
&UserPassAuthenticator{},
&NoAuthAuthenticator{},
},
@@ -207,7 +206,7 @@ func (s *Service) startDataTun(ctx context.Context, errCh chan error) {
defer utils.Close(lData)
// 等待连接
connCh := utils.ConnChan(ctx, lData)
connCh := utils.ChanConnAccept(ctx, lData)
defer close(connCh)
// 处理连接
@@ -231,7 +230,7 @@ loop:
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
procCh := utils.WaitChan(timeout, &s.dataConnWg)
procCh := utils.ChanWgWait(timeout, &s.dataConnWg)
defer close(procCh)
select {
@@ -275,7 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
// 响应用户
user := data.Conn
defer utils.Close(user)
socks5.SendSuccess(user, client)
socks.SendSuccess(user, client)
// 写入目标地址
_, err = client.Write([]byte{byte(len(data.Dest))})
@@ -313,11 +312,11 @@ func (s *Service) processDataConn(client net.Conn) {
type NoAuthAuthenticator struct {
}
func (a *NoAuthAuthenticator) Method() socks5.AuthMethod {
return socks5.NoAuth
func (a *NoAuthAuthenticator) Method() socks.AuthMethod {
return socks.NoAuth
}
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
// 获取用户地址
conn, ok := writer.(net.Conn)
@@ -332,7 +331,7 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
slog.Debug("用户的地址为 " + client)
// 获取服务
server, ok := ctx.Value("service").(*socks5.Server)
server, ok := ctx.Value("service").(*socks.Server)
if !ok {
return nil, errors.New("noAuth 认证失败,无法获取服务信息")
}
@@ -373,8 +372,8 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
}
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
return &socks5.AuthContext{
Method: socks5.NoAuth,
return &socks.AuthContext{
Method: socks.NoAuth,
Timeout: uint(timeout),
Payload: nil,
}, nil
@@ -383,11 +382,11 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
type UserPassAuthenticator struct {
}
func (a *UserPassAuthenticator) Method() socks5.AuthMethod {
return socks5.UserPassAuth
func (a *UserPassAuthenticator) Method() socks.AuthMethod {
return socks.UserPassAuth
}
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
// 检查认证版本
slog.Debug("验证认证版本")
@@ -395,8 +394,8 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
if err != nil {
return nil, errors.Wrap(err, "读取版本号失败")
}
if v != socks5.AuthVersion {
_, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure})
if v != socks.AuthVersion {
_, err := writer.Write([]byte{socks.Version, socks.AuthFailure})
if err != nil {
return nil, errors.Wrap(err, "响应认证失败")
}
@@ -484,14 +483,14 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
}
// 响应认证成功
_, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess})
_, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess})
if err != nil {
slog.Error("响应认证失败", "err", err)
return nil, err
}
return &socks5.AuthContext{
Method: socks5.UserPassAuth,
return &socks.AuthContext{
Method: socks.UserPassAuth,
Timeout: uint(timeout),
Payload: nil,
}, nil

View File

@@ -6,7 +6,7 @@ import (
"io"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/pkg/socks5"
socks6 "proxy-server/server/fwd/socks"
"reflect"
"testing"
)
@@ -49,7 +49,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
name string
args args
wantWriter string
want *socks5.AuthContext
want *socks6.AuthContext
wantErr bool
}{
// TODO: Add test cases.
@@ -76,7 +76,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
func TestNoAuthAuthenticator_Method(t *testing.T) {
tests := []struct {
name string
want socks5.AuthMethod
want socks6.AuthMethod
}{
// TODO: Add test cases.
}
@@ -93,7 +93,7 @@ func TestNoAuthAuthenticator_Method(t *testing.T) {
func TestService_Run(t *testing.T) {
type fields struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -124,7 +124,7 @@ func TestService_Run(t *testing.T) {
func TestService_processCtrlConn(t *testing.T) {
type fields struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -154,7 +154,7 @@ func TestService_processCtrlConn(t *testing.T) {
func TestService_processDataConn(t *testing.T) {
type fields struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -184,7 +184,7 @@ func TestService_processDataConn(t *testing.T) {
func TestService_startCtrlTun(t *testing.T) {
type fields struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -215,7 +215,7 @@ func TestService_startCtrlTun(t *testing.T) {
func TestService_startDataTun(t *testing.T) {
type fields struct {
Config *Config
connMap map[string]socks5.ProxyData
connMap map[string]socks6.ProxyData
ctrlConnWg utils.CountWaitGroup
dataConnWg utils.CountWaitGroup
}
@@ -252,7 +252,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
name string
args args
wantWriter string
want *socks5.AuthContext
want *socks6.AuthContext
wantErr bool
}{
// TODO: Add test cases.
@@ -279,7 +279,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
func TestUserPassAuthenticator_Method(t *testing.T) {
tests := []struct {
name string
want socks5.AuthMethod
want socks6.AuthMethod
}{
// TODO: Add test cases.
}

View File

@@ -1,4 +1,4 @@
package socks5
package socks
import (
"context"
@@ -27,7 +27,7 @@ type Authenticator interface {
}
// authenticate 执行认证流程
func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
// 版本检查
err := checkVersion(reader)
@@ -46,18 +46,18 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
}
// 认证客户端连接
for _, authenticator := range server.config.AuthMethods {
for _, authenticator := range s.config.AuthMethods {
method := authenticator.Method()
if slices.Contains(methods, byte(method)) {
slog.Debug("使用的认证方式", method)
_, err := writer.Write([]byte{SocksVersion, byte(method)})
_, err := writer.Write([]byte{Version, byte(method)})
if err != nil {
slog.Error("响应认证方式失败", err)
return nil, err
}
ctx := context.WithValue(context.Background(), "service", server)
ctx := context.WithValue(context.Background(), "service", s)
authContext, err := authenticator.Authenticate(ctx, reader, writer)
if err != nil {
return nil, err
@@ -67,7 +67,7 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
}
// 无适用的认证方式
_, err = writer.Write([]byte{SocksVersion, NoAcceptable})
_, err = writer.Write([]byte{Version, NoAcceptable})
if err != nil {
return nil, err
}

View File

@@ -1,4 +1,4 @@
package socks5
package socks
type ConfigError string

View File

@@ -1,4 +1,4 @@
package socks5
package socks
import (
"context"
@@ -65,7 +65,7 @@ func (a AddrSpec) Address() string {
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
}
func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
// 检查版本
err := checkVersion(reader)
@@ -95,13 +95,13 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
}
// 获取目标地址
dest, err := server.parseTarget(reader, writer)
dest, err := s.parseTarget(reader, writer)
if err != nil {
return nil, err
}
request := &Request{
Version: SocksVersion,
Version: Version,
Command: command,
DestAddr: dest,
bufConn: reader,
@@ -110,7 +110,7 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
return request, nil
}
func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
dest := &AddrSpec{}
aTypeBuf := make([]byte, 1)
@@ -152,7 +152,7 @@ func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec
dest.FQDN = string(fqdnBuff)
// 域名解析
addr, err := server.config.Resolver.Resolve(dest.FQDN)
addr, err := s.config.Resolver.Resolve(dest.FQDN)
if err != nil {
err := sendReply(writer, hostUnreachable, nil)
if err != nil {
@@ -197,33 +197,33 @@ type Request struct {
bufConn io.Reader
}
func (server *Server) handle(req *Request, conn net.Conn) error {
func (s *Server) handle(req *Request, conn net.Conn) error {
ctx := context.Background()
// 目标地址重写
req.realDestAddr = req.DestAddr
if server.config.Rewriter != nil {
ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req)
if s.config.Rewriter != nil {
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
}
// 根据协商方法建立连接
switch req.Command {
case ConnectCommand:
return server.handleConnect(ctx, conn, req)
return s.handleConnect(ctx, conn, req)
case BindCommand:
return server.handleBind(ctx, conn, req)
return s.handleBind(ctx, conn, req)
case AssociateCommand:
return server.handleAssociate(ctx, conn, req)
return s.handleAssociate(ctx, conn, req)
default:
return fmt.Errorf("unsupported command: %v", req.Command)
}
}
func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
// 检查规则集约束
server.config.Logger.Printf("检查约束规则\n")
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
s.config.Logger.Printf("检查约束规则\n")
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
@@ -233,12 +233,12 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
}
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
server.Conn <- ProxyData{conn, req.realDestAddr.Address()}
s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
return nil
// 与目标服务器建立连接
// server.config.Logger.Printf("与目标服务器建立连接\n")
// dial := server.config.Dial
// s.config.Logger.Printf("与目标服务器建立连接\n")
// dial := s.config.Dial
// target, err := dial("tcp", req.realDestAddr.Address())
// if err != nil {
// msg := err.Error()
@@ -271,7 +271,7 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
// timeout := req.AuthContext.Timeout
// slog.Debug("超时时间", "timeout", timeout)
//
// timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
// timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
// defer cancel()
//
// // 代理流量
@@ -304,9 +304,9 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
}
func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
@@ -322,9 +322,9 @@ func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Reques
return nil
}
func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("failed to send reply: %v", err)
}
@@ -370,7 +370,7 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
}
msg := make([]byte, 6+len(addrBody))
msg[0] = SocksVersion
msg[0] = Version
msg[1] = resp
msg[2] = 0 // Reserved
msg[3] = addrType

View File

@@ -1,4 +1,4 @@
package socks5
package socks
import (
"net"

View File

@@ -1,4 +1,4 @@
package socks5
package socks
import (
"context"

View File

@@ -1,8 +1,8 @@
package socks5
package socks
import (
"bufio"
"errors"
"context"
"fmt"
"io"
"log"
@@ -11,10 +11,13 @@ import (
"os"
"proxy-server/pkg/utils"
"strconv"
"time"
"github.com/pkg/errors"
)
const (
SocksVersion = byte(5)
Version = byte(5)
)
type Config struct {
@@ -47,6 +50,9 @@ type Config struct {
type Server struct {
config *Config
ctx context.Context
cancel context.CancelFunc
wg utils.CountWaitGroup
Name string
Port uint16
Conn chan ProxyData
@@ -76,8 +82,12 @@ func New(conf *Config) (*Server, error) {
}
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
config: conf,
ctx: ctx,
cancel: cancel,
wg: utils.CountWaitGroup{},
Name: conf.Name,
Port: conf.Port,
Conn: make(chan ProxyData, 100),
@@ -85,46 +95,83 @@ func New(conf *Config) (*Server, error) {
}
// Run 监听端口
func (server *Server) Run() error {
host := server.config.Host
port := server.config.Port
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
func (s *Server) Run() error {
slog.Info("启动 socks5 代理服务")
listener, err := net.Listen("tcp", addr)
// 监听端口
host := s.config.Host
port := s.config.Port
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
ls, err := net.Listen("tcp", addr)
if err != nil {
return err
return errors.Wrap(err, "监听端口失败")
}
defer utils.Close(listener)
defer utils.Close(ls)
slog.Info("正在监听端口", slog.Uint64("port", uint64(port)))
slog.Info("代理服务已启动,正在监听端口 " + addr)
// 处理连接
connCh := utils.ChanConnAccept(s.ctx, ls)
defer close(connCh)
for {
conn, err := listener.Accept()
if err != nil {
slog.Error("客户端连接失败", err)
continue
}
go func() {
err := server.serve(conn)
if err != nil {
slog.Error("连接异常退出", err)
err = nil
for loop := true; loop; {
select {
case <-s.ctx.Done():
slog.Debug("服务主动停止")
loop = false
case conn, ok := <-connCh:
if !ok {
err = errors.New("意外错误,无法获取连接")
loop = false
break
}
}()
s.wg.Add(1)
go func() {
defer s.wg.Done()
err := s.serve(conn)
if err != nil {
slog.Error("处理连接失败", err)
}
}()
}
}
if err != nil {
s.Close()
}
// 关闭服务
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wgCh := utils.ChanWgWait(timeout, &s.wg)
err = nil
select {
case <-timeout.Done():
err = errors.New("关闭超时(强制关闭)")
case <-wgCh:
}
if s.Conn != nil {
close(s.Conn)
}
return err
}
// Close 关闭服务
func (s *Server) Close() {
s.cancel()
}
// serve 建立连接
func (server *Server) serve(conn net.Conn) error {
func (s *Server) serve(conn net.Conn) error {
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
reader := bufio.NewReader(conn)
// 认证
slog.Debug("开始认证流程")
authContext, err := server.authenticate(reader, conn)
authContext, err := s.authenticate(reader, conn)
if err != nil {
utils.Close(conn)
slog.Error("认证失败", err)
@@ -135,7 +182,7 @@ func (server *Server) serve(conn net.Conn) error {
// 处理连接请求
slog.Debug("处理连接请求")
request, err := server.request(reader, conn)
request, err := s.request(reader, conn)
if err != nil {
slog.Error("连接请求处理失败", err)
return err
@@ -156,7 +203,7 @@ func (server *Server) serve(conn net.Conn) error {
// 处理请求
slog.Debug("开始代理流量")
err = server.handle(request, conn)
err = s.handle(request, conn)
if err != nil {
return err
}
@@ -173,7 +220,7 @@ func checkVersion(reader io.Reader) error {
slog.Debug("客户端请求版本", "version", version)
if version != SocksVersion {
if version != Version {
return errors.New("客户端版本不兼容")
}

135
template/service/service.go Normal file
View File

@@ -0,0 +1,135 @@
package service
import (
"context"
"errors"
"io"
"log"
"net"
"strconv"
"sync"
"time"
)
type Config struct {
Host string
Port uint16
CloseWait time.Duration
}
type Server struct {
config *Config
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
func New(conf *Config) (*Server, error) {
if conf.Host == "" {
conf.Host = "localhost"
}
if conf.Port == 0 {
return nil, errors.New("port is required")
}
ctx, cancel := context.WithCancel(context.Background())
return &Server{
conf,
ctx,
cancel,
sync.WaitGroup{},
}, nil
}
func (s *Server) Run() error {
// start listen
addr := net.JoinHostPort(s.config.Host, strconv.Itoa(int(s.config.Port)))
ls, err := net.Listen("tcp", addr)
if err != nil {
return errors.New("failed to listen")
}
defer closeRes(ls)
// wait accept
connCh := make(chan net.Conn)
defer close(connCh)
go func() {
for {
conn, err := ls.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
log.Println("accept failed", err)
}
// retry on temporary error
var ne net.Error
if errors.As(err, &ne) && ne.Temporary() {
continue
}
return
}
select {
case <-s.ctx.Done():
closeRes(conn)
return
case connCh <- conn:
}
}
}()
// handle accept
func() {
for {
select {
case <-s.ctx.Done():
return
case conn, ok := <-connCh:
if !ok {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
_ = s.handle(conn)
}()
}
}
}()
// close
timeout, cancel := context.WithTimeout(context.Background(), s.config.CloseWait)
defer cancel()
waitCh := make(chan struct{})
defer close(waitCh)
go func() {
s.wg.Wait()
select {
case <-timeout.Done(): // waitCh may be closed
case waitCh <- struct{}{}:
}
}()
err = nil
select {
case <-timeout.Done():
err = errors.New("close timeout")
case <-waitCh:
}
return err
}
func (s *Server) Close() {
s.cancel()
}
func (s *Server) handle(conn net.Conn) error {
defer closeRes(conn)
return nil
}
func closeRes[T io.Closer](res T) {
_ = res.Close()
}

View File

@@ -2,7 +2,7 @@ package fwd
import (
"net"
"proxy-server/server/pkg/socks5"
socks6 "proxy-server/server/fwd/socks"
"testing"
)
@@ -26,7 +26,7 @@ func fakeRequest() {
}
// 发送认证请求
_, err = conn.Write([]byte{socks5.SocksVersion, byte(1), byte(socks5.NoAuth)})
_, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)})
if err != nil {
panic(err)
}