按协议判断连接权限,优化权限检查效率
This commit is contained in:
16
README.md
16
README.md
@@ -1,21 +1,23 @@
|
|||||||
## todo
|
## todo
|
||||||
|
|
||||||
|
可配置 logger,不直接使用 slog
|
||||||
|
|
||||||
|
授权测试,两种协议,三种认证方式
|
||||||
|
|
||||||
|
连接计时数据清理,避免堆泄露
|
||||||
|
|
||||||
找一个其他方式即时关闭未成功建立数据通道的连接
|
找一个其他方式即时关闭未成功建立数据通道的连接
|
||||||
|
|
||||||
排查下套接字重复的问题
|
排查下套接字重复的问题
|
||||||
|
|
||||||
鉴权时判断授权的协议
|
|
||||||
|
|
||||||
建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string,提高效率
|
建立通道时,发送的 dst 和 tag 等信息,可以用字节表示而非 string,提高效率
|
||||||
|
|
||||||
建立数据通道失败后,根据用户所选协议返回对应失败响应
|
建立数据通道失败后,根据用户所选协议返回对应失败响应
|
||||||
|
|
||||||
测试跳过认证时的最大 qps(需要注意单机连接数上限,会导致连接失败)
|
|
||||||
|
|
||||||
数据通道池化
|
|
||||||
|
|
||||||
可配配置环境变量
|
可配配置环境变量
|
||||||
|
|
||||||
|
- 输出级别
|
||||||
- 退出等待时间
|
- 退出等待时间
|
||||||
- 数据通道连接超时等待时间
|
- 数据通道连接超时等待时间
|
||||||
- 目标地址连接超时等待时间
|
- 目标地址连接超时等待时间
|
||||||
@@ -24,13 +26,13 @@
|
|||||||
|
|
||||||
协程池化
|
协程池化
|
||||||
|
|
||||||
需要测试,考虑是否切换到 gnet
|
数据通道池化
|
||||||
|
|
||||||
数据通道支持 tcp 多路复用(分离逻辑流)
|
数据通道支持 tcp 多路复用(分离逻辑流)
|
||||||
|
|
||||||
👆 进阶黑魔法 multipath tcp + 多路复用
|
👆 进阶黑魔法 multipath tcp + 多路复用
|
||||||
|
|
||||||
考虑一下连接安全性
|
切换到 gnet
|
||||||
|
|
||||||
## 开发相关
|
## 开发相关
|
||||||
|
|
||||||
|
|||||||
@@ -47,29 +47,83 @@ create index user_ips_ip_address_index on user_ips (ip_address);
|
|||||||
drop table if exists channels cascade;
|
drop table if exists channels cascade;
|
||||||
create table channels (
|
create table channels (
|
||||||
id serial primary key,
|
id serial primary key,
|
||||||
user_id int not null references users (id)
|
user_id int not null references users (id)
|
||||||
on update cascade
|
on update cascade
|
||||||
on delete cascade,
|
on delete cascade,
|
||||||
node_id int not null references nodes (id) --
|
node_id int not null references nodes (id) --
|
||||||
on update cascade --
|
on update cascade --
|
||||||
on delete set null, -- 节点删除后,用户侧需要保留提取记录
|
on delete set null, -- 节点删除后,用户侧需要保留提取记录
|
||||||
node_port int,
|
user_addr varchar(255) not null, -- 快照数据
|
||||||
protocol varchar(255),
|
node_port int, -- 快照数据
|
||||||
auth_ip bool,
|
auth_ip bool,
|
||||||
auth_pass bool,
|
auth_pass bool,
|
||||||
|
protocol varchar(255),
|
||||||
username varchar(255) unique,
|
username varchar(255) unique,
|
||||||
password varchar(255),
|
password varchar(255),
|
||||||
expiration timestamp not null,
|
expiration timestamp not null,
|
||||||
created_at timestamp default current_timestamp,
|
created_at timestamp default current_timestamp,
|
||||||
updated_at timestamp default current_timestamp,
|
updated_at timestamp default current_timestamp,
|
||||||
deleted_at timestamp
|
deleted_at timestamp
|
||||||
);
|
);
|
||||||
create index channel_user_id_index on channels (user_id);
|
create index channel_user_id_index on channels (user_id);
|
||||||
create index channel_node_id_index on channels (node_id);
|
create index channel_node_id_index on channels (node_id);
|
||||||
create index channel_username_index on channels (username);
|
create index channel_user_addr_index on channels (user_addr);
|
||||||
|
create index channel_node_port_index on channels (node_port);
|
||||||
|
|
||||||
-- ====================
|
-- ====================
|
||||||
-- 填充数据
|
-- 填充数据
|
||||||
-- ====================
|
-- ====================
|
||||||
|
|
||||||
|
do
|
||||||
|
$$
|
||||||
|
declare
|
||||||
|
r_user_id int;
|
||||||
|
r_node_id int;
|
||||||
|
begin
|
||||||
|
-- 用户信息
|
||||||
|
insert into users (
|
||||||
|
username, password, email, phone, name
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
'test', 'test', 'test@user.email', '12345678901', 'test_user'
|
||||||
|
)
|
||||||
|
returning id into r_user_id;
|
||||||
|
|
||||||
|
insert into user_ips (
|
||||||
|
user_id, ip_address
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
r_user_id, '::1'
|
||||||
|
), (
|
||||||
|
r_user_id, '127.0.0.1'
|
||||||
|
);
|
||||||
|
|
||||||
|
-- 节点信息
|
||||||
|
insert into nodes (
|
||||||
|
name, version, fwd_port, provider, location
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
'qwer', 1, 20001, 'test_provider', 'test_location'
|
||||||
|
)
|
||||||
|
returning id into r_node_id;
|
||||||
|
|
||||||
|
-- 权限信息
|
||||||
|
insert into channels (
|
||||||
|
user_id, node_id, user_addr, node_port, auth_ip, auth_pass,
|
||||||
|
protocol, username, password, expiration
|
||||||
|
)
|
||||||
|
values (
|
||||||
|
r_user_id, r_node_id, '::1', 20001, true, false,
|
||||||
|
'http', 'ip6http', 'asdf', now() + interval '1 year'
|
||||||
|
), (
|
||||||
|
r_user_id, r_node_id, '::1', 20001, true, false,
|
||||||
|
'socks5', 'ip6socks', 'asdf', now() + interval '1 year'
|
||||||
|
), (
|
||||||
|
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
|
||||||
|
'http', 'ip4http', 'asdf', now() + interval '1 year'
|
||||||
|
), (
|
||||||
|
r_user_id, r_node_id, '127.0.0.1', 20001, true, false,
|
||||||
|
'socks5', 'ip4socks', 'asdf', now() + interval '1 year'
|
||||||
|
);
|
||||||
|
end
|
||||||
|
$$;
|
||||||
|
|||||||
@@ -1,32 +1,25 @@
|
|||||||
package core
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
models2 "proxy-server/server/pkg/models"
|
"proxy-server/server/fwd/core"
|
||||||
|
"proxy-server/server/pkg/models"
|
||||||
"proxy-server/server/pkg/orm"
|
"proxy-server/server/pkg/orm"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Payload struct {
|
type Protocol string
|
||||||
ID uint
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthContext struct {
|
const (
|
||||||
Timeout float64
|
Socks5 = Protocol("socks5")
|
||||||
Payload Payload
|
Http = Protocol("http")
|
||||||
Meta map[string]any
|
)
|
||||||
}
|
|
||||||
|
|
||||||
func CheckIp(conn net.Conn) (*AuthContext, error) {
|
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
||||||
return &AuthContext{
|
|
||||||
Timeout: 0,
|
|
||||||
Payload: Payload{
|
|
||||||
1,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
// 获取用户地址
|
// 获取用户地址
|
||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
@@ -37,23 +30,24 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
|
|||||||
|
|
||||||
// 获取服务端口
|
// 获取服务端口
|
||||||
localAddr := conn.LocalAddr().String()
|
localAddr := conn.LocalAddr().String()
|
||||||
_, localPort, err := net.SplitHostPort(localAddr)
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||||
|
localPort, err := strconv.Atoi(_localPort)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "noAuth 认证失败")
|
||||||
|
}
|
||||||
|
|
||||||
// 查询权限记录
|
// 查询权限记录
|
||||||
slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort)
|
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
|
||||||
var channels []models2.Channel
|
var channels []models.Channel
|
||||||
err = orm.DB.
|
err = orm.DB.Find(&channels, &models.Channel{
|
||||||
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort).
|
AuthIp: true,
|
||||||
Joins("INNER JOIN public.users u ON channels.user_id = u.id").
|
UserAddr: remoteHost,
|
||||||
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost).
|
NodePort: localPort,
|
||||||
Where(&models2.Channel{
|
Protocol: string(proto),
|
||||||
AuthIp: true,
|
}).Error
|
||||||
}).
|
|
||||||
Find(&channels).Error
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.New("查询用户权限失败")
|
return nil, errors.New("查询用户权限失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 记录应该只有一条
|
// 记录应该只有一条
|
||||||
channel, err := orm.MaySingle(channels)
|
channel, err := orm.MaySingle(channels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -71,30 +65,32 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
|
|||||||
return nil, errors.New("权限已过期")
|
return nil, errors.New("权限已过期")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AuthContext{
|
return &core.AuthContext{
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
Payload: Payload{
|
Payload: core.Payload{
|
||||||
channel.UserId,
|
ID: channel.UserId,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
|
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
|
||||||
return &AuthContext{
|
|
||||||
Timeout: 0,
|
|
||||||
Payload: Payload{
|
|
||||||
1,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
|
|
||||||
// 查询通道配置
|
// 获取服务端口
|
||||||
var channel models2.Channel
|
localAddr := conn.LocalAddr().String()
|
||||||
err := orm.DB.
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
||||||
Where(&models2.Channel{
|
localPort, err := strconv.Atoi(_localPort)
|
||||||
Username: username,
|
if err != nil {
|
||||||
AuthPass: true,
|
return nil, errors.Wrap(err, "noAuth 认证失败")
|
||||||
}).
|
}
|
||||||
First(&channel).Error
|
|
||||||
|
// 查询权限记录
|
||||||
|
var channel models.Channel
|
||||||
|
err = orm.DB.Take(&channel, &models.Channel{
|
||||||
|
AuthPass: true,
|
||||||
|
Username: username,
|
||||||
|
NodePort: localPort,
|
||||||
|
Protocol: string(proto),
|
||||||
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "用户不存在")
|
return nil, errors.Wrap(err, "用户不存在")
|
||||||
}
|
}
|
||||||
@@ -104,15 +100,8 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
|
|||||||
return nil, errors.New("密码错误")
|
return nil, errors.New("密码错误")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查权限是否过期
|
|
||||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
|
||||||
if timeout <= 0 {
|
|
||||||
return nil, errors.New("权限已过期")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
||||||
if channel.AuthIp {
|
if channel.AuthIp {
|
||||||
slog.Debug("验证用户 ip")
|
|
||||||
|
|
||||||
// 获取用户地址
|
// 获取用户地址
|
||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
@@ -121,28 +110,22 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
|
|||||||
return nil, errors.Wrap(err, "无法获取连接信息")
|
return nil, errors.Wrap(err, "无法获取连接信息")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 查询通道配置
|
// 查询权限记录
|
||||||
|
if channel.UserAddr != remoteHost {
|
||||||
var ips int64
|
|
||||||
err = orm.DB.
|
|
||||||
Where(&models2.UserIp{
|
|
||||||
UserId: channel.UserId,
|
|
||||||
IpAddress: remoteHost,
|
|
||||||
}).
|
|
||||||
Count(&ips).Error
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "查询白名单失败")
|
|
||||||
}
|
|
||||||
|
|
||||||
if ips == 0 {
|
|
||||||
return nil, errors.New("不在白名单内")
|
return nil, errors.New("不在白名单内")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &AuthContext{
|
// 检查权限是否过期
|
||||||
|
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||||
|
if timeout <= 0 {
|
||||||
|
return nil, errors.New("权限已过期")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &core.AuthContext{
|
||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
Payload: Payload{
|
Payload: core.Payload{
|
||||||
channel.UserId,
|
ID: channel.UserId,
|
||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
151
server/fwd/auth/auth_test.go
Normal file
151
server/fwd/auth/auth_test.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"proxy-server/server/fwd/core"
|
||||||
|
"proxy-server/server/pkg/orm"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type MockConn struct {
|
||||||
|
local net.Addr
|
||||||
|
remote net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) LocalAddr() net.Addr {
|
||||||
|
return m.local
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) RemoteAddr() net.Addr {
|
||||||
|
return m.remote
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) Read(b []byte) (n int, err error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) Write(b []byte) (n int, err error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) SetDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) SetReadDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m MockConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckIp(t *testing.T) {
|
||||||
|
|
||||||
|
orm.InitForTest(
|
||||||
|
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||||
|
)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
conn net.Conn
|
||||||
|
proto Protocol
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *core.AuthContext
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test-ok",
|
||||||
|
args: args{
|
||||||
|
conn: MockConn{
|
||||||
|
remote: &net.TCPAddr{
|
||||||
|
IP: []byte{127, 0, 0, 1},
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
local: &net.TCPAddr{
|
||||||
|
IP: []byte{127, 0, 0, 1},
|
||||||
|
Port: 20001,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
proto: Http,
|
||||||
|
},
|
||||||
|
want: &core.AuthContext{
|
||||||
|
Payload: core.Payload{ID: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := CheckIp(tt.args.conn, tt.args.proto)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||||
|
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCheckPass(t *testing.T) {
|
||||||
|
|
||||||
|
orm.InitForTest(
|
||||||
|
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
|
||||||
|
)
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
conn net.Conn
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
proto Protocol
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *core.AuthContext
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "test-ok",
|
||||||
|
args: args{
|
||||||
|
conn: MockConn{
|
||||||
|
remote: &net.TCPAddr{
|
||||||
|
IP: []byte{127, 0, 0, 1},
|
||||||
|
Port: 12345,
|
||||||
|
},
|
||||||
|
local: &net.TCPAddr{
|
||||||
|
IP: []byte{127, 0, 0, 1},
|
||||||
|
Port: 20001,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
proto: Http,
|
||||||
|
username: "ip4http",
|
||||||
|
password: "asdf",
|
||||||
|
},
|
||||||
|
want: &core.AuthContext{
|
||||||
|
Payload: core.Payload{ID: 1},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
|
||||||
|
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -92,3 +92,13 @@ func (a FwdAddr) Network() string {
|
|||||||
func (a FwdAddr) String() string {
|
func (a FwdAddr) String() string {
|
||||||
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
return fmt.Sprintf("%s:%d", a.IP, a.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AuthContext struct {
|
||||||
|
Timeout float64
|
||||||
|
Payload Payload
|
||||||
|
Meta map[string]any
|
||||||
|
}
|
||||||
|
|
||||||
|
type Payload struct {
|
||||||
|
ID uint
|
||||||
|
}
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
|
|||||||
|
|
||||||
// 检查客户端
|
// 检查客户端
|
||||||
var node models.Node
|
var node models.Node
|
||||||
err = orm.DB.First(&node, &models.Node{
|
err = orm.DB.Take(&node, &models.Node{
|
||||||
Name: name,
|
Name: name,
|
||||||
}).Error
|
}).Error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -118,7 +118,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
|
|||||||
go func() {
|
go func() {
|
||||||
user, err := http.Process(s.ctx, conn)
|
user, err := http.Process(s.ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("dispatcher http process error", "err", err)
|
slog.Error("处理 http 连接失败", "err", err)
|
||||||
utils.Close(conn)
|
utils.Close(conn)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/textproto"
|
"net/textproto"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"proxy-server/server/fwd/auth"
|
||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -47,11 +48,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
|
|
||||||
// 验证账号
|
// 验证账号
|
||||||
authInfo := headers.Get("Proxy-Authorization")
|
authInfo := headers.Get("Proxy-Authorization")
|
||||||
var auth *core.AuthContext
|
var authCtx *core.AuthContext
|
||||||
|
var authErr error
|
||||||
if authInfo == "" {
|
if authInfo == "" {
|
||||||
auth, err = core.CheckIp(conn)
|
authCtx, authErr = auth.CheckIp(conn, auth.Http)
|
||||||
if err != nil {
|
if authErr != nil {
|
||||||
return nil, errors.Wrap(err, "验证账号失败")
|
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "响应 407 失败")
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(authErr, "验证账号失败")
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
authParts := strings.Split(authInfo, " ")
|
authParts := strings.Split(authInfo, " ")
|
||||||
@@ -66,7 +72,14 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
return nil, errors.Wrap(err, "解码认证信息失败")
|
return nil, errors.Wrap(err, "解码认证信息失败")
|
||||||
}
|
}
|
||||||
authPair := strings.Split(string(authBytes), ":")
|
authPair := strings.Split(string(authBytes), ":")
|
||||||
auth, err = core.CheckPass(conn, authPair[0], authPair[1])
|
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
|
||||||
|
if authErr != nil {
|
||||||
|
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "响应 407 失败")
|
||||||
|
}
|
||||||
|
return nil, errors.Wrap(authErr, "验证账号失败")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取 Host
|
// 获取 Host
|
||||||
@@ -94,7 +107,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
Port: addr.Port,
|
Port: addr.Port,
|
||||||
Domain: host,
|
Domain: host,
|
||||||
},
|
},
|
||||||
auth: auth,
|
auth: authCtx,
|
||||||
}
|
}
|
||||||
|
|
||||||
var user *core.Conn
|
var user *core.Conn
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
"proxy-server/pkg/utils"
|
"proxy-server/pkg/utils"
|
||||||
|
"proxy-server/server/fwd/auth"
|
||||||
"proxy-server/server/fwd/core"
|
"proxy-server/server/fwd/core"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
@@ -60,7 +61,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
reader := bufio.NewReader(conn)
|
reader := bufio.NewReader(conn)
|
||||||
|
|
||||||
// 认证
|
// 认证
|
||||||
auth, err := authenticate(ctx, reader, conn)
|
authCtx, err := authenticate(ctx, reader, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "认证失败")
|
return nil, errors.Wrap(err, "认证失败")
|
||||||
}
|
}
|
||||||
@@ -85,7 +86,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
|
|||||||
Protocol: "socks5",
|
Protocol: "socks5",
|
||||||
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
|
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
|
||||||
Dest: request.DestAddr,
|
Dest: request.DestAddr,
|
||||||
Auth: auth,
|
Auth: authCtx,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,7 +168,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
|||||||
password := string(passwordBuf)
|
password := string(passwordBuf)
|
||||||
|
|
||||||
// 检查权限
|
// 检查权限
|
||||||
authContext, err := core.CheckPass(conn, username, password)
|
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "权限检查失败")
|
return nil, errors.Wrap(err, "权限检查失败")
|
||||||
}
|
}
|
||||||
@@ -188,12 +189,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
|
|||||||
return nil, errors.Wrap(err, "响应认证方式失败")
|
return nil, errors.Wrap(err, "响应认证方式失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
authContext, err := core.CheckIp(conn)
|
authCtx, err := auth.CheckIp(conn, auth.Socks5)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.Wrap(err, "权限检查失败")
|
return nil, errors.Wrap(err, "权限检查失败")
|
||||||
}
|
}
|
||||||
|
|
||||||
return authContext, nil
|
return authCtx, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 无适用的认证方式
|
// 无适用的认证方式
|
||||||
|
|||||||
@@ -11,10 +11,12 @@ type Channel struct {
|
|||||||
gorm.Model
|
gorm.Model
|
||||||
UserId uint
|
UserId uint
|
||||||
NodeId uint
|
NodeId uint
|
||||||
|
UserAddr string
|
||||||
|
NodePort int
|
||||||
|
AuthIp bool
|
||||||
|
AuthPass bool
|
||||||
Protocol string
|
Protocol string
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
AuthIp bool
|
|
||||||
AuthPass bool
|
|
||||||
Expiration time.Time
|
Expiration time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,6 +29,17 @@ func Init() {
|
|||||||
DB = db
|
DB = db
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func InitForTest(dsn string) {
|
||||||
|
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||||
|
Logger: logger.Default,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
DB = db
|
||||||
|
}
|
||||||
|
|
||||||
func MaySingle[T any](results []T) (*T, error) {
|
func MaySingle[T any](results []T) (*T, error) {
|
||||||
rsLen := len(results)
|
rsLen := len(results)
|
||||||
if rsLen == 0 {
|
if rsLen == 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user