按协议判断连接权限,优化权限检查效率

This commit is contained in:
2025-03-08 11:40:52 +08:00
parent 5786ac9d99
commit f996a20823
11 changed files with 328 additions and 101 deletions

View File

@@ -1,21 +1,23 @@
## todo
可配置 logger不直接使用 slog
授权测试,两种协议,三种认证方式
连接计时数据清理,避免堆泄露
找一个其他方式即时关闭未成功建立数据通道的连接
排查下套接字重复的问题
鉴权时判断授权的协议
建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string提高效率
建立数据通道失败后,根据用户所选协议返回对应失败响应
测试跳过认证时的最大 qps需要注意单机连接数上限会导致连接失败
数据通道池化
可配配置环境变量
- 输出级别
- 退出等待时间
- 数据通道连接超时等待时间
- 目标地址连接超时等待时间
@@ -24,13 +26,13 @@
协程池化
需要测试,考虑是否切换到 gnet
数据通道池化
数据通道支持 tcp 多路复用(分离逻辑流)
👆 进阶黑魔法 multipath tcp + 多路复用
考虑一下连接安全性
切换到 gnet
## 开发相关

View File

@@ -47,29 +47,83 @@ create index user_ips_ip_address_index on user_ips (ip_address);
drop table if exists channels cascade;
create table channels (
id serial primary key,
user_id int not null references users (id)
user_id int not null references users (id)
on update cascade
on delete cascade,
node_id int not null references nodes (id) --
on update cascade --
on delete set null, -- 节点删除后,用户侧需要保留提取记录
node_port int,
protocol varchar(255),
node_id int not null references nodes (id) --
on update cascade --
on delete set null, -- 节点删除后,用户侧需要保留提取记录
user_addr varchar(255) not null, -- 快照数据
node_port int, -- 快照数据
auth_ip bool,
auth_pass bool,
protocol varchar(255),
username varchar(255) unique,
password varchar(255),
expiration timestamp not null,
expiration timestamp not null,
created_at timestamp default current_timestamp,
updated_at timestamp default current_timestamp,
deleted_at timestamp
);
create index channel_user_id_index on channels (user_id);
create index channel_node_id_index on channels (node_id);
create index channel_username_index on channels (username);
create index channel_user_addr_index on channels (user_addr);
create index channel_node_port_index on channels (node_port);
-- ====================
-- 填充数据
-- ====================
do
$$
declare
r_user_id int;
r_node_id int;
begin
-- 用户信息
insert into users (
username, password, email, phone, name
)
values (
'test', 'test', 'test@user.email', '12345678901', 'test_user'
)
returning id into r_user_id;
insert into user_ips (
user_id, ip_address
)
values (
r_user_id, '::1'
), (
r_user_id, '127.0.0.1'
);
-- 节点信息
insert into nodes (
name, version, fwd_port, provider, location
)
values (
'qwer', 1, 20001, 'test_provider', 'test_location'
)
returning id into r_node_id;
-- 权限信息
insert into channels (
user_id, node_id, user_addr, node_port, auth_ip, auth_pass,
protocol, username, password, expiration
)
values (
r_user_id, r_node_id, '::1', 20001, true, false,
'http', 'ip6http', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '::1', 20001, true, false,
'socks5', 'ip6socks', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
'http', 'ip4http', 'asdf', now() + interval '1 year'
), (
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
'socks5', 'ip4socks', 'asdf', now() + interval '1 year'
);
end
$$;

View File

