按协议判断连接权限,优化权限检查效率

This commit is contained in:
2025-03-08 11:40:52 +08:00
parent 5786ac9d99
commit f996a20823
11 changed files with 328 additions and 101 deletions

View File

@@ -1,21 +1,23 @@
## todo ## todo
可配置 logger不直接使用 slog
授权测试,两种协议,三种认证方式
连接计时数据清理,避免堆泄露
找一个其他方式即时关闭未成功建立数据通道的连接 找一个其他方式即时关闭未成功建立数据通道的连接
排查下套接字重复的问题 排查下套接字重复的问题
鉴权时判断授权的协议
建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string提高效率 建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string提高效率
建立数据通道失败后,根据用户所选协议返回对应失败响应 建立数据通道失败后,根据用户所选协议返回对应失败响应
测试跳过认证时的最大 qps需要注意单机连接数上限会导致连接失败
数据通道池化
可配配置环境变量 可配配置环境变量
- 输出级别
- 退出等待时间 - 退出等待时间
- 数据通道连接超时等待时间 - 数据通道连接超时等待时间
- 目标地址连接超时等待时间 - 目标地址连接超时等待时间
@@ -24,13 +26,13 @@
协程池化 协程池化
需要测试,考虑是否切换到 gnet 数据通道池化
数据通道支持 tcp 多路复用(分离逻辑流) 数据通道支持 tcp 多路复用(分离逻辑流)
👆 进阶黑魔法 multipath tcp + 多路复用 👆 进阶黑魔法 multipath tcp + 多路复用
考虑一下连接安全性 切换到 gnet
## 开发相关 ## 开发相关

View File

@@ -47,29 +47,83 @@ create index user_ips_ip_address_index on user_ips (ip_address);
drop table if exists channels cascade; drop table if exists channels cascade;
create table channels ( create table channels (
id serial primary key, id serial primary key,
user_id int not null references users (id) user_id int not null references users (id)
on update cascade on update cascade
on delete cascade, on delete cascade,
node_id int not null references nodes (id) -- node_id int not null references nodes (id) --
on update cascade -- on update cascade --
on delete set null, -- 节点删除后,用户侧需要保留提取记录 on delete set null, -- 节点删除后,用户侧需要保留提取记录
node_port int, user_addr varchar(255) not null, -- 快照数据
protocol varchar(255), node_port int, -- 快照数据
auth_ip bool, auth_ip bool,
auth_pass bool, auth_pass bool,
protocol varchar(255),
username varchar(255) unique, username varchar(255) unique,
password varchar(255), password varchar(255),
expiration timestamp not null, expiration timestamp not null,
created_at timestamp default current_timestamp, created_at timestamp default current_timestamp,
updated_at timestamp default current_timestamp, updated_at timestamp default current_timestamp,
deleted_at timestamp deleted_at timestamp
); );
create index channel_user_id_index on channels (user_id); create index channel_user_id_index on channels (user_id);
create index channel_node_id_index on channels (node_id); create index channel_node_id_index on channels (node_id);
create index channel_username_index on channels (username); create index channel_user_addr_index on channels (user_addr);
create index channel_node_port_index on channels (node_port);
-- ==================== -- ====================
-- 填充数据 -- 填充数据
-- ==================== -- ====================
do
$$
declare
r_user_id int;
r_node_id int;
begin
-- 用户信息
insert into users (
username, password, email, phone, name
)
values (
'test', 'test', 'test@user.email', '12345678901', 'test_user'
)
returning id into r_user_id;
insert into user_ips (
user_id, ip_address
)
values (
r_user_id, '::1'
), (
r_user_id, '127.0.0.1'
);
-- 节点信息
insert into nodes (
name, version, fwd_port, provider, location
)
values (
'qwer', 1, 20001, 'test_provider', 'test_location'
)
returning id into r_node_id;
-- 权限信息
insert into channels (
user_id, node_id, user_addr, node_port, auth_ip, auth_pass,
protocol, username, password, expiration
)
values (
r_user_id, r_node_id, '::1', 20001, true, false,
'http', 'ip6http', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '::1', 20001, true, false,
'socks5', 'ip6socks', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
'http', 'ip4http', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
'socks5', 'ip4socks', 'asdf', now() + interval '1 year'
);
end
$$;

View File

