package auth import ( "fmt" "log/slog" "net" "proxy-server/server/fwd/core" "proxy-server/server/fwd/repo" "proxy-server/server/pkg/orm" "strconv" "time" "errors" ) type Protocol string const ( Socks5 = Protocol("socks5") Http = Protocol("http") ) func CheckIp(conn net.Conn, proto Protocol) (*core.AuthContext, error) { // 获取用户地址 remoteAddr := conn.RemoteAddr().String() remoteHost, _, err := net.SplitHostPort(remoteAddr) if err != nil { return nil, fmt.Errorf("无法获取连接信息: %w", err) } // 获取服务端口 localAddr := conn.LocalAddr().String() _, _localPort, err := net.SplitHostPort(localAddr) localPort, err := strconv.Atoi(_localPort) if err != nil { return nil, fmt.Errorf("noAuth 认证失败: %w", err) } // 查询权限记录 slog.Debug("用户 " + remoteHost + " 请求连接到 " + _localPort) var channels []repo.Channel err = orm.DB.Find(&channels, &repo.Channel{ AuthIp: true, UserAddr: remoteHost, NodePort: localPort, Protocol: string(proto), }).Error if err != nil { return nil, errors.New("查询用户权限失败") } // 记录应该只有一条 channel, err := orm.MaySingle(channels) if err != nil { return nil, errors.New("不在白名单内") } // 检查是否需要密码认证 if channel.AuthPass { return nil, errors.New("需要密码认证") } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() if timeout <= 0 { return nil, errors.New("权限已过期") } return &core.AuthContext{ Timeout: timeout, Payload: core.Payload{ ID: channel.UserId, }, }, nil } func CheckPass(conn net.Conn, proto Protocol, username, password string) (*core.AuthContext, error) { // 获取服务端口 localAddr := conn.LocalAddr().String() _, _localPort, err := net.SplitHostPort(localAddr) localPort, err := strconv.Atoi(_localPort) if err != nil { return nil, fmt.Errorf("noAuth 认证失败: %w", err) } // 查询权限记录 var channel repo.Channel err = orm.DB.Take(&channel, &repo.Channel{ AuthPass: true, Username: username, NodePort: localPort, Protocol: string(proto), }).Error if err != nil { return nil, errors.New("用户不存在") } // 检查密码 todo 哈希 if channel.Password != password { return nil, errors.New("密码错误") } // 如果用户设置了双验证则检查 ip 是否在白名单中 if channel.AuthIp { // 获取用户地址 remoteAddr := conn.RemoteAddr().String() remoteHost, _, err := net.SplitHostPort(remoteAddr) if err != nil { return nil, fmt.Errorf("无法获取连接信息: %w", err) } // 查询权限记录 if channel.UserAddr != remoteHost { return nil, errors.New("不在白名单内") } } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() if timeout <= 0 { return nil, errors.New("权限已过期") } return &core.AuthContext{ Timeout: timeout, Payload: core.Payload{ ID: channel.UserId, }, }, nil }