重构代理解析流程,引入端口混合协议转发
This commit is contained in:
136
server/fwd/core/auth.go
Normal file
136
server/fwd/core/auth.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"proxy-server/server/models"
|
||||
"proxy-server/server/pkg/orm"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Payload struct {
|
||||
ID uint
|
||||
}
|
||||
|
||||
type AuthContext struct {
|
||||
Timeout float64
|
||||
Payload Payload
|
||||
Meta map[string]any
|
||||
}
|
||||
|
||||
func CheckIp(conn net.Conn) (*AuthContext, error) {
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "noAuth 认证失败")
|
||||
}
|
||||
|
||||
// 获取服务端口
|
||||
localAddr := conn.LocalAddr().String()
|
||||
_, localPort, err := net.SplitHostPort(localAddr)
|
||||
|
||||
// 查询权限记录
|
||||
slog.Info("用户 " + remoteHost + " 请求连接到 " + localPort)
|
||||
var channels []models.Channel
|
||||
err = orm.DB.
|
||||
Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", localPort).
|
||||
Joins("INNER JOIN public.users u ON channels.user_id = u.id").
|
||||
Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", remoteHost).
|
||||
Where(&models.Channel{
|
||||
AuthIp: true,
|
||||
}).
|
||||
Find(&channels).Error
|
||||
if err != nil {
|
||||
return nil, errors.New("查询用户权限失败")
|
||||
}
|
||||
|
||||
// 记录应该只有一条
|
||||
channel, err := orm.MaySingle(channels)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "不在白名单内")
|
||||
}
|
||||
|
||||
// 检查是否需要密码认证
|
||||
if channel.AuthPass {
|
||||
return nil, errors.New("需要密码认证")
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
return nil, errors.New("权限已过期")
|
||||
}
|
||||
|
||||
return &AuthContext{
|
||||
Timeout: timeout,
|
||||
Payload: Payload{
|
||||
channel.UserId,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func CheckPass(conn net.Conn, username, password string) (*AuthContext, error) {
|
||||
|
||||
// 查询通道配置
|
||||
var channel models.Channel
|
||||
err := orm.DB.
|
||||
Where(&models.Channel{
|
||||
Username: username,
|
||||
AuthPass: true,
|
||||
}).
|
||||
First(&channel).Error
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "用户不存在")
|
||||
}
|
||||
|
||||
// 检查密码 todo 哈希
|
||||
if channel.Password != password {
|
||||
return nil, errors.New("密码错误")
|
||||
}
|
||||
|
||||
// 检查权限是否过期
|
||||
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
||||
if timeout <= 0 {
|
||||
return nil, errors.New("权限已过期")
|
||||
}
|
||||
|
||||
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
||||
if channel.AuthIp {
|
||||
slog.Debug("验证用户 ip")
|
||||
|
||||
// 获取用户地址
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "无法获取连接信息")
|
||||
}
|
||||
|
||||
// 查询通道配置
|
||||
|
||||
var ips int64
|
||||
err = orm.DB.
|
||||
Where(&models.UserIp{
|
||||
UserId: channel.UserId,
|
||||
IpAddress: remoteHost,
|
||||
}).
|
||||
Count(&ips).Error
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "查询白名单失败")
|
||||
}
|
||||
|
||||
if ips == 0 {
|
||||
return nil, errors.New("不在白名单内")
|
||||
}
|
||||
}
|
||||
|
||||
return &AuthContext{
|
||||
Timeout: timeout,
|
||||
Payload: Payload{
|
||||
channel.UserId,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
Reference in New Issue
Block a user