优化 socks 解析流程
This commit is contained in:
16
README.md
16
README.md
@@ -18,12 +18,28 @@ fwd 使用自定义 context 实现在一个上下文中控制 cancel,errCh 和
|
||||
|
||||
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
|
||||
|
||||
数据通道池化
|
||||
|
||||
### 长期
|
||||
|
||||
代理端口支持混合端口转发(支持 tcp_mux)
|
||||
|
||||
数据通道支持 tcp 多路复用(分离逻辑流)
|
||||
|
||||
👆 进阶黑魔法 multipath tcp + 多路复用
|
||||
|
||||
考虑一下连接安全性
|
||||
|
||||
内部接口 rtt 是否还有优化空间(当前30-300ms,根据内容大小增长)
|
||||
|
||||
### 代码清理
|
||||
|
||||
检查 slog 级别:
|
||||
|
||||
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印
|
||||
|
||||
其他级别日志就地打印,Info 只用来跟踪关键流程
|
||||
|
||||
## 开发相关
|
||||
|
||||
### 环境变量
|
||||
|
||||
@@ -8,12 +8,12 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||
connCh := make(chan net.Conn)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ls.Accept()
|
||||
if err != nil {
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
slog.Error("接受连接失败", err)
|
||||
// 临时错误重试连接
|
||||
var ne net.Error
|
||||
@@ -35,7 +35,7 @@ func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||
return connCh
|
||||
}
|
||||
|
||||
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} {
|
||||
func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/fwd/socks"
|
||||
"proxy-server/server/pkg/env"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"proxy-server/server/pkg/socks5"
|
||||
"proxy-server/server/web/app/models"
|
||||
"strconv"
|
||||
"time"
|
||||
@@ -23,7 +23,7 @@ type Config struct {
|
||||
|
||||
type Service struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -36,7 +36,7 @@ func New(config *Config) *Service {
|
||||
|
||||
return &Service{
|
||||
Config: _config,
|
||||
connMap: make(map[string]socks5.ProxyData),
|
||||
connMap: make(map[string]socks.ProxyData),
|
||||
ctrlConnWg: utils.CountWaitGroup{},
|
||||
dataConnWg: utils.CountWaitGroup{},
|
||||
}
|
||||
@@ -95,7 +95,7 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) {
|
||||
defer utils.Close(ls)
|
||||
|
||||
// 等待连接
|
||||
connCh := utils.ConnChan(ctx, ls)
|
||||
connCh := utils.ChanConnAccept(ctx, ls)
|
||||
defer close(connCh)
|
||||
|
||||
// 处理连接
|
||||
@@ -119,7 +119,7 @@ loop:
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
procCh := utils.WaitChan(timeout, &s.ctrlConnWg)
|
||||
procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg)
|
||||
defer close(procCh)
|
||||
|
||||
select {
|
||||
@@ -144,20 +144,19 @@ func (s *Service) processCtrlConn(controller net.Conn) {
|
||||
reader := bufio.NewReader(controller)
|
||||
|
||||
// 读取端口
|
||||
portBuf := make([]byte, 2)
|
||||
_, err := io.ReadFull(reader, portBuf)
|
||||
portBuf, err := utils.ReadBuffer(reader, 2)
|
||||
if err != nil {
|
||||
slog.Error("读取转发端口失败", "err", err)
|
||||
slog.Error("接收转发端口失败", "err", err)
|
||||
return
|
||||
}
|
||||
port := binary.BigEndian.Uint16(portBuf)
|
||||
|
||||
// 新建代理服务
|
||||
slog.Info("新建代理服务", "port", port)
|
||||
proxy, err := socks5.New(&socks5.Config{
|
||||
proxy, err := socks.New(&socks.Config{
|
||||
Name: strconv.Itoa(int(port)),
|
||||
Port: port,
|
||||
AuthMethods: []socks5.Authenticator{
|
||||
AuthMethods: []socks.Authenticator{
|
||||
&UserPassAuthenticator{},
|
||||
&NoAuthAuthenticator{},
|
||||
},
|
||||
@@ -207,7 +206,7 @@ func (s *Service) startDataTun(ctx context.Context, errCh chan error) {
|
||||
defer utils.Close(lData)
|
||||
|
||||
// 等待连接
|
||||
connCh := utils.ConnChan(ctx, lData)
|
||||
connCh := utils.ChanConnAccept(ctx, lData)
|
||||
defer close(connCh)
|
||||
|
||||
// 处理连接
|
||||
@@ -231,7 +230,7 @@ loop:
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
procCh := utils.WaitChan(timeout, &s.dataConnWg)
|
||||
procCh := utils.ChanWgWait(timeout, &s.dataConnWg)
|
||||
defer close(procCh)
|
||||
|
||||
select {
|
||||
@@ -275,7 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
|
||||
// 响应用户
|
||||
user := data.Conn
|
||||
defer utils.Close(user)
|
||||
socks5.SendSuccess(user, client)
|
||||
socks.SendSuccess(user, client)
|
||||
|
||||
// 写入目标地址
|
||||
_, err = client.Write([]byte{byte(len(data.Dest))})
|
||||
@@ -313,11 +312,11 @@ func (s *Service) processDataConn(client net.Conn) {
|
||||
type NoAuthAuthenticator struct {
|
||||
}
|
||||
|
||||
func (a *NoAuthAuthenticator) Method() socks5.AuthMethod {
|
||||
return socks5.NoAuth
|
||||
func (a *NoAuthAuthenticator) Method() socks.AuthMethod {
|
||||
return socks.NoAuth
|
||||
}
|
||||
|
||||
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
|
||||
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
||||
|
||||
// 获取用户地址
|
||||
conn, ok := writer.(net.Conn)
|
||||
@@ -332,7 +331,7 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
||||
slog.Debug("用户的地址为 " + client)
|
||||
|
||||
// 获取服务
|
||||
server, ok := ctx.Value("service").(*socks5.Server)
|
||||
server, ok := ctx.Value("service").(*socks.Server)
|
||||
if !ok {
|
||||
return nil, errors.New("noAuth 认证失败,无法获取服务信息")
|
||||
}
|
||||
@@ -373,8 +372,8 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
||||
}
|
||||
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
|
||||
|
||||
return &socks5.AuthContext{
|
||||
Method: socks5.NoAuth,
|
||||
return &socks.AuthContext{
|
||||
Method: socks.NoAuth,
|
||||
Timeout: uint(timeout),
|
||||
Payload: nil,
|
||||
}, nil
|
||||
@@ -383,11 +382,11 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
||||
type UserPassAuthenticator struct {
|
||||
}
|
||||
|
||||
func (a *UserPassAuthenticator) Method() socks5.AuthMethod {
|
||||
return socks5.UserPassAuth
|
||||
func (a *UserPassAuthenticator) Method() socks.AuthMethod {
|
||||
return socks.UserPassAuth
|
||||
}
|
||||
|
||||
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
|
||||
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
||||
|
||||
// 检查认证版本
|
||||
slog.Debug("验证认证版本")
|
||||
@@ -395,8 +394,8 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "读取版本号失败")
|
||||
}
|
||||
if v != socks5.AuthVersion {
|
||||
_, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure})
|
||||
if v != socks.AuthVersion {
|
||||
_, err := writer.Write([]byte{socks.Version, socks.AuthFailure})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "响应认证失败")
|
||||
}
|
||||
@@ -484,14 +483,14 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
|
||||
}
|
||||
|
||||
// 响应认证成功
|
||||
_, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess})
|
||||
_, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess})
|
||||
if err != nil {
|
||||
slog.Error("响应认证失败", "err", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &socks5.AuthContext{
|
||||
Method: socks5.UserPassAuth,
|
||||
return &socks.AuthContext{
|
||||
Method: socks.UserPassAuth,
|
||||
Timeout: uint(timeout),
|
||||
Payload: nil,
|
||||
}, nil
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"proxy-server/pkg/utils"
|
||||
"proxy-server/server/pkg/socks5"
|
||||
socks6 "proxy-server/server/fwd/socks"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
@@ -49,7 +49,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
|
||||
name string
|
||||
args args
|
||||
wantWriter string
|
||||
want *socks5.AuthContext
|
||||
want *socks6.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
@@ -76,7 +76,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
|
||||
func TestNoAuthAuthenticator_Method(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want socks5.AuthMethod
|
||||
want socks6.AuthMethod
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
@@ -93,7 +93,7 @@ func TestNoAuthAuthenticator_Method(t *testing.T) {
|
||||
func TestService_Run(t *testing.T) {
|
||||
type fields struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks6.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -124,7 +124,7 @@ func TestService_Run(t *testing.T) {
|
||||
func TestService_processCtrlConn(t *testing.T) {
|
||||
type fields struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks6.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -154,7 +154,7 @@ func TestService_processCtrlConn(t *testing.T) {
|
||||
func TestService_processDataConn(t *testing.T) {
|
||||
type fields struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks6.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -184,7 +184,7 @@ func TestService_processDataConn(t *testing.T) {
|
||||
func TestService_startCtrlTun(t *testing.T) {
|
||||
type fields struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks6.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -215,7 +215,7 @@ func TestService_startCtrlTun(t *testing.T) {
|
||||
func TestService_startDataTun(t *testing.T) {
|
||||
type fields struct {
|
||||
Config *Config
|
||||
connMap map[string]socks5.ProxyData
|
||||
connMap map[string]socks6.ProxyData
|
||||
ctrlConnWg utils.CountWaitGroup
|
||||
dataConnWg utils.CountWaitGroup
|
||||
}
|
||||
@@ -252,7 +252,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
|
||||
name string
|
||||
args args
|
||||
wantWriter string
|
||||
want *socks5.AuthContext
|
||||
want *socks6.AuthContext
|
||||
wantErr bool
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
@@ -279,7 +279,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
|
||||
func TestUserPassAuthenticator_Method(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
want socks5.AuthMethod
|
||||
want socks6.AuthMethod
|
||||
}{
|
||||
// TODO: Add test cases.
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -27,7 +27,7 @@ type Authenticator interface {
|
||||
}
|
||||
|
||||
// authenticate 执行认证流程
|
||||
func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||
|
||||
// 版本检查
|
||||
err := checkVersion(reader)
|
||||
@@ -46,18 +46,18 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
|
||||
}
|
||||
|
||||
// 认证客户端连接
|
||||
for _, authenticator := range server.config.AuthMethods {
|
||||
for _, authenticator := range s.config.AuthMethods {
|
||||
method := authenticator.Method()
|
||||
if slices.Contains(methods, byte(method)) {
|
||||
slog.Debug("使用的认证方式", method)
|
||||
|
||||
_, err := writer.Write([]byte{SocksVersion, byte(method)})
|
||||
_, err := writer.Write([]byte{Version, byte(method)})
|
||||
if err != nil {
|
||||
slog.Error("响应认证方式失败", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), "service", server)
|
||||
ctx := context.WithValue(context.Background(), "service", s)
|
||||
authContext, err := authenticator.Authenticate(ctx, reader, writer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -67,7 +67,7 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
|
||||
}
|
||||
|
||||
// 无适用的认证方式
|
||||
_, err = writer.Write([]byte{SocksVersion, NoAcceptable})
|
||||
_, err = writer.Write([]byte{Version, NoAcceptable})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
type ConfigError string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -65,7 +65,7 @@ func (a AddrSpec) Address() string {
|
||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||
}
|
||||
|
||||
func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
|
||||
func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
|
||||
|
||||
// 检查版本
|
||||
err := checkVersion(reader)
|
||||
@@ -95,13 +95,13 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
|
||||
}
|
||||
|
||||
// 获取目标地址
|
||||
dest, err := server.parseTarget(reader, writer)
|
||||
dest, err := s.parseTarget(reader, writer)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
request := &Request{
|
||||
Version: SocksVersion,
|
||||
Version: Version,
|
||||
Command: command,
|
||||
DestAddr: dest,
|
||||
bufConn: reader,
|
||||
@@ -110,7 +110,7 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
|
||||
func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
|
||||
dest := &AddrSpec{}
|
||||
|
||||
aTypeBuf := make([]byte, 1)
|
||||
@@ -152,7 +152,7 @@ func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec
|
||||
dest.FQDN = string(fqdnBuff)
|
||||
|
||||
// 域名解析
|
||||
addr, err := server.config.Resolver.Resolve(dest.FQDN)
|
||||
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
||||
if err != nil {
|
||||
err := sendReply(writer, hostUnreachable, nil)
|
||||
if err != nil {
|
||||
@@ -197,33 +197,33 @@ type Request struct {
|
||||
bufConn io.Reader
|
||||
}
|
||||
|
||||
func (server *Server) handle(req *Request, conn net.Conn) error {
|
||||
func (s *Server) handle(req *Request, conn net.Conn) error {
|
||||
ctx := context.Background()
|
||||
|
||||
// 目标地址重写
|
||||
req.realDestAddr = req.DestAddr
|
||||
if server.config.Rewriter != nil {
|
||||
ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req)
|
||||
if s.config.Rewriter != nil {
|
||||
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
|
||||
}
|
||||
|
||||
// 根据协商方法建立连接
|
||||
switch req.Command {
|
||||
case ConnectCommand:
|
||||
return server.handleConnect(ctx, conn, req)
|
||||
return s.handleConnect(ctx, conn, req)
|
||||
case BindCommand:
|
||||
return server.handleBind(ctx, conn, req)
|
||||
return s.handleBind(ctx, conn, req)
|
||||
case AssociateCommand:
|
||||
return server.handleAssociate(ctx, conn, req)
|
||||
return s.handleAssociate(ctx, conn, req)
|
||||
default:
|
||||
return fmt.Errorf("unsupported command: %v", req.Command)
|
||||
}
|
||||
}
|
||||
|
||||
func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
|
||||
// 检查规则集约束
|
||||
server.config.Logger.Printf("检查约束规则\n")
|
||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
||||
s.config.Logger.Printf("检查约束规则\n")
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
}
|
||||
@@ -233,12 +233,12 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
||||
}
|
||||
|
||||
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
||||
server.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
||||
s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
||||
return nil
|
||||
|
||||
// 与目标服务器建立连接
|
||||
// server.config.Logger.Printf("与目标服务器建立连接\n")
|
||||
// dial := server.config.Dial
|
||||
// s.config.Logger.Printf("与目标服务器建立连接\n")
|
||||
// dial := s.config.Dial
|
||||
// target, err := dial("tcp", req.realDestAddr.Address())
|
||||
// if err != nil {
|
||||
// msg := err.Error()
|
||||
@@ -271,7 +271,7 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
||||
// timeout := req.AuthContext.Timeout
|
||||
// slog.Debug("超时时间", "timeout", timeout)
|
||||
//
|
||||
// timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
// timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||
// defer cancel()
|
||||
//
|
||||
// // 代理流量
|
||||
@@ -304,9 +304,9 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
||||
|
||||
}
|
||||
|
||||
func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
}
|
||||
@@ -322,9 +322,9 @@ func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Reques
|
||||
return nil
|
||||
}
|
||||
|
||||
func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
|
||||
// Check if this is allowed
|
||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||
return fmt.Errorf("failed to send reply: %v", err)
|
||||
}
|
||||
@@ -370,7 +370,7 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
|
||||
}
|
||||
|
||||
msg := make([]byte, 6+len(addrBody))
|
||||
msg[0] = SocksVersion
|
||||
msg[0] = Version
|
||||
msg[1] = resp
|
||||
msg[2] = 0 // Reserved
|
||||
msg[3] = addrType
|
||||
@@ -1,4 +1,4 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -1,4 +1,4 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,8 +1,8 @@
|
||||
package socks5
|
||||
package socks
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
@@ -11,10 +11,13 @@ import (
|
||||
"os"
|
||||
"proxy-server/pkg/utils"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
SocksVersion = byte(5)
|
||||
Version = byte(5)
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@@ -47,6 +50,9 @@ type Config struct {
|
||||
|
||||
type Server struct {
|
||||
config *Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg utils.CountWaitGroup
|
||||
Name string
|
||||
Port uint16
|
||||
Conn chan ProxyData
|
||||
@@ -76,8 +82,12 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
config: conf,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
wg: utils.CountWaitGroup{},
|
||||
Name: conf.Name,
|
||||
Port: conf.Port,
|
||||
Conn: make(chan ProxyData, 100),
|
||||
@@ -85,46 +95,83 @@ func New(conf *Config) (*Server, error) {
|
||||
}
|
||||
|
||||
// Run 监听端口
|
||||
func (server *Server) Run() error {
|
||||
host := server.config.Host
|
||||
port := server.config.Port
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
|
||||
func (s *Server) Run() error {
|
||||
slog.Info("启动 socks5 代理服务")
|
||||
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
// 监听端口
|
||||
host := s.config.Host
|
||||
port := s.config.Port
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||
ls, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "监听端口失败")
|
||||
}
|
||||
defer utils.Close(listener)
|
||||
defer utils.Close(ls)
|
||||
slog.Info("正在监听端口", slog.Uint64("port", uint64(port)))
|
||||
|
||||
slog.Info("代理服务已启动,正在监听端口 " + addr)
|
||||
// 处理连接
|
||||
connCh := utils.ChanConnAccept(s.ctx, ls)
|
||||
defer close(connCh)
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
slog.Error("客户端连接失败", err)
|
||||
continue
|
||||
}
|
||||
|
||||
go func() {
|
||||
err := server.serve(conn)
|
||||
if err != nil {
|
||||
slog.Error("连接异常退出", err)
|
||||
err = nil
|
||||
for loop := true; loop; {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
slog.Debug("服务主动停止")
|
||||
loop = false
|
||||
case conn, ok := <-connCh:
|
||||
if !ok {
|
||||
err = errors.New("意外错误,无法获取连接")
|
||||
loop = false
|
||||
break
|
||||
}
|
||||
}()
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
err := s.serve(conn)
|
||||
if err != nil {
|
||||
slog.Error("处理连接失败", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
s.Close()
|
||||
}
|
||||
|
||||
// 关闭服务
|
||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
wgCh := utils.ChanWgWait(timeout, &s.wg)
|
||||
|
||||
err = nil
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
err = errors.New("关闭超时(强制关闭)")
|
||||
case <-wgCh:
|
||||
}
|
||||
|
||||
if s.Conn != nil {
|
||||
close(s.Conn)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Close 关闭服务
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
// serve 建立连接
|
||||
func (server *Server) serve(conn net.Conn) error {
|
||||
func (s *Server) serve(conn net.Conn) error {
|
||||
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
|
||||
// 认证
|
||||
slog.Debug("开始认证流程")
|
||||
authContext, err := server.authenticate(reader, conn)
|
||||
authContext, err := s.authenticate(reader, conn)
|
||||
if err != nil {
|
||||
utils.Close(conn)
|
||||
slog.Error("认证失败", err)
|
||||
@@ -135,7 +182,7 @@ func (server *Server) serve(conn net.Conn) error {
|
||||
|
||||
// 处理连接请求
|
||||
slog.Debug("处理连接请求")
|
||||
request, err := server.request(reader, conn)
|
||||
request, err := s.request(reader, conn)
|
||||
if err != nil {
|
||||
slog.Error("连接请求处理失败", err)
|
||||
return err
|
||||
@@ -156,7 +203,7 @@ func (server *Server) serve(conn net.Conn) error {
|
||||
|
||||
// 处理请求
|
||||
slog.Debug("开始代理流量")
|
||||
err = server.handle(request, conn)
|
||||
err = s.handle(request, conn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -173,7 +220,7 @@ func checkVersion(reader io.Reader) error {
|
||||
|
||||
slog.Debug("客户端请求版本", "version", version)
|
||||
|
||||
if version != SocksVersion {
|
||||
if version != Version {
|
||||
return errors.New("客户端版本不兼容")
|
||||
}
|
||||
|
||||
135
template/service/service.go
Normal file
135
template/service/service.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Host string
|
||||
Port uint16
|
||||
CloseWait time.Duration
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
config *Config
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func New(conf *Config) (*Server, error) {
|
||||
|
||||
if conf.Host == "" {
|
||||
conf.Host = "localhost"
|
||||
}
|
||||
if conf.Port == 0 {
|
||||
return nil, errors.New("port is required")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
return &Server{
|
||||
conf,
|
||||
ctx,
|
||||
cancel,
|
||||
sync.WaitGroup{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
|
||||
// start listen
|
||||
addr := net.JoinHostPort(s.config.Host, strconv.Itoa(int(s.config.Port)))
|
||||
ls, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return errors.New("failed to listen")
|
||||
}
|
||||
defer closeRes(ls)
|
||||
|
||||
// wait accept
|
||||
connCh := make(chan net.Conn)
|
||||
defer close(connCh)
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ls.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
log.Println("accept failed", err)
|
||||
}
|
||||
// retry on temporary error
|
||||
var ne net.Error
|
||||
if errors.As(err, &ne) && ne.Temporary() {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
closeRes(conn)
|
||||
return
|
||||
case connCh <- conn:
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// handle accept
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
case conn, ok := <-connCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
_ = s.handle(conn)
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// close
|
||||
timeout, cancel := context.WithTimeout(context.Background(), s.config.CloseWait)
|
||||
defer cancel()
|
||||
|
||||
waitCh := make(chan struct{})
|
||||
defer close(waitCh)
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
select {
|
||||
case <-timeout.Done(): // waitCh may be closed
|
||||
case waitCh <- struct{}{}:
|
||||
}
|
||||
}()
|
||||
|
||||
err = nil
|
||||
select {
|
||||
case <-timeout.Done():
|
||||
err = errors.New("close timeout")
|
||||
case <-waitCh:
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.cancel()
|
||||
}
|
||||
|
||||
func (s *Server) handle(conn net.Conn) error {
|
||||
defer closeRes(conn)
|
||||
return nil
|
||||
}
|
||||
|
||||
func closeRes[T io.Closer](res T) {
|
||||
_ = res.Close()
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package fwd
|
||||
|
||||
import (
|
||||
"net"
|
||||
"proxy-server/server/pkg/socks5"
|
||||
socks6 "proxy-server/server/fwd/socks"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -26,7 +26,7 @@ func fakeRequest() {
|
||||
}
|
||||
|
||||
// 发送认证请求
|
||||
_, err = conn.Write([]byte{socks5.SocksVersion, byte(1), byte(socks5.NoAuth)})
|
||||
_, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user