@@ -1,32 +1,25 @@
package core package auth
import ( import (
"log/slog" "log/slog"
"net" "net"
models2 "proxy-server/server/pkg/models" "proxy-server/server/fwd/core"
"proxy-server/server/pkg/models"
"proxy-server/server/pkg/orm" "proxy-server/server/pkg/orm"
"strconv"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type Payload struct { type Protocol string
ID uint
}
type AuthContext struct { const (
Timeout float64 Socks5 = Protocol("socks5")
Payload Payload Http = Protocol("http")
Meta map[string]any )
}
func CheckIp(conn net.Conn) (*AuthContext, error) { func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
// 获取用户地址 // 获取用户地址
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
@@ -37,23 +30,24 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
// 获取服务端口 // 获取服务端口
localAddr := conn.LocalAddr().String() localAddr := conn.LocalAddr().String()
_, localPort, err := net.SplitHostPort(localAddr) _, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
}
// 查询权限记录 // 查询权限记录
slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort) slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
var channels []models2.Channel var channels []models.Channel
err = orm.DB. err = orm.DB.Find(&channels, &models.Channel{
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort). AuthIp: true,
Joins("INNER JOIN public.users u ON channels.user_id = u.id"). UserAddr: remoteHost,
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost). NodePort: localPort,
Where(&models2.Channel{ Protocol: string(proto),
AuthIp: true, }).Error
}).
Find(&channels).Error
if err != nil { if err != nil {
return nil, errors.New("查询用户权限失败") return nil, errors.New("查询用户权限失败")
} }
// 记录应该只有一条 // 记录应该只有一条
channel, err := orm.MaySingle(channels) channel, err := orm.MaySingle(channels)
if err != nil { if err != nil {
@@ -71,30 +65,32 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
return nil, errors.New("权限已过期") return nil, errors.New("权限已过期")
} }
return &AuthContext{ return &core.AuthContext{
Timeout: timeout, Timeout: timeout,
Payload: Payload{ Payload: core.Payload{
channel.UserId, ID: channel.UserId,
}, },
}, nil }, nil
} }
func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) { func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
// 查询通道配置 // 获取服务端口
var channel models2.Channel localAddr := conn.LocalAddr().String()
err := orm.DB. _, _localPort, err := net.SplitHostPort(localAddr)
Where(&models2.Channel{ localPort, err := strconv.Atoi(_localPort)
Username: username, if err != nil {
AuthPass: true, return nil, errors.Wrap(err, "noAuth 认证失败")
}). }
First(&channel).Error
// 查询权限记录
var channel models.Channel
err = orm.DB.Take(&channel, &models.Channel{
AuthPass: true,
Username: username,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil { if err != nil {
return nil, errors.Wrap(err, "用户不存在") return nil, errors.Wrap(err, "用户不存在")
} }
@@ -104,15 +100,8 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.New("密码错误") return nil, errors.New("密码错误")
} }
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
// 如果用户设置了双验证则检查 ip 是否在白名单中 // 如果用户设置了双验证则检查 ip 是否在白名单中
if channel.AuthIp { if channel.AuthIp {
slog.Debug("验证用户 ip")
// 获取用户地址 // 获取用户地址
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
@@ -121,28 +110,22 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.Wrap(err, "无法获取连接信息") return nil, errors.Wrap(err, "无法获取连接信息")
} }
// 查询通道配置 // 查询权限记录
if channel.UserAddr != remoteHost {
var ips int64
err = orm.DB.
Where(&models2.UserIp{
UserId: channel.UserId,
IpAddress: remoteHost,
}).
Count(&ips).Error
if err != nil {
return nil, errors.Wrap(err, "查询白名单失败")
}
if ips == 0 {
return nil, errors.New("不在白名单内") return nil, errors.New("不在白名单内")
} }
} }
return &AuthContext{ // 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
return &core.AuthContext{
Timeout: timeout, Timeout: timeout,
Payload: Payload{ Payload: core.Payload{
channel.UserId, ID: channel.UserId,
}, },
}, nil }, nil
} }

View File

@@ -0,0 +1,151 @@
package auth
import (
"net"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/orm"
"reflect"
"testing"
"time"
)
type MockConn struct {
local net.Addr
remote net.Addr
}
func (m MockConn) LocalAddr() net.Addr {
return m.local
}
func (m MockConn) RemoteAddr() net.Addr {
return m.remote
}
func (m MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Close() error {
return nil
}
func (m MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestCheckIp(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckIp(tt.args.conn, tt.args.proto)
if (err != nil) != tt.wantErr {
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckPass(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
username string
password string
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
username: "ip4http",
password: "asdf",
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
if (err != nil) != tt.wantErr {
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -92,3 +92,13 @@ func (a FwdAddr) Network() string {
func (a FwdAddr) String() string { func (a FwdAddr) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port) return fmt.Sprintf("%s:%d", a.IP, a.Port)
} }
type AuthContext struct {
Timeout float64
Payload Payload
Meta map[string]any
}
type Payload struct {
ID uint
}

View File

@@ -94,7 +94,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
// 检查客户端 // 检查客户端
var node models.Node var node models.Node
err = orm.DB.First(&node, &models.Node{ err = orm.DB.Take(&node, &models.Node{
Name: name, Name: name,
}).Error }).Error
if err != nil { if err != nil {

View File

@@ -118,7 +118,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
go func() { go func() {
user, err := http.Process(s.ctx, conn) user, err := http.Process(s.ctx, conn)
if err != nil { if err != nil {
slog.Error("dispatcher http process error", "err", err) slog.Error("处理 http 连接失败", "err", err)
utils.Close(conn) utils.Close(conn)
return return
} }

View File

@@ -8,6 +8,7 @@ import (
"net" "net"
"net/textproto" "net/textproto"
"net/url" "net/url"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"strings" "strings"
@@ -47,11 +48,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 验证账号 // 验证账号
authInfo := headers.Get("Proxy-Authorization") authInfo := headers.Get("Proxy-Authorization")
var auth *core.AuthContext var authCtx *core.AuthContext
var authErr error
if authInfo == "" { if authInfo == "" {
auth, err = core.CheckIp(conn) authCtx, authErr = auth.CheckIp(conn, auth.Http)
if err != nil { if authErr != nil {
return nil, errors.Wrap(err, "验证账号失败") _, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
} }
} else { } else {
authParts := strings.Split(authInfo, " ") authParts := strings.Split(authInfo, " ")
@@ -66,7 +72,14 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
return nil, errors.Wrap(err, "解码认证信息失败") return nil, errors.Wrap(err, "解码认证信息失败")
} }
authPair := strings.Split(string(authBytes), ":") authPair := strings.Split(string(authBytes), ":")
auth, err = core.CheckPass(conn, authPair[0], authPair[1]) authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
}
} }
// 获取 Host // 获取 Host
@@ -94,7 +107,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Port: addr.Port, Port: addr.Port,
Domain: host, Domain: host,
}, },
auth: auth, auth: authCtx,
} }
var user *core.Conn var user *core.Conn

View File

@@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"net" "net"
"proxy-server/pkg/utils" "proxy-server/pkg/utils"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core" "proxy-server/server/fwd/core"
"slices" "slices"
@@ -60,7 +61,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
reader := bufio.NewReader(conn) reader := bufio.NewReader(conn)
// 认证 // 认证
auth, err := authenticate(ctx, reader, conn) authCtx, err := authenticate(ctx, reader, conn)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "认证失败") return nil, errors.Wrap(err, "认证失败")
} }
@@ -85,7 +86,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Protocol: "socks5", Protocol: "socks5",
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(), Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
Dest: request.DestAddr, Dest: request.DestAddr,
Auth: auth, Auth: authCtx,
}, nil }, nil
} }
@@ -167,7 +168,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
password := string(passwordBuf) password := string(passwordBuf)
// 检查权限 // 检查权限
authContext, err := core.CheckPass(conn, username, password) authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "权限检查失败") return nil, errors.Wrap(err, "权限检查失败")
} }
@@ -188,12 +189,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
return nil, errors.Wrap(err, "响应认证方式失败") return nil, errors.Wrap(err, "响应认证方式失败")
} }
authContext, err := core.CheckIp(conn) authCtx, err := auth.CheckIp(conn, auth.Socks5)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "权限检查失败") return nil, errors.Wrap(err, "权限检查失败")
} }
return authContext, nil return authCtx, nil
} }
// 无适用的认证方式 // 无适用的认证方式

View File

@@ -11,10 +11,12 @@ type Channel struct {
gorm.Model gorm.Model
UserId uint UserId uint
NodeId uint NodeId uint
UserAddr string
NodePort int
AuthIp bool
AuthPass bool
Protocol string Protocol string
Username string Username string
Password string Password string
AuthIp bool
AuthPass bool
Expiration time.Time Expiration time.Time
} }

View File

@@ -29,6 +29,17 @@ func Init() {
DB = db DB = db
} }
func InitForTest(dsn string) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default,
})
if err != nil {
panic(err)
}
DB = db
}
func MaySingle[T any](results []T) (*T, error) { func MaySingle[T any](results []T) (*T, error) {
rsLen := len(results) rsLen := len(results)
if rsLen == 0 { if rsLen == 0 {