@@ -1,32 +1,25 @@
package core
package auth
import (
"log/slog"
"net"
models2 "proxy-server/server/pkg/models"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/models"
"proxy-server/server/pkg/orm"
"strconv"
"time"
"github.com/pkg/errors"
)
type Payload struct {
ID uint
}
type Protocol string
type AuthContext struct {
Timeout float64
Payload Payload
Meta map[string]any
}
const (
Socks5 = Protocol("socks5")
Http = Protocol("http")
)
func CheckIp(conn net.Conn) (*AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
@@ -37,23 +30,24 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, localPort, err := net.SplitHostPort(localAddr)
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
}
// 查询权限记录
slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort)
var channels []models2.Channel
err = orm.DB.
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort).
Joins("INNER JOIN public.users u ON channels.user_id = u.id").
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost).
Where(&models2.Channel{
AuthIp: true,
}).
Find(&channels).Error
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
var channels []models.Channel
err = orm.DB.Find(&channels, &models.Channel{
AuthIp: true,
UserAddr: remoteHost,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.New("查询用户权限失败")
}
// 记录应该只有一条
channel, err := orm.MaySingle(channels)
if err != nil {
@@ -71,30 +65,32 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
return nil, errors.New("权限已过期")
}
return &AuthContext{
return &core.AuthContext{
Timeout: timeout,
Payload: Payload{
channel.UserId,
Payload: core.Payload{
ID: channel.UserId,
},
}, nil
}
func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
// 查询通道配置
var channel models2.Channel
err := orm.DB.
Where(&models2.Channel{
Username: username,
AuthPass: true,
}).
First(&channel).Error
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
}
// 查询权限记录
var channel models.Channel
err = orm.DB.Take(&channel, &models.Channel{
AuthPass: true,
Username: username,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.Wrap(err, "用户不存在")
}
@@ -104,15 +100,8 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.New("密码错误")
}
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
// 如果用户设置了双验证则检查 ip 是否在白名单中
if channel.AuthIp {
slog.Debug("验证用户 ip")
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
@@ -121,28 +110,22 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.Wrap(err, "无法获取连接信息")
}
// 查询通道配置
var ips int64
err = orm.DB.
Where(&models2.UserIp{
UserId: channel.UserId,
IpAddress: remoteHost,
}).
Count(&ips).Error
if err != nil {
return nil, errors.Wrap(err, "查询白名单失败")
}
if ips == 0 {
// 查询权限记录
if channel.UserAddr != remoteHost {
return nil, errors.New("不在白名单内")
}
}
return &AuthContext{
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
return &core.AuthContext{
Timeout: timeout,
Payload: Payload{
channel.UserId,
Payload: core.Payload{
ID: channel.UserId,
},
}, nil
}

View File

@@ -0,0 +1,151 @@
package auth
import (
"net"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/orm"
"reflect"
"testing"
"time"
)
type MockConn struct {
local net.Addr
remote net.Addr
}
func (m MockConn) LocalAddr() net.Addr {
return m.local
}
func (m MockConn) RemoteAddr() net.Addr {
return m.remote
}
func (m MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Close() error {
return nil
}
func (m MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestCheckIp(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckIp(tt.args.conn, tt.args.proto)
if (err != nil) != tt.wantErr {
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckPass(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
username string
password string
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
username: "ip4http",
password: "asdf",
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
if (err != nil) != tt.wantErr {
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -92,3 +92,13 @@ func (a FwdAddr) Network() string {
func (a FwdAddr) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
type AuthContext struct {
Timeout float64
Payload Payload
Meta map[string]any
}
type Payload struct {
ID uint
}

View File

@@ -94,7 +94,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
// 检查客户端
var node models.Node
err = orm.DB.First(&node, &models.Node{
err = orm.DB.Take(&node, &models.Node{
Name: name,
}).Error
if err != nil {

View File

@@ -118,7 +118,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
go func() {
user, err := http.Process(s.ctx, conn)
if err != nil {
slog.Error("dispatcher http process error", "err", err)
slog.Error("处理 http 连接失败", "err", err)
utils.Close(conn)
return
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/textproto"
"net/url"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"strings"
@@ -47,11 +48,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 验证账号
authInfo := headers.Get("Proxy-Authorization")
var auth *core.AuthContext
var authCtx *core.AuthContext
var authErr error
if authInfo == "" {
auth, err = core.CheckIp(conn)
if err != nil {
return nil, errors.Wrap(err, "验证账号失败")
authCtx, authErr = auth.CheckIp(conn, auth.Http)
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
}
} else {
authParts := strings.Split(authInfo, " ")
@@ -66,7 +72,14 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
return nil, errors.Wrap(err, "解码认证信息失败")
}
authPair := strings.Split(string(authBytes), ":")
auth, err = core.CheckPass(conn, authPair[0], authPair[1])
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
}
}
// 获取 Host
@@ -94,7 +107,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Port: addr.Port,
Domain: host,
},
auth: auth,
auth: authCtx,
}
var user *core.Conn

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"slices"
@@ -60,7 +61,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
reader := bufio.NewReader(conn)
// 认证
auth, err := authenticate(ctx, reader, conn)
authCtx, err := authenticate(ctx, reader, conn)
if err != nil {
return nil, errors.Wrap(err, "认证失败")
}
@@ -85,7 +86,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Protocol: "socks5",
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
Dest: request.DestAddr,
Auth: auth,
Auth: authCtx,
}, nil
}
@@ -167,7 +168,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
password := string(passwordBuf)
// 检查权限
authContext, err := core.CheckPass(conn, username, password)
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
}
@@ -188,12 +189,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
return nil, errors.Wrap(err, "响应认证方式失败")
}
authContext, err := core.CheckIp(conn)
authCtx, err := auth.CheckIp(conn, auth.Socks5)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
}
return authContext, nil
return authCtx, nil
}
// 无适用的认证方式

View File

@@ -11,10 +11,12 @@ type Channel struct {
gorm.Model
UserId uint
NodeId uint
UserAddr string
NodePort int
AuthIp bool
AuthPass bool
Protocol string
Username string
Password string
AuthIp bool
AuthPass bool
Expiration time.Time
}

View File

@@ -29,6 +29,17 @@ func Init() {
DB = db
}
func InitForTest(dsn string) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default,
})
if err != nil {
panic(err)
}
DB = db
}
func MaySingle[T any](results []T) (*T, error) {
rsLen := len(results)
if rsLen == 0 {