按协议判断连接权限,优化权限检查效率

This commit is contained in:
2025-03-08 11:40:52 +08:00
parent 5786ac9d99
commit f996a20823
11 changed files with 328 additions and 101 deletions

View File

@@ -1,32 +1,25 @@
package core
package auth
import (
"log/slog"
"net"
models2 "proxy-server/server/pkg/models"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/models"
"proxy-server/server/pkg/orm"
"strconv"
"time"
"github.com/pkg/errors"
)
type Payload struct {
ID uint
}
type Protocol string
type AuthContext struct {
Timeout float64
Payload Payload
Meta map[string]any
}
const (
Socks5 = Protocol("socks5")
Http = Protocol("http")
)
func CheckIp(conn net.Conn) (*AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
@@ -37,23 +30,24 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, localPort, err := net.SplitHostPort(localAddr)
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
}
// 查询权限记录
slog.Debug("用户 " + remoteHost + " 请求连接到 " + localPort)
var channels []models2.Channel
err = orm.DB.
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort).
Joins("INNER JOIN public.users u ON channels.user_id = u.id").
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost).
Where(&models2.Channel{
AuthIp: true,
}).
Find(&channels).Error
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
var channels []models.Channel
err = orm.DB.Find(&channels, &models.Channel{
AuthIp: true,
UserAddr: remoteHost,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.New("查询用户权限失败")
}
// 记录应该只有一条
channel, err := orm.MaySingle(channels)
if err != nil {
@@ -71,30 +65,32 @@ func CheckIp(conn net.Conn) (*AuthContext, error) {
return nil, errors.New("权限已过期")
}
return &AuthContext{
return &core.AuthContext{
Timeout: timeout,
Payload: Payload{
channel.UserId,
Payload: core.Payload{
ID: channel.UserId,
},
}, nil
}
func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return &AuthContext{
Timeout: 0,
Payload: Payload{
1,
},
}, nil
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
// 查询通道配置
var channel models2.Channel
err := orm.DB.
Where(&models2.Channel{
Username: username,
AuthPass: true,
}).
First(&channel).Error
// 获取服务端口
localAddr := conn.LocalAddr().String()
_, _localPort, err := net.SplitHostPort(localAddr)
localPort, err := strconv.Atoi(_localPort)
if err != nil {
return nil, errors.Wrap(err, "noAuth 认证失败")
}
// 查询权限记录
var channel models.Channel
err = orm.DB.Take(&channel, &models.Channel{
AuthPass: true,
Username: username,
NodePort: localPort,
Protocol: string(proto),
}).Error
if err != nil {
return nil, errors.Wrap(err, "用户不存在")
}
@@ -104,15 +100,8 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.New("密码错误")
}
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
// 如果用户设置了双验证则检查 ip 是否在白名单中
if channel.AuthIp {
slog.Debug("验证用户 ip")
// 获取用户地址
remoteAddr := conn.RemoteAddr().String()
@@ -121,28 +110,22 @@ func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
return nil, errors.Wrap(err, "无法获取连接信息")
}
// 查询通道配置
var ips int64
err = orm.DB.
Where(&models2.UserIp{
UserId: channel.UserId,
IpAddress: remoteHost,
}).
Count(&ips).Error
if err != nil {
return nil, errors.Wrap(err, "查询白名单失败")
}
if ips == 0 {
// 查询权限记录
if channel.UserAddr != remoteHost {
return nil, errors.New("不在白名单内")
}
}
return &AuthContext{
// 检查权限是否过期
timeout := channel.Expiration.Sub(time.Now()).Seconds()
if timeout <= 0 {
return nil, errors.New("权限已过期")
}
return &core.AuthContext{
Timeout: timeout,
Payload: Payload{
channel.UserId,
Payload: core.Payload{
ID: channel.UserId,
},
}, nil
}

View File

@@ -0,0 +1,151 @@
package auth
import (
"net"
"proxy-server/server/fwd/core"
"proxy-server/server/pkg/orm"
"reflect"
"testing"
"time"
)
type MockConn struct {
local net.Addr
remote net.Addr
}
func (m MockConn) LocalAddr() net.Addr {
return m.local
}
func (m MockConn) RemoteAddr() net.Addr {
return m.remote
}
func (m MockConn) Read(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Write(b []byte) (n int, err error) {
return 0, nil
}
func (m MockConn) Close() error {
return nil
}
func (m MockConn) SetDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m MockConn) SetWriteDeadline(t time.Time) error {
return nil
}
func TestCheckIp(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckIp(tt.args.conn, tt.args.proto)
if (err != nil) != tt.wantErr {
t.Errorf("CheckIp() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckIp() got = %v, want %v", got, tt.want)
}
})
}
}
func TestCheckPass(t *testing.T) {
orm.InitForTest(
"host=localhost port=5432 user=proxy password=proxy dbname=app sslmode=disable TimeZone=Asia/Shanghai",
)
type args struct {
conn net.Conn
username string
password string
proto Protocol
}
tests := []struct {
name string
args args
want *core.AuthContext
wantErr bool
}{
{
name: "test-ok",
args: args{
conn: MockConn{
remote: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 12345,
},
local: &net.TCPAddr{
IP: []byte{127, 0, 0, 1},
Port: 20001,
},
},
proto: Http,
username: "ip4http",
password: "asdf",
},
want: &core.AuthContext{
Payload: core.Payload{ID: 1},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := CheckPass(tt.args.conn, tt.args.proto, tt.args.username, tt.args.password)
if (err != nil) != tt.wantErr {
t.Errorf("CheckPass() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got.Payload, tt.want.Payload) || !reflect.DeepEqual(got.Meta, tt.want.Meta) {
t.Errorf("CheckPass() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -92,3 +92,13 @@ func (a FwdAddr) Network() string {
func (a FwdAddr) String() string {
return fmt.Sprintf("%s:%d", a.IP, a.Port)
}
type AuthContext struct {
Timeout float64
Payload Payload
Meta map[string]any
}
type Payload struct {
ID uint
}

View File

@@ -94,7 +94,7 @@ func (s *Service) processCtrlConn(conn net.Conn) error {
// 检查客户端
var node models.Node
err = orm.DB.First(&node, &models.Node{
err = orm.DB.Take(&node, &models.Node{
Name: name,
}).Error
if err != nil {

View File

@@ -118,7 +118,7 @@ func (s *Server) acceptHttp(ls net.Listener) error {
go func() {
user, err := http.Process(s.ctx, conn)
if err != nil {
slog.Error("dispatcher http process error", "err", err)
slog.Error("处理 http 连接失败", "err", err)
utils.Close(conn)
return
}

View File

@@ -8,6 +8,7 @@ import (
"net"
"net/textproto"
"net/url"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"strings"
@@ -47,11 +48,16 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
// 验证账号
authInfo := headers.Get("Proxy-Authorization")
var auth *core.AuthContext
var authCtx *core.AuthContext
var authErr error
if authInfo == "" {
auth, err = core.CheckIp(conn)
if err != nil {
return nil, errors.Wrap(err, "验证账号失败")
authCtx, authErr = auth.CheckIp(conn, auth.Http)
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
}
} else {
authParts := strings.Split(authInfo, " ")
@@ -66,7 +72,14 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
return nil, errors.Wrap(err, "解码认证信息失败")
}
authPair := strings.Split(string(authBytes), ":")
auth, err = core.CheckPass(conn, authPair[0], authPair[1])
authCtx, authErr = auth.CheckPass(conn, auth.Http, authPair[0], authPair[1])
if authErr != nil {
_, err := conn.Write([]byte("HTTP/1.1 407 Proxy Authentication Required\r\n\r\n"))
if err != nil {
return nil, errors.Wrap(err, "响应 407 失败")
}
return nil, errors.Wrap(authErr, "验证账号失败")
}
}
// 获取 Host
@@ -94,7 +107,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Port: addr.Port,
Domain: host,
},
auth: auth,
auth: authCtx,
}
var user *core.Conn

View File

@@ -9,6 +9,7 @@ import (
"log/slog"
"net"
"proxy-server/pkg/utils"
"proxy-server/server/fwd/auth"
"proxy-server/server/fwd/core"
"slices"
@@ -60,7 +61,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
reader := bufio.NewReader(conn)
// 认证
auth, err := authenticate(ctx, reader, conn)
authCtx, err := authenticate(ctx, reader, conn)
if err != nil {
return nil, errors.Wrap(err, "认证失败")
}
@@ -85,7 +86,7 @@ func Process(ctx context.Context, conn net.Conn) (*core.Conn, error) {
Protocol: "socks5",
Tag: conn.RemoteAddr().String() + "_" + conn.LocalAddr().String(),
Dest: request.DestAddr,
Auth: auth,
Auth: authCtx,
}, nil
}
@@ -167,7 +168,7 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
password := string(passwordBuf)
// 检查权限
authContext, err := core.CheckPass(conn, username, password)
authContext, err := auth.CheckPass(conn, auth.Socks5, username, password)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
}
@@ -188,12 +189,12 @@ func authenticate(ctx context.Context, reader *bufio.Reader, conn net.Conn) (*co
return nil, errors.Wrap(err, "响应认证方式失败")
}
authContext, err := core.CheckIp(conn)
authCtx, err := auth.CheckIp(conn, auth.Socks5)
if err != nil {
return nil, errors.Wrap(err, "权限检查失败")
}
return authContext, nil
return authCtx, nil
}
// 无适用的认证方式

View File

@@ -11,10 +11,12 @@ type Channel struct {
gorm.Model
UserId uint
NodeId uint
UserAddr string
NodePort int
AuthIp bool
AuthPass bool
Protocol string
Username string
Password string
AuthIp bool
AuthPass bool
Expiration time.Time
}

View File

@@ -29,6 +29,17 @@ func Init() {
DB = db
}
func InitForTest(dsn string) {
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{
Logger: logger.Default,
})
if err != nil {
panic(err)
}
DB = db
}
func MaySingle[T any](results []T) (*T, error) {
rsLen := len(results)
if rsLen == 0 {