133 lines
2.9 KiB
Go
133 lines
2.9 KiB
Go
package auth
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"proxy-server/server/fwd/core"
|
|
"proxy-server/server/fwd/repo"
|
|
"proxy-server/server/pkg/orm"
|
|
"strconv"
|
|
"time"
|
|
|
|
"errors"
|
|
)
|
|
|
|
type Protocol string
|
|
|
|
const (
|
|
Socks5 = Protocol("socks5")
|
|
Http = Protocol("http")
|
|
)
|
|
|
|
func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) {
|
|
|
|
// 获取用户地址
|
|
remoteAddr := conn.RemoteAddr().String()
|
|
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
|
}
|
|
|
|
// 获取服务端口
|
|
localAddr := conn.LocalAddr().String()
|
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
|
localPort, err := strconv.Atoi(_localPort)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
|
}
|
|
|
|
// 查询权限记录
|
|
slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort)
|
|
var channels []repo.Channel
|
|
err = orm.DB.Find(&channels, &repo.Channel{
|
|
AuthIp: true,
|
|
UserAddr: remoteHost,
|
|
NodePort: localPort,
|
|
Protocol: string(proto),
|
|
}).Error
|
|
if err != nil {
|
|
return nil, errors.New("查询用户权限失败")
|
|
}
|
|
// 记录应该只有一条
|
|
channel, err := orm.MaySingle(channels)
|
|
if err != nil {
|
|
return nil, errors.New("不在白名单内")
|
|
}
|
|
|
|
// 检查是否需要密码认证
|
|
if channel.AuthPass {
|
|
return nil, errors.New("需要密码认证")
|
|
}
|
|
|
|
// 检查权限是否过期
|
|
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
|
if timeout <= 0 {
|
|
return nil, errors.New("权限已过期")
|
|
}
|
|
|
|
return &core.AuthContext{
|
|
Timeout: timeout,
|
|
Payload: core.Payload{
|
|
ID: channel.UserId,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) {
|
|
|
|
// 获取服务端口
|
|
localAddr := conn.LocalAddr().String()
|
|
_, _localPort, err := net.SplitHostPort(localAddr)
|
|
localPort, err := strconv.Atoi(_localPort)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("noAuth 认证失败: %w", err)
|
|
}
|
|
|
|
// 查询权限记录
|
|
var channel repo.Channel
|
|
err = orm.DB.Take(&channel, &repo.Channel{
|
|
AuthPass: true,
|
|
Username: username,
|
|
NodePort: localPort,
|
|
Protocol: string(proto),
|
|
}).Error
|
|
if err != nil {
|
|
return nil, errors.New("用户不存在")
|
|
}
|
|
|
|
// 检查密码 todo 哈希
|
|
if channel.Password != password {
|
|
return nil, errors.New("密码错误")
|
|
}
|
|
|
|
// 如果用户设置了双验证则检查 ip 是否在白名单中
|
|
if channel.AuthIp {
|
|
|
|
// 获取用户地址
|
|
remoteAddr := conn.RemoteAddr().String()
|
|
remoteHost, _, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("无法获取连接信息: %w", err)
|
|
}
|
|
|
|
// 查询权限记录
|
|
if channel.UserAddr != remoteHost {
|
|
return nil, errors.New("不在白名单内")
|
|
}
|
|
}
|
|
|
|
// 检查权限是否过期
|
|
timeout := channel.Expiration.Sub(time.Now()).Seconds()
|
|
if timeout <= 0 {
|
|
return nil, errors.New("权限已过期")
|
|
}
|
|
|
|
return &core.AuthContext{
|
|
Timeout: timeout,
|
|
Payload: core.Payload{
|
|
ID: channel.UserId,
|
|
},
|
|
}, nil
|
|
}
|