优化 socks 解析流程
This commit is contained in:
16
README.md
16
README.md
@@ -18,12 +18,28 @@ fwd 使用自定义 context 实现在一个上下文中控制 cancel,errCh 和
|
|||||||
|
|
||||||
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
|
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
|
||||||
|
|
||||||
|
数据通道池化
|
||||||
|
|
||||||
### 长期
|
### 长期
|
||||||
|
|
||||||
|
代理端口支持混合端口转发(支持 tcp_mux)
|
||||||
|
|
||||||
|
数据通道支持 tcp 多路复用(分离逻辑流)
|
||||||
|
|
||||||
|
👆 进阶黑魔法 multipath tcp + 多路复用
|
||||||
|
|
||||||
考虑一下连接安全性
|
考虑一下连接安全性
|
||||||
|
|
||||||
内部接口 rtt 是否还有优化空间(当前30-300ms,根据内容大小增长)
|
内部接口 rtt 是否还有优化空间(当前30-300ms,根据内容大小增长)
|
||||||
|
|
||||||
|
### 代码清理
|
||||||
|
|
||||||
|
检查 slog 级别:
|
||||||
|
|
||||||
|
ERR: 除非有必要,否则全部 error 都使用 `errors.Wrap()` 包裹(如果下游有返回 err),并附带本层业务信息,return 到上层统一打印
|
||||||
|
|
||||||
|
其他级别日志就地打印,Info 只用来跟踪关键流程
|
||||||
|
|
||||||
## 开发相关
|
## 开发相关
|
||||||
|
|
||||||
### 环境变量
|
### 环境变量
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ import (
|
|||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
func ChanConnAccept(ctx context.Context, ls net.Listener) chan net.Conn {
|
||||||
connCh := make(chan net.Conn)
|
connCh := make(chan net.Conn)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
conn, err := ls.Accept()
|
conn, err := ls.Accept()
|
||||||
if err != nil {
|
if err != nil && !errors.Is(err, net.ErrClosed) {
|
||||||
slog.Error("接受连接失败", err)
|
slog.Error("接受连接失败", err)
|
||||||
// 临时错误重试连接
|
// 临时错误重试连接
|
||||||
var ne net.Error
|
var ne net.Error
|
||||||
@@ -35,7 +35,7 @@ func ConnChan(ctx context.Context, ls net.Listener) chan net.Conn {
|
|||||||
return connCh
|
return connCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func WaitChan(ctx context.Context, wg *CountWaitGroup) chan struct{} {
|
func ChanWgWait(ctx context.Context, wg *CountWaitGroup) chan struct{} {
|
||||||
ch := make(chan struct{})
|
ch := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|||||||
@@ -8,9 +8,9 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
|
"proxy-server/server/fwd/socks"
|
||||||
"proxy-server/server/pkg/env"
|
"proxy-server/server/pkg/env"
|
||||||
"proxy-server/server/pkg/orm"
|
"proxy-server/server/pkg/orm"
|
||||||
"proxy-server/server/pkg/socks5"
|
|
||||||
"proxy-server/server/web/app/models"
|
"proxy-server/server/web/app/models"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,7 +23,7 @@ type Config struct {
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -36,7 +36,7 @@ func New(config *Config) *Service {
|
|||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
Config: _config,
|
Config: _config,
|
||||||
connMap: make(map[string]socks5.ProxyData),
|
connMap: make(map[string]socks.ProxyData),
|
||||||
ctrlConnWg: utils.CountWaitGroup{},
|
ctrlConnWg: utils.CountWaitGroup{},
|
||||||
dataConnWg: utils.CountWaitGroup{},
|
dataConnWg: utils.CountWaitGroup{},
|
||||||
}
|
}
|
||||||
@@ -95,7 +95,7 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) {
|
|||||||
defer utils.Close(ls)
|
defer utils.Close(ls)
|
||||||
|
|
||||||
// 等待连接
|
// 等待连接
|
||||||
connCh := utils.ConnChan(ctx, ls)
|
connCh := utils.ChanConnAccept(ctx, ls)
|
||||||
defer close(connCh)
|
defer close(connCh)
|
||||||
|
|
||||||
// 处理连接
|
// 处理连接
|
||||||
@@ -119,7 +119,7 @@ loop:
|
|||||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
procCh := utils.WaitChan(timeout, &s.ctrlConnWg)
|
procCh := utils.ChanWgWait(timeout, &s.ctrlConnWg)
|
||||||
defer close(procCh)
|
defer close(procCh)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -144,20 +144,19 @@ func (s *Service) processCtrlConn(controller net.Conn) {
|
|||||||
reader := bufio.NewReader(controller)
|
reader := bufio.NewReader(controller)
|
||||||
|
|
||||||
// 读取端口
|
// 读取端口
|
||||||
portBuf := make([]byte, 2)
|
portBuf, err := utils.ReadBuffer(reader, 2)
|
||||||
_, err := io.ReadFull(reader, portBuf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("读取转发端口失败", "err", err)
|
slog.Error("接收转发端口失败", "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
port := binary.BigEndian.Uint16(portBuf)
|
port := binary.BigEndian.Uint16(portBuf)
|
||||||
|
|
||||||
// 新建代理服务
|
// 新建代理服务
|
||||||
slog.Info("新建代理服务", "port", port)
|
slog.Info("新建代理服务", "port", port)
|
||||||
proxy, err := socks5.New(&socks5.Config{
|
proxy, err := socks.New(&socks.Config{
|
||||||
Name: strconv.Itoa(int(port)),
|
Name: strconv.Itoa(int(port)),
|
||||||
Port: port,
|
Port: port,
|
||||||
AuthMethods: []socks5.Authenticator{
|
AuthMethods: []socks.Authenticator{
|
||||||
&UserPassAuthenticator{},
|
&UserPassAuthenticator{},
|
||||||
&NoAuthAuthenticator{},
|
&NoAuthAuthenticator{},
|
||||||
},
|
},
|
||||||
@@ -207,7 +206,7 @@ func (s *Service) startDataTun(ctx context.Context, errCh chan error) {
|
|||||||
defer utils.Close(lData)
|
defer utils.Close(lData)
|
||||||
|
|
||||||
// 等待连接
|
// 等待连接
|
||||||
connCh := utils.ConnChan(ctx, lData)
|
connCh := utils.ChanConnAccept(ctx, lData)
|
||||||
defer close(connCh)
|
defer close(connCh)
|
||||||
|
|
||||||
// 处理连接
|
// 处理连接
|
||||||
@@ -231,7 +230,7 @@ loop:
|
|||||||
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
procCh := utils.WaitChan(timeout, &s.dataConnWg)
|
procCh := utils.ChanWgWait(timeout, &s.dataConnWg)
|
||||||
defer close(procCh)
|
defer close(procCh)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -275,7 +274,7 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
// 响应用户
|
// 响应用户
|
||||||
user := data.Conn
|
user := data.Conn
|
||||||
defer utils.Close(user)
|
defer utils.Close(user)
|
||||||
socks5.SendSuccess(user, client)
|
socks.SendSuccess(user, client)
|
||||||
|
|
||||||
// 写入目标地址
|
// 写入目标地址
|
||||||
_, err = client.Write([]byte{byte(len(data.Dest))})
|
_, err = client.Write([]byte{byte(len(data.Dest))})
|
||||||
@@ -313,11 +312,11 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
type NoAuthAuthenticator struct {
|
type NoAuthAuthenticator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *NoAuthAuthenticator) Method() socks5.AuthMethod {
|
func (a *NoAuthAuthenticator) Method() socks.AuthMethod {
|
||||||
return socks5.NoAuth
|
return socks.NoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
|
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
||||||
|
|
||||||
// 获取用户地址
|
// 获取用户地址
|
||||||
conn, ok := writer.(net.Conn)
|
conn, ok := writer.(net.Conn)
|
||||||
@@ -332,7 +331,7 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
|||||||
slog.Debug("用户的地址为 " + client)
|
slog.Debug("用户的地址为 " + client)
|
||||||
|
|
||||||
// 获取服务
|
// 获取服务
|
||||||
server, ok := ctx.Value("service").(*socks5.Server)
|
server, ok := ctx.Value("service").(*socks.Server)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("noAuth 认证失败,无法获取服务信息")
|
return nil, errors.New("noAuth 认证失败,无法获取服务信息")
|
||||||
}
|
}
|
||||||
@@ -373,8 +372,8 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
|||||||
}
|
}
|
||||||
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
|
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
|
||||||
|
|
||||||
return &socks5.AuthContext{
|
return &socks.AuthContext{
|
||||||
Method: socks5.NoAuth,
|
Method: socks.NoAuth,
|
||||||
Timeout: uint(timeout),
|
Timeout: uint(timeout),
|
||||||
Payload: nil,
|
Payload: nil,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -383,11 +382,11 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
|||||||
type UserPassAuthenticator struct {
|
type UserPassAuthenticator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserPassAuthenticator) Method() socks5.AuthMethod {
|
func (a *UserPassAuthenticator) Method() socks.AuthMethod {
|
||||||
return socks5.UserPassAuth
|
return socks.UserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks5.AuthContext, error) {
|
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
||||||
|
|
||||||
// 检查认证版本
|
// 检查认证版本
|
||||||
slog.Debug("验证认证版本")
|
slog.Debug("验证认证版本")
|
||||||
@@ -395,8 +394,8 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "读取版本号失败")
|
return nil, errors.Wrap(err, "读取版本号失败")
|
||||||
}
|
}
|
||||||
if v != socks5.AuthVersion {
|
if v != socks.AuthVersion {
|
||||||
_, err := writer.Write([]byte{socks5.SocksVersion, socks5.AuthFailure})
|
_, err := writer.Write([]byte{socks.Version, socks.AuthFailure})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "响应认证失败")
|
return nil, errors.Wrap(err, "响应认证失败")
|
||||||
}
|
}
|
||||||
@@ -484,14 +483,14 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 响应认证成功
|
// 响应认证成功
|
||||||
_, err = writer.Write([]byte{socks5.AuthVersion, socks5.AuthSuccess})
|
_, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("响应认证失败", "err", err)
|
slog.Error("响应认证失败", "err", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &socks5.AuthContext{
|
return &socks.AuthContext{
|
||||||
Method: socks5.UserPassAuth,
|
Method: socks.UserPassAuth,
|
||||||
Timeout: uint(timeout),
|
Timeout: uint(timeout),
|
||||||
Payload: nil,
|
Payload: nil,
|
||||||
}, nil
|
}, nil
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
"proxy-server/server/pkg/socks5"
|
socks6 "proxy-server/server/fwd/socks"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@@ -49,7 +49,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantWriter string
|
wantWriter string
|
||||||
want *socks5.AuthContext
|
want *socks6.AuthContext
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
// TODO: Add test cases.
|
||||||
@@ -76,7 +76,7 @@ func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
|
|||||||
func TestNoAuthAuthenticator_Method(t *testing.T) {
|
func TestNoAuthAuthenticator_Method(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
want socks5.AuthMethod
|
want socks6.AuthMethod
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
// TODO: Add test cases.
|
||||||
}
|
}
|
||||||
@@ -93,7 +93,7 @@ func TestNoAuthAuthenticator_Method(t *testing.T) {
|
|||||||
func TestService_Run(t *testing.T) {
|
func TestService_Run(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks6.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -124,7 +124,7 @@ func TestService_Run(t *testing.T) {
|
|||||||
func TestService_processCtrlConn(t *testing.T) {
|
func TestService_processCtrlConn(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks6.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -154,7 +154,7 @@ func TestService_processCtrlConn(t *testing.T) {
|
|||||||
func TestService_processDataConn(t *testing.T) {
|
func TestService_processDataConn(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks6.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -184,7 +184,7 @@ func TestService_processDataConn(t *testing.T) {
|
|||||||
func TestService_startCtrlTun(t *testing.T) {
|
func TestService_startCtrlTun(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks6.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -215,7 +215,7 @@ func TestService_startCtrlTun(t *testing.T) {
|
|||||||
func TestService_startDataTun(t *testing.T) {
|
func TestService_startDataTun(t *testing.T) {
|
||||||
type fields struct {
|
type fields struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks5.ProxyData
|
connMap map[string]socks6.ProxyData
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -252,7 +252,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
args args
|
args args
|
||||||
wantWriter string
|
wantWriter string
|
||||||
want *socks5.AuthContext
|
want *socks6.AuthContext
|
||||||
wantErr bool
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
// TODO: Add test cases.
|
||||||
@@ -279,7 +279,7 @@ func TestUserPassAuthenticator_Authenticate(t *testing.T) {
|
|||||||
func TestUserPassAuthenticator_Method(t *testing.T) {
|
func TestUserPassAuthenticator_Method(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
want socks5.AuthMethod
|
want socks6.AuthMethod
|
||||||
}{
|
}{
|
||||||
// TODO: Add test cases.
|
// TODO: Add test cases.
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -27,7 +27,7 @@ type Authenticator interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// authenticate 执行认证流程
|
// authenticate 执行认证流程
|
||||||
func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
||||||
|
|
||||||
// 版本检查
|
// 版本检查
|
||||||
err := checkVersion(reader)
|
err := checkVersion(reader)
|
||||||
@@ -46,18 +46,18 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 认证客户端连接
|
// 认证客户端连接
|
||||||
for _, authenticator := range server.config.AuthMethods {
|
for _, authenticator := range s.config.AuthMethods {
|
||||||
method := authenticator.Method()
|
method := authenticator.Method()
|
||||||
if slices.Contains(methods, byte(method)) {
|
if slices.Contains(methods, byte(method)) {
|
||||||
slog.Debug("使用的认证方式", method)
|
slog.Debug("使用的认证方式", method)
|
||||||
|
|
||||||
_, err := writer.Write([]byte{SocksVersion, byte(method)})
|
_, err := writer.Write([]byte{Version, byte(method)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("响应认证方式失败", err)
|
slog.Error("响应认证方式失败", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.WithValue(context.Background(), "service", server)
|
ctx := context.WithValue(context.Background(), "service", s)
|
||||||
authContext, err := authenticator.Authenticate(ctx, reader, writer)
|
authContext, err := authenticator.Authenticate(ctx, reader, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -67,7 +67,7 @@ func (server *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthCon
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 无适用的认证方式
|
// 无适用的认证方式
|
||||||
_, err = writer.Write([]byte{SocksVersion, NoAcceptable})
|
_, err = writer.Write([]byte{Version, NoAcceptable})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
type ConfigError string
|
type ConfigError string
|
||||||
|
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -65,7 +65,7 @@ func (a AddrSpec) Address() string {
|
|||||||
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
|
func (s *Server) request(reader io.Reader, writer io.Writer) (*Request, error) {
|
||||||
|
|
||||||
// 检查版本
|
// 检查版本
|
||||||
err := checkVersion(reader)
|
err := checkVersion(reader)
|
||||||
@@ -95,13 +95,13 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取目标地址
|
// 获取目标地址
|
||||||
dest, err := server.parseTarget(reader, writer)
|
dest, err := s.parseTarget(reader, writer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
request := &Request{
|
request := &Request{
|
||||||
Version: SocksVersion,
|
Version: Version,
|
||||||
Command: command,
|
Command: command,
|
||||||
DestAddr: dest,
|
DestAddr: dest,
|
||||||
bufConn: reader,
|
bufConn: reader,
|
||||||
@@ -110,7 +110,7 @@ func (server *Server) request(reader io.Reader, writer io.Writer) (*Request, err
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
|
func (s *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec, error) {
|
||||||
dest := &AddrSpec{}
|
dest := &AddrSpec{}
|
||||||
|
|
||||||
aTypeBuf := make([]byte, 1)
|
aTypeBuf := make([]byte, 1)
|
||||||
@@ -152,7 +152,7 @@ func (server *Server) parseTarget(reader io.Reader, writer io.Writer) (*AddrSpec
|
|||||||
dest.FQDN = string(fqdnBuff)
|
dest.FQDN = string(fqdnBuff)
|
||||||
|
|
||||||
// 域名解析
|
// 域名解析
|
||||||
addr, err := server.config.Resolver.Resolve(dest.FQDN)
|
addr, err := s.config.Resolver.Resolve(dest.FQDN)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err := sendReply(writer, hostUnreachable, nil)
|
err := sendReply(writer, hostUnreachable, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -197,33 +197,33 @@ type Request struct {
|
|||||||
bufConn io.Reader
|
bufConn io.Reader
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) handle(req *Request, conn net.Conn) error {
|
func (s *Server) handle(req *Request, conn net.Conn) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// 目标地址重写
|
// 目标地址重写
|
||||||
req.realDestAddr = req.DestAddr
|
req.realDestAddr = req.DestAddr
|
||||||
if server.config.Rewriter != nil {
|
if s.config.Rewriter != nil {
|
||||||
ctx, req.realDestAddr = server.config.Rewriter.Rewrite(ctx, req)
|
ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 根据协商方法建立连接
|
// 根据协商方法建立连接
|
||||||
switch req.Command {
|
switch req.Command {
|
||||||
case ConnectCommand:
|
case ConnectCommand:
|
||||||
return server.handleConnect(ctx, conn, req)
|
return s.handleConnect(ctx, conn, req)
|
||||||
case BindCommand:
|
case BindCommand:
|
||||||
return server.handleBind(ctx, conn, req)
|
return s.handleBind(ctx, conn, req)
|
||||||
case AssociateCommand:
|
case AssociateCommand:
|
||||||
return server.handleAssociate(ctx, conn, req)
|
return s.handleAssociate(ctx, conn, req)
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unsupported command: %v", req.Command)
|
return fmt.Errorf("unsupported command: %v", req.Command)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
||||||
|
|
||||||
// 检查规则集约束
|
// 检查规则集约束
|
||||||
server.config.Logger.Printf("检查约束规则\n")
|
s.config.Logger.Printf("检查约束规则\n")
|
||||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("failed to send reply: %v", err)
|
return fmt.Errorf("failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
@@ -233,12 +233,12 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
||||||
server.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
// 与目标服务器建立连接
|
// 与目标服务器建立连接
|
||||||
// server.config.Logger.Printf("与目标服务器建立连接\n")
|
// s.config.Logger.Printf("与目标服务器建立连接\n")
|
||||||
// dial := server.config.Dial
|
// dial := s.config.Dial
|
||||||
// target, err := dial("tcp", req.realDestAddr.Address())
|
// target, err := dial("tcp", req.realDestAddr.Address())
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
// msg := err.Error()
|
// msg := err.Error()
|
||||||
@@ -271,7 +271,7 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
|||||||
// timeout := req.AuthContext.Timeout
|
// timeout := req.AuthContext.Timeout
|
||||||
// slog.Debug("超时时间", "timeout", timeout)
|
// slog.Debug("超时时间", "timeout", timeout)
|
||||||
//
|
//
|
||||||
// timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
// timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
||||||
// defer cancel()
|
// defer cancel()
|
||||||
//
|
//
|
||||||
// // 代理流量
|
// // 代理流量
|
||||||
@@ -304,9 +304,9 @@ func (server *Server) handleConnect(ctx context.Context, conn net.Conn, req *Req
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("failed to send reply: %v", err)
|
return fmt.Errorf("failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
@@ -322,9 +322,9 @@ func (server *Server) handleBind(ctx context.Context, conn net.Conn, req *Reques
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (server *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
|
func (s *Server) handleAssociate(ctx context.Context, conn net.Conn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
if ctx_, ok := server.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
return fmt.Errorf("failed to send reply: %v", err)
|
return fmt.Errorf("failed to send reply: %v", err)
|
||||||
}
|
}
|
||||||
@@ -370,7 +370,7 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
msg := make([]byte, 6+len(addrBody))
|
msg := make([]byte, 6+len(addrBody))
|
||||||
msg[0] = SocksVersion
|
msg[0] = Version
|
||||||
msg[1] = resp
|
msg[1] = resp
|
||||||
msg[2] = 0 // Reserved
|
msg[2] = 0 // Reserved
|
||||||
msg[3] = addrType
|
msg[3] = addrType
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
package socks5
|
package socks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
"bufio"
|
||||||
"errors"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
@@ -11,10 +11,13 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
SocksVersion = byte(5)
|
Version = byte(5)
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
@@ -47,6 +50,9 @@ type Config struct {
|
|||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
config *Config
|
config *Config
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg utils.CountWaitGroup
|
||||||
Name string
|
Name string
|
||||||
Port uint16
|
Port uint16
|
||||||
Conn chan ProxyData
|
Conn chan ProxyData
|
||||||
@@ -76,8 +82,12 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &Server{
|
return &Server{
|
||||||
config: conf,
|
config: conf,
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
wg: utils.CountWaitGroup{},
|
||||||
Name: conf.Name,
|
Name: conf.Name,
|
||||||
Port: conf.Port,
|
Port: conf.Port,
|
||||||
Conn: make(chan ProxyData, 100),
|
Conn: make(chan ProxyData, 100),
|
||||||
@@ -85,46 +95,83 @@ func New(conf *Config) (*Server, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run 监听端口
|
// Run 监听端口
|
||||||
func (server *Server) Run() error {
|
func (s *Server) Run() error {
|
||||||
host := server.config.Host
|
|
||||||
port := server.config.Port
|
|
||||||
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
|
||||||
|
|
||||||
slog.Info("启动 socks5 代理服务")
|
slog.Info("启动 socks5 代理服务")
|
||||||
|
|
||||||
listener, err := net.Listen("tcp", addr)
|
// 监听端口
|
||||||
|
host := s.config.Host
|
||||||
|
port := s.config.Port
|
||||||
|
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
|
||||||
|
ls, err := net.Listen("tcp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "监听端口失败")
|
||||||
}
|
}
|
||||||
defer utils.Close(listener)
|
defer utils.Close(ls)
|
||||||
|
slog.Info("正在监听端口", slog.Uint64("port", uint64(port)))
|
||||||
|
|
||||||
slog.Info("代理服务已启动,正在监听端口 " + addr)
|
// 处理连接
|
||||||
|
connCh := utils.ChanConnAccept(s.ctx, ls)
|
||||||
|
defer close(connCh)
|
||||||
|
|
||||||
for {
|
err = nil
|
||||||
conn, err := listener.Accept()
|
for loop := true; loop; {
|
||||||
if err != nil {
|
select {
|
||||||
slog.Error("客户端连接失败", err)
|
case <-s.ctx.Done():
|
||||||
continue
|
slog.Debug("服务主动停止")
|
||||||
}
|
loop = false
|
||||||
|
case conn, ok := <-connCh:
|
||||||
go func() {
|
if !ok {
|
||||||
err := server.serve(conn)
|
err = errors.New("意外错误,无法获取连接")
|
||||||
if err != nil {
|
loop = false
|
||||||
slog.Error("连接异常退出", err)
|
break
|
||||||
}
|
}
|
||||||
}()
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
err := s.serve(conn)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("处理连接失败", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
s.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 关闭服务
|
||||||
|
timeout, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
wgCh := utils.ChanWgWait(timeout, &s.wg)
|
||||||
|
|
||||||
|
err = nil
|
||||||
|
select {
|
||||||
|
case <-timeout.Done():
|
||||||
|
err = errors.New("关闭超时(强制关闭)")
|
||||||
|
case <-wgCh:
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.Conn != nil {
|
||||||
|
close(s.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close 关闭服务
|
||||||
|
func (s *Server) Close() {
|
||||||
|
s.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// serve 建立连接
|
// serve 建立连接
|
||||||
func (server *Server) serve(conn net.Conn) error {
|
func (s *Server) serve(conn net.Conn) error {
|
||||||
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
|
|
||||||
// 认证
|
// 认证
|
||||||
slog.Debug("开始认证流程")
|
slog.Debug("开始认证流程")
|
||||||
authContext, err := server.authenticate(reader, conn)
|
authContext, err := s.authenticate(reader, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
utils.Close(conn)
|
utils.Close(conn)
|
||||||
slog.Error("认证失败", err)
|
slog.Error("认证失败", err)
|
||||||
@@ -135,7 +182,7 @@ func (server *Server) serve(conn net.Conn) error {
|
|||||||
|
|
||||||
// 处理连接请求
|
// 处理连接请求
|
||||||
slog.Debug("处理连接请求")
|
slog.Debug("处理连接请求")
|
||||||
request, err := server.request(reader, conn)
|
request, err := s.request(reader, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("连接请求处理失败", err)
|
slog.Error("连接请求处理失败", err)
|
||||||
return err
|
return err
|
||||||
@@ -156,7 +203,7 @@ func (server *Server) serve(conn net.Conn) error {
|
|||||||
|
|
||||||
// 处理请求
|
// 处理请求
|
||||||
slog.Debug("开始代理流量")
|
slog.Debug("开始代理流量")
|
||||||
err = server.handle(request, conn)
|
err = s.handle(request, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -173,7 +220,7 @@ func checkVersion(reader io.Reader) error {
|
|||||||
|
|
||||||
slog.Debug("客户端请求版本", "version", version)
|
slog.Debug("客户端请求版本", "version", version)
|
||||||
|
|
||||||
if version != SocksVersion {
|
if version != Version {
|
||||||
return errors.New("客户端版本不兼容")
|
return errors.New("客户端版本不兼容")
|
||||||
}
|
}
|
||||||
|
|
||||||
135
template/service/service.go
Normal file
135
template/service/service.go
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Host string
|
||||||
|
Port uint16
|
||||||
|
CloseWait time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type Server struct {
|
||||||
|
config *Config
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
wg sync.WaitGroup
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(conf *Config) (*Server, error) {
|
||||||
|
|
||||||
|
if conf.Host == "" {
|
||||||
|
conf.Host = "localhost"
|
||||||
|
}
|
||||||
|
if conf.Port == 0 {
|
||||||
|
return nil, errors.New("port is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &Server{
|
||||||
|
conf,
|
||||||
|
ctx,
|
||||||
|
cancel,
|
||||||
|
sync.WaitGroup{},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Run() error {
|
||||||
|
|
||||||
|
// start listen
|
||||||
|
addr := net.JoinHostPort(s.config.Host, strconv.Itoa(int(s.config.Port)))
|
||||||
|
ls, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return errors.New("failed to listen")
|
||||||
|
}
|
||||||
|
defer closeRes(ls)
|
||||||
|
|
||||||
|
// wait accept
|
||||||
|
connCh := make(chan net.Conn)
|
||||||
|
defer close(connCh)
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := ls.Accept()
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, net.ErrClosed) {
|
||||||
|
log.Println("accept failed", err)
|
||||||
|
}
|
||||||
|
// retry on temporary error
|
||||||
|
var ne net.Error
|
||||||
|
if errors.As(err, &ne) && ne.Temporary() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
closeRes(conn)
|
||||||
|
return
|
||||||
|
case connCh <- conn:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// handle accept
|
||||||
|
func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.ctx.Done():
|
||||||
|
return
|
||||||
|
case conn, ok := <-connCh:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer s.wg.Done()
|
||||||
|
_ = s.handle(conn)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// close
|
||||||
|
timeout, cancel := context.WithTimeout(context.Background(), s.config.CloseWait)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
waitCh := make(chan struct{})
|
||||||
|
defer close(waitCh)
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
select {
|
||||||
|
case <-timeout.Done(): // waitCh may be closed
|
||||||
|
case waitCh <- struct{}{}:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = nil
|
||||||
|
select {
|
||||||
|
case <-timeout.Done():
|
||||||
|
err = errors.New("close timeout")
|
||||||
|
case <-waitCh:
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) Close() {
|
||||||
|
s.cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) handle(conn net.Conn) error {
|
||||||
|
defer closeRes(conn)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeRes[T io.Closer](res T) {
|
||||||
|
_ = res.Close()
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ package fwd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/server/pkg/socks5"
|
socks6 "proxy-server/server/fwd/socks"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +26,7 @@ func fakeRequest() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 发送认证请求
|
// 发送认证请求
|
||||||
_, err = conn.Write([]byte{socks5.SocksVersion, byte(1), byte(socks5.NoAuth)})
|
_, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user