重构认证相关结构,更新认证流程,添加日志功能
This commit is contained in:
@@ -2,20 +2,14 @@
|
|||||||
|
|
||||||
监听进程信号,优雅关闭服务
|
监听进程信号,优雅关闭服务
|
||||||
|
|
||||||
加一个 log 包,实现全局日志格式控制
|
|
||||||
|
|
||||||
读取 conn 时加上超时机制
|
读取 conn 时加上超时机制
|
||||||
|
|
||||||
检查 ip 时需要判断同一 ip 的不同写法
|
检查 ip 时需要判断同一 ip 的不同写法
|
||||||
|
|
||||||
客户端断联后,服务端代理端口没有正确关闭
|
|
||||||
|
|
||||||
代理节点超时控制
|
代理节点超时控制
|
||||||
|
|
||||||
实现一个 socks context 以在子组件中获取 socks 相关信息
|
实现一个 socks context 以在子组件中获取 socks 相关信息
|
||||||
|
|
||||||
fwd 使用自定义 context 实现在一个上下文中控制 cancel,errCh 和其他自定义数据
|
|
||||||
|
|
||||||
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
|
网关根据代理节点对目标服务连接的反馈,决定向用户返回的 socks 响应
|
||||||
|
|
||||||
数据通道池化
|
数据通道池化
|
||||||
|
|||||||
2
go.mod
2
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module proxy-server
|
module proxy-server
|
||||||
|
|
||||||
go 1.23
|
go 1.24
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/gin-gonic/gin v1.10.0
|
github.com/gin-gonic/gin v1.10.0
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"proxy-server/server/pkg/env"
|
"proxy-server/server/pkg/env"
|
||||||
"proxy-server/server/pkg/orm"
|
"proxy-server/server/pkg/orm"
|
||||||
"proxy-server/server/web/app/models"
|
"proxy-server/server/web/app/models"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ type Config struct {
|
|||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
Config *Config
|
Config *Config
|
||||||
connMap map[string]socks.ProxyData
|
connMap map[string]socks.ProxyConn
|
||||||
ctrlConnWg utils.CountWaitGroup
|
ctrlConnWg utils.CountWaitGroup
|
||||||
dataConnWg utils.CountWaitGroup
|
dataConnWg utils.CountWaitGroup
|
||||||
}
|
}
|
||||||
@@ -36,7 +37,7 @@ func New(config *Config) *Service {
|
|||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
Config: _config,
|
Config: _config,
|
||||||
connMap: make(map[string]socks.ProxyData),
|
connMap: make(map[string]socks.ProxyConn),
|
||||||
ctrlConnWg: utils.CountWaitGroup{},
|
ctrlConnWg: utils.CountWaitGroup{},
|
||||||
dataConnWg: utils.CountWaitGroup{},
|
dataConnWg: utils.CountWaitGroup{},
|
||||||
}
|
}
|
||||||
@@ -99,19 +100,22 @@ func (s *Service) startCtrlTun(ctx context.Context, errCh chan error) {
|
|||||||
defer close(connCh)
|
defer close(connCh)
|
||||||
|
|
||||||
// 处理连接
|
// 处理连接
|
||||||
loop:
|
for loop := true; loop; {
|
||||||
for {
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
slog.Debug("结束处理连接,由于上下文取消")
|
slog.Debug("结束处理连接,由于上下文取消")
|
||||||
break loop
|
loop = false
|
||||||
case conn, ok := <-connCh:
|
case conn, ok := <-connCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
slog.Debug("结束处理连接,由于获取连接失败")
|
slog.Debug("结束处理连接,由于获取连接失败")
|
||||||
break loop
|
loop = false
|
||||||
}
|
}
|
||||||
s.ctrlConnWg.Add(1)
|
s.ctrlConnWg.Add(1)
|
||||||
go s.processCtrlConn(conn)
|
go func() {
|
||||||
|
defer s.ctrlConnWg.Done()
|
||||||
|
defer utils.Close(conn)
|
||||||
|
s.processCtrlConn(conn)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -134,12 +138,7 @@ loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) processCtrlConn(controller net.Conn) {
|
func (s *Service) processCtrlConn(controller net.Conn) {
|
||||||
defer func() {
|
slog.Info("收到客户端控制通道连接", "addr", controller.RemoteAddr().String())
|
||||||
s.ctrlConnWg.Done()
|
|
||||||
utils.Close(controller)
|
|
||||||
}()
|
|
||||||
|
|
||||||
slog.Info("收到客户端控制连接 " + controller.RemoteAddr().String())
|
|
||||||
|
|
||||||
reader := bufio.NewReader(controller)
|
reader := bufio.NewReader(controller)
|
||||||
|
|
||||||
@@ -161,32 +160,30 @@ func (s *Service) processCtrlConn(controller net.Conn) {
|
|||||||
&NoAuthAuthenticator{},
|
&NoAuthAuthenticator{},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
defer proxy.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("代理服务创建失败", "err", err)
|
slog.Error("代理服务创建失败", "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := proxy.Run()
|
err := proxy.Run()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("代理服务建立失败", "err", err)
|
slog.Error("代理服务启动失败", "err", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
slog.Info("代理服务已建立", "port", port)
|
// 等待用户连接
|
||||||
for {
|
for {
|
||||||
user := <-proxy.Conn
|
user := <-proxy.Conn
|
||||||
tag := user.Tag()
|
tag := user.Tag()
|
||||||
_, err := controller.Write([]byte{byte(len(tag))})
|
tagBuf := make([]byte, len(tag)+1)
|
||||||
|
tagBuf[0] = byte(len(tag))
|
||||||
|
copy(tagBuf[1:], tag)
|
||||||
|
_, err := controller.Write(tagBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("write error", "err", err)
|
slog.Error("写入 tag 失败", "err", err)
|
||||||
return
|
utils.Close(user)
|
||||||
}
|
|
||||||
_, err = controller.Write([]byte(tag))
|
|
||||||
slog.Info("已通知客户端建立数据通道")
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("write error", "err", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.connMap[tag] = user
|
s.connMap[tag] = user
|
||||||
@@ -222,7 +219,11 @@ loop:
|
|||||||
break loop
|
break loop
|
||||||
}
|
}
|
||||||
s.dataConnWg.Add(1)
|
s.dataConnWg.Add(1)
|
||||||
go s.processDataConn(conn)
|
go func() {
|
||||||
|
defer s.dataConnWg.Done()
|
||||||
|
defer utils.Close(conn)
|
||||||
|
s.processDataConn(conn)
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -245,10 +246,6 @@ loop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) processDataConn(client net.Conn) {
|
func (s *Service) processDataConn(client net.Conn) {
|
||||||
defer func() {
|
|
||||||
s.dataConnWg.Done()
|
|
||||||
utils.Close(client)
|
|
||||||
}()
|
|
||||||
slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String())
|
slog.Info("已建立客户端数据通道 " + client.RemoteAddr().String())
|
||||||
|
|
||||||
// 读取 tag
|
// 读取 tag
|
||||||
@@ -277,14 +274,10 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
socks.SendSuccess(user, client)
|
socks.SendSuccess(user, client)
|
||||||
|
|
||||||
// 写入目标地址
|
// 写入目标地址
|
||||||
_, err = client.Write([]byte{byte(len(data.Dest))})
|
destBuf := slices.Insert([]byte(data.Dest), 0, byte(len(data.Dest)))
|
||||||
|
_, err = client.Write(destBuf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("写入目标地址失败", "err", err)
|
slog.Error("发送目标地址失败", "err", err)
|
||||||
return
|
|
||||||
}
|
|
||||||
_, err = client.Write([]byte(data.Dest))
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("写入目标地址失败", "err", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,14 +287,14 @@ func (s *Service) processDataConn(client net.Conn) {
|
|||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(client, user)
|
_, err := io.Copy(client, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("processDataConn error c2u", "err", err)
|
slog.Error("processDataConn error u2c", "err", err)
|
||||||
}
|
}
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(user, client)
|
_, err := io.Copy(user, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("processDataConn error u2c", "err", err)
|
slog.Error("processDataConn error c2u", "err", err)
|
||||||
}
|
}
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}()
|
}()
|
||||||
@@ -316,7 +309,7 @@ func (a *NoAuthAuthenticator) Method() socks.AuthMethod {
|
|||||||
return socks.NoAuth
|
return socks.NoAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) {
|
||||||
|
|
||||||
// 获取用户地址
|
// 获取用户地址
|
||||||
conn, ok := writer.(net.Conn)
|
conn, ok := writer.(net.Conn)
|
||||||
@@ -372,10 +365,12 @@ func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader
|
|||||||
}
|
}
|
||||||
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
|
slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout)))
|
||||||
|
|
||||||
return &socks.AuthContext{
|
return &socks.Authentication{
|
||||||
Method: socks.NoAuth,
|
Method: socks.NoAuth,
|
||||||
Timeout: uint(timeout),
|
Timeout: uint(timeout),
|
||||||
Payload: nil,
|
Payload: socks.Payload{
|
||||||
|
ID: channel.UserId,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -386,7 +381,7 @@ func (a *UserPassAuthenticator) Method() socks.AuthMethod {
|
|||||||
return socks.UserPassAuth
|
return socks.UserPassAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.AuthContext, error) {
|
func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) {
|
||||||
|
|
||||||
// 检查认证版本
|
// 检查认证版本
|
||||||
slog.Debug("验证认证版本")
|
slog.Debug("验证认证版本")
|
||||||
@@ -489,9 +484,11 @@ func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Read
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &socks.AuthContext{
|
return &socks.Authentication{
|
||||||
Method: socks.UserPassAuth,
|
Method: socks.UserPassAuth,
|
||||||
Timeout: uint(timeout),
|
Timeout: uint(timeout),
|
||||||
Payload: nil,
|
Payload: socks.Payload{
|
||||||
|
ID: channel.UserId,
|
||||||
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
5
server/fwd/logs/logs.go
Normal file
5
server/fwd/logs/logs.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package logs
|
||||||
|
|
||||||
|
func Write(str ...string) {
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,294 +0,0 @@
|
|||||||
package fwd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"proxy-server/pkg/utils"
|
|
||||||
socks6 "proxy-server/server/fwd/socks"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestNew(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
config *Config
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
want *Service
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
{
|
|
||||||
name: "server config nil",
|
|
||||||
args: args{
|
|
||||||
config: nil,
|
|
||||||
},
|
|
||||||
want: &Service{
|
|
||||||
Config: &Config{},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := New(tt.args.config); !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("New() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoAuthAuthenticator_Authenticate(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
reader io.Reader
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantWriter string
|
|
||||||
want *socks6.AuthContext
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
a := &NoAuthAuthenticator{}
|
|
||||||
writer := &bytes.Buffer{}
|
|
||||||
got, err := a.Authenticate(tt.args.ctx, tt.args.reader, writer)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotWriter := writer.String(); gotWriter != tt.wantWriter {
|
|
||||||
t.Errorf("Authenticate() gotWriter = %v, want %v", gotWriter, tt.wantWriter)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Authenticate() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoAuthAuthenticator_Method(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
want socks6.AuthMethod
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
a := &NoAuthAuthenticator{}
|
|
||||||
if got := a.Method(); got != tt.want {
|
|
||||||
t.Errorf("Method() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_Run(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Config *Config
|
|
||||||
connMap map[string]socks6.ProxyData
|
|
||||||
ctrlConnWg utils.CountWaitGroup
|
|
||||||
dataConnWg utils.CountWaitGroup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
errCh chan error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Service{
|
|
||||||
Config: tt.fields.Config,
|
|
||||||
connMap: tt.fields.connMap,
|
|
||||||
ctrlConnWg: tt.fields.ctrlConnWg,
|
|
||||||
dataConnWg: tt.fields.dataConnWg,
|
|
||||||
}
|
|
||||||
s.Run(tt.args.ctx, tt.args.errCh)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_processCtrlConn(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Config *Config
|
|
||||||
connMap map[string]socks6.ProxyData
|
|
||||||
ctrlConnWg utils.CountWaitGroup
|
|
||||||
dataConnWg utils.CountWaitGroup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
controller net.Conn
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Service{
|
|
||||||
Config: tt.fields.Config,
|
|
||||||
connMap: tt.fields.connMap,
|
|
||||||
ctrlConnWg: tt.fields.ctrlConnWg,
|
|
||||||
dataConnWg: tt.fields.dataConnWg,
|
|
||||||
}
|
|
||||||
s.processCtrlConn(tt.args.controller)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_processDataConn(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Config *Config
|
|
||||||
connMap map[string]socks6.ProxyData
|
|
||||||
ctrlConnWg utils.CountWaitGroup
|
|
||||||
dataConnWg utils.CountWaitGroup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
client net.Conn
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Service{
|
|
||||||
Config: tt.fields.Config,
|
|
||||||
connMap: tt.fields.connMap,
|
|
||||||
ctrlConnWg: tt.fields.ctrlConnWg,
|
|
||||||
dataConnWg: tt.fields.dataConnWg,
|
|
||||||
}
|
|
||||||
s.processDataConn(tt.args.client)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_startCtrlTun(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Config *Config
|
|
||||||
connMap map[string]socks6.ProxyData
|
|
||||||
ctrlConnWg utils.CountWaitGroup
|
|
||||||
dataConnWg utils.CountWaitGroup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
errCh chan error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Service{
|
|
||||||
Config: tt.fields.Config,
|
|
||||||
connMap: tt.fields.connMap,
|
|
||||||
ctrlConnWg: tt.fields.ctrlConnWg,
|
|
||||||
dataConnWg: tt.fields.dataConnWg,
|
|
||||||
}
|
|
||||||
s.startCtrlTun(tt.args.ctx, tt.args.errCh)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestService_startDataTun(t *testing.T) {
|
|
||||||
type fields struct {
|
|
||||||
Config *Config
|
|
||||||
connMap map[string]socks6.ProxyData
|
|
||||||
ctrlConnWg utils.CountWaitGroup
|
|
||||||
dataConnWg utils.CountWaitGroup
|
|
||||||
}
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
errCh chan error
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
fields fields
|
|
||||||
args args
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
s := &Service{
|
|
||||||
Config: tt.fields.Config,
|
|
||||||
connMap: tt.fields.connMap,
|
|
||||||
ctrlConnWg: tt.fields.ctrlConnWg,
|
|
||||||
dataConnWg: tt.fields.dataConnWg,
|
|
||||||
}
|
|
||||||
s.startDataTun(tt.args.ctx, tt.args.errCh)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserPassAuthenticator_Authenticate(t *testing.T) {
|
|
||||||
type args struct {
|
|
||||||
ctx context.Context
|
|
||||||
reader io.Reader
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
args args
|
|
||||||
wantWriter string
|
|
||||||
want *socks6.AuthContext
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
a := &UserPassAuthenticator{}
|
|
||||||
writer := &bytes.Buffer{}
|
|
||||||
got, err := a.Authenticate(tt.args.ctx, tt.args.reader, writer)
|
|
||||||
if (err != nil) != tt.wantErr {
|
|
||||||
t.Errorf("Authenticate() error = %v, wantErr %v", err, tt.wantErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if gotWriter := writer.String(); gotWriter != tt.wantWriter {
|
|
||||||
t.Errorf("Authenticate() gotWriter = %v, want %v", gotWriter, tt.wantWriter)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("Authenticate() got = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUserPassAuthenticator_Method(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
want socks6.AuthMethod
|
|
||||||
}{
|
|
||||||
// TODO: Add test cases.
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
a := &UserPassAuthenticator{}
|
|
||||||
if got := a.Method(); got != tt.want {
|
|
||||||
t.Errorf("Method() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -23,11 +23,11 @@ const (
|
|||||||
|
|
||||||
type Authenticator interface {
|
type Authenticator interface {
|
||||||
Method() AuthMethod
|
Method() AuthMethod
|
||||||
Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*AuthContext, error)
|
Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*Authentication, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticate 执行认证流程
|
// authenticate 执行认证流程
|
||||||
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
|
func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*Authentication, error) {
|
||||||
|
|
||||||
// 版本检查
|
// 版本检查
|
||||||
err := checkVersion(reader)
|
err := checkVersion(reader)
|
||||||
@@ -75,8 +75,13 @@ func (s *Server) authenticate(reader io.Reader, writer io.Writer) (*AuthContext,
|
|||||||
return nil, errors.New("没有适用的认证方式")
|
return nil, errors.New("没有适用的认证方式")
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthContext struct {
|
type Authentication struct {
|
||||||
Method AuthMethod
|
Method AuthMethod
|
||||||
Timeout uint
|
Timeout uint
|
||||||
Payload map[string]any
|
Payload Payload
|
||||||
|
Data map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type Payload struct {
|
||||||
|
ID uint
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -49,6 +49,27 @@ type AddrSpec struct {
|
|||||||
Port int
|
Port int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a AddrSpec) Domain() []string {
|
||||||
|
if a.FQDN != "" {
|
||||||
|
return []string{a.FQDN}
|
||||||
|
}
|
||||||
|
|
||||||
|
var domain []string
|
||||||
|
|
||||||
|
ch := make(chan struct{})
|
||||||
|
defer close(ch)
|
||||||
|
go func() {
|
||||||
|
addr, err := net.LookupAddr(a.IP.String())
|
||||||
|
if err == nil {
|
||||||
|
domain = addr
|
||||||
|
}
|
||||||
|
ch <- struct{}{}
|
||||||
|
}()
|
||||||
|
<-ch
|
||||||
|
|
||||||
|
return domain
|
||||||
|
}
|
||||||
|
|
||||||
func (a AddrSpec) String() string {
|
func (a AddrSpec) String() string {
|
||||||
if a.FQDN != "" {
|
if a.FQDN != "" {
|
||||||
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
|
||||||
@@ -186,8 +207,8 @@ type Request struct {
|
|||||||
Version uint8
|
Version uint8
|
||||||
// Requested command
|
// Requested command
|
||||||
Command uint8
|
Command uint8
|
||||||
// AuthContext provided during negotiation
|
// Authentication provided during negotiation
|
||||||
AuthContext *AuthContext
|
Authentication *Authentication
|
||||||
// AddrSpec of the network that sent the request
|
// AddrSpec of the network that sent the request
|
||||||
RemoteAddr *AddrSpec
|
RemoteAddr *AddrSpec
|
||||||
// AddrSpec of the desired destination
|
// AddrSpec of the desired destination
|
||||||
@@ -220,7 +241,6 @@ func (s *Server) handle(req *Request, conn net.Conn) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request) error {
|
||||||
|
|
||||||
// 检查规则集约束
|
// 检查规则集约束
|
||||||
s.config.Logger.Printf("检查约束规则\n")
|
s.config.Logger.Printf("检查约束规则\n")
|
||||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
@@ -233,75 +253,8 @@ func (s *Server) handleConnect(ctx context.Context, conn net.Conn, req *Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
slog.Info("需要向 " + req.DestAddr.Address() + " 建立连接")
|
||||||
s.Conn <- ProxyData{conn, req.realDestAddr.Address()}
|
s.Conn <- ProxyConn{conn, req.realDestAddr.Address()}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
// 与目标服务器建立连接
|
|
||||||
// s.config.Logger.Printf("与目标服务器建立连接\n")
|
|
||||||
// dial := s.config.Dial
|
|
||||||
// target, err := dial("tcp", req.realDestAddr.Address())
|
|
||||||
// if err != nil {
|
|
||||||
// msg := err.Error()
|
|
||||||
// resp := hostUnreachable
|
|
||||||
// if strings.Contains(msg, "refused") {
|
|
||||||
// resp = connectionRefused
|
|
||||||
// } else if strings.Contains(msg, "network is unreachable") {
|
|
||||||
// resp = networkUnreachable
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// err := sendReply(Conn, resp, nil)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("failed to send reply: %v", err)
|
|
||||||
// }
|
|
||||||
// return fmt.Errorf("request to %v failed: %v", req.DestAddr, err)
|
|
||||||
// }
|
|
||||||
// defer closeConnection(target)
|
|
||||||
//
|
|
||||||
// // 正常响应
|
|
||||||
// slog.Info("连接成功,开始代理流量")
|
|
||||||
//
|
|
||||||
// local := target.LocalAddr().(*net.TCPAddr)
|
|
||||||
// bind := AddrSpec{IP: local.IP, Port: local.Port}
|
|
||||||
// err = sendReply(Conn, successReply, &bind)
|
|
||||||
// if err != nil {
|
|
||||||
// return fmt.Errorf("Failed to send reply: %v", err)
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // 配置超时时间和行为
|
|
||||||
// timeout := req.AuthContext.Timeout
|
|
||||||
// slog.Debug("超时时间", "timeout", timeout)
|
|
||||||
//
|
|
||||||
// timeoutCtx, cancel := ctx.WithTimeout(ctx, time.Duration(timeout)*time.Second)
|
|
||||||
// defer cancel()
|
|
||||||
//
|
|
||||||
// // 代理流量
|
|
||||||
// errChan := make(chan error, 2)
|
|
||||||
// go func() {
|
|
||||||
// _, err = io.Copy(target, req.bufConn)
|
|
||||||
// errChan <- err
|
|
||||||
// }()
|
|
||||||
// go func() {
|
|
||||||
// _, err = io.Copy(Conn, target)
|
|
||||||
// errChan <- err
|
|
||||||
// }()
|
|
||||||
//
|
|
||||||
// for {
|
|
||||||
// select {
|
|
||||||
//
|
|
||||||
// case <-timeoutCtx.Done():
|
|
||||||
// slog.Debug("超时断开连接")
|
|
||||||
// // todo 根据 termination 执行不同的断开行为
|
|
||||||
// return nil
|
|
||||||
//
|
|
||||||
// case err := <-errChan:
|
|
||||||
// slog.Debug("主动断开连接")
|
|
||||||
// if err != nil {
|
|
||||||
// return errors.Wrap(err, "代理流量出现错误")
|
|
||||||
// }
|
|
||||||
// return nil
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
func (s *Server) handleBind(ctx context.Context, conn net.Conn, req *Request) error {
|
||||||
@@ -391,15 +344,19 @@ func SendSuccess(user net.Conn, target net.Conn) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type ProxyData struct {
|
type ProxyConn struct {
|
||||||
// 用户连入的连接
|
// 用户连入的连接
|
||||||
Conn net.Conn
|
Conn net.Conn
|
||||||
// 用户目标地址
|
// 用户目标地址
|
||||||
Dest string
|
Dest string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d ProxyData) Tag() string {
|
func (d ProxyConn) Tag() string {
|
||||||
local := d.Conn.LocalAddr()
|
local := d.Conn.LocalAddr()
|
||||||
remote := d.Conn.RemoteAddr()
|
remote := d.Conn.RemoteAddr()
|
||||||
return fmt.Sprintf("%s-%s", remote, local)
|
return fmt.Sprintf("%s-%s", remote, local)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d ProxyConn) Close() error {
|
||||||
|
return d.Conn.Close()
|
||||||
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
@@ -55,7 +56,7 @@ type Server struct {
|
|||||||
wg utils.CountWaitGroup
|
wg utils.CountWaitGroup
|
||||||
Name string
|
Name string
|
||||||
Port uint16
|
Port uint16
|
||||||
Conn chan ProxyData
|
Conn chan ProxyConn
|
||||||
}
|
}
|
||||||
|
|
||||||
// New 创建服务器
|
// New 创建服务器
|
||||||
@@ -90,7 +91,7 @@ func New(conf *Config) (*Server, error) {
|
|||||||
wg: utils.CountWaitGroup{},
|
wg: utils.CountWaitGroup{},
|
||||||
Name: conf.Name,
|
Name: conf.Name,
|
||||||
Port: conf.Port,
|
Port: conf.Port,
|
||||||
Conn: make(chan ProxyData, 100),
|
Conn: make(chan ProxyConn, 100),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -128,7 +129,8 @@ func (s *Server) Run() error {
|
|||||||
s.wg.Add(1)
|
s.wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
err := s.serve(conn)
|
// 连接要传出,不能在这里关闭连接
|
||||||
|
err := s.process(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("处理连接失败", err)
|
slog.Error("处理连接失败", err)
|
||||||
}
|
}
|
||||||
@@ -163,8 +165,8 @@ func (s *Server) Close() {
|
|||||||
s.cancel()
|
s.cancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
// serve 建立连接
|
// process 建立连接
|
||||||
func (s *Server) serve(conn net.Conn) error {
|
func (s *Server) process(conn net.Conn) error {
|
||||||
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
slog.Info("收到来自" + conn.RemoteAddr().String() + "的连接")
|
||||||
|
|
||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
@@ -190,15 +192,27 @@ func (s *Server) serve(conn net.Conn) error {
|
|||||||
slog.Debug("连接请求处理完成")
|
slog.Debug("连接请求处理完成")
|
||||||
}
|
}
|
||||||
|
|
||||||
request.AuthContext = authContext
|
// 记录日志
|
||||||
client, ok := conn.RemoteAddr().(*net.TCPAddr)
|
go func() {
|
||||||
|
slog.Info(
|
||||||
|
"用户访问记录",
|
||||||
|
slog.Uint64("uid", uint64(authContext.Payload.ID)),
|
||||||
|
slog.String("user", conn.RemoteAddr().String()),
|
||||||
|
slog.Any("node", conn.LocalAddr().String()),
|
||||||
|
slog.String("dest", request.DestAddr.Address()),
|
||||||
|
slog.String("domain", strings.Join(request.DestAddr.Domain(), ",")),
|
||||||
|
)
|
||||||
|
}()
|
||||||
|
|
||||||
|
request.Authentication = authContext
|
||||||
|
user, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
return fmt.Errorf("获取客户端地址失败")
|
return fmt.Errorf("获取用户地址失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
request.RemoteAddr = &AddrSpec{
|
request.RemoteAddr = &AddrSpec{
|
||||||
IP: client.IP,
|
IP: user.IP,
|
||||||
Port: client.Port,
|
Port: user.Port,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理请求
|
// 处理请求
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
package fwd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
socks6 "proxy-server/server/fwd/socks"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func BenchmarkNoAuth(b *testing.B) {
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkUserPassAuth(b *testing.B) {
|
|
||||||
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fakeRequest() {
|
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", "localhost:20001")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 发送认证请求
|
|
||||||
_, err = conn.Write([]byte{socks6.Version, byte(1), byte(socks6.NoAuth)})
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 忽略返回
|
|
||||||
_, err = conn.Read(make([]byte, 2))
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user