package fwd import ( "context" "io" "log/slog" "net" "proxy-server/pkg/utils" "proxy-server/server/fwd/socks" "proxy-server/server/pkg/orm" "proxy-server/server/web/models" "time" "github.com/pkg/errors" ) type NoAuthAuthenticator struct { } func (a *NoAuthAuthenticator) Method() socks.AuthMethod { return socks.NoAuth } func (a *NoAuthAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { // 获取用户地址 conn, ok := writer.(net.Conn) if !ok { return nil, errors.New("noAuth 认证失败,无法获取连接信息") } addr := conn.RemoteAddr().String() client, _, err := net.SplitHostPort(addr) if err != nil { return nil, errors.Wrap(err, "noAuth 认证失败") } slog.Debug("用户的地址为 " + client) // 获取服务 server, ok := ctx.Value("service").(*socks.Server) if !ok { return nil, errors.New("noAuth 认证失败,无法获取服务信息") } node := server.Name slog.Debug("服务的名称为 " + server.Name) // 查询权限记录 slog.Info("用户 " + client + " 请求连接到 " + node) var channels []models.Channel err = orm.DB. Joins("INNER JOIN public.nodes n ON channels.node_id = n.id AND n.name = ?", node). Joins("INNER JOIN public.users u ON channels.user_id = u.id"). Joins("INNER JOIN public.user_ips ip ON u.id = ip.user_id AND ip.ip_address = ?", client). Where(&models.Channel{ AuthIp: true, }). Find(&channels).Error if err != nil { return nil, errors.New("noAuth 查询用户权限失败") } // 记录应该只有一条 channel, err := orm.MaySingle(channels) if err != nil { return nil, errors.Wrap(err, "noAuth 没有权限") } // 检查是否需要密码认证 if channel.AuthPass { return nil, errors.New("noAuth 没有权限,需要密码认证") } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() slog.Info("用户剩余时间", "timeout", timeout) if timeout <= 0 { return nil, errors.New("noAuth 权限已过期") } slog.Debug("权限剩余时间", slog.Uint64("timeout", uint64(timeout))) return &socks.Authentication{ Method: socks.NoAuth, Timeout: uint(timeout), Payload: socks.Payload{ ID: channel.UserId, }, }, nil } type UserPassAuthenticator struct { } func (a *UserPassAuthenticator) Method() socks.AuthMethod { return socks.UserPassAuth } func (a *UserPassAuthenticator) Authenticate(ctx context.Context, reader io.Reader, writer io.Writer) (*socks.Authentication, error) { // 检查认证版本 slog.Debug("验证认证版本") v, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取版本号失败") } if v != socks.AuthVersion { _, err := writer.Write([]byte{socks.Version, socks.AuthFailure}) if err != nil { return nil, errors.Wrap(err, "响应认证失败") } return nil, errors.New("认证版本参数不正确") } // 读取账号 slog.Debug("验证用户账号") uLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取用户名长度失败") } usernameBuf, err := utils.ReadBuffer(reader, int(uLen)) if err != nil { return nil, errors.Wrap(err, "读取用户名失败") } username := string(usernameBuf) // 读取密码 pLen, err := utils.ReadByte(reader) if err != nil { return nil, errors.Wrap(err, "读取密码长度失败") } passwordBuf, err := utils.ReadBuffer(reader, int(pLen)) if err != nil { return nil, errors.Wrap(err, "读取密码失败") } password := string(passwordBuf) // 查询通道配置 var channel models.Channel err = orm.DB. Where(&models.Channel{ Username: username, AuthPass: true, }). First(&channel).Error if err != nil { return nil, errors.Wrap(err, "查询用户失败") } // 检查密码 todo 哈希 if channel.Password != password { return nil, errors.New("密码错误") } // 检查权限是否过期 timeout := channel.Expiration.Sub(time.Now()).Seconds() slog.Info("用户剩余时间", "timeout", timeout) if timeout <= 0 { return nil, errors.New("权限已过期") } // 如果用户设置了双验证则检查 ip 是否在白名单中 if channel.AuthIp { slog.Debug("验证用户 ip") // 获取用户地址 conn, ok := writer.(net.Conn) if !ok { return nil, errors.New("无法获取连接信息") } addr := conn.RemoteAddr().String() client, _, err := net.SplitHostPort(addr) if err != nil { return nil, errors.Wrap(err, "无法获取连接信息") } // 查询通道配置 var ips []models.UserIp err = orm.DB. Where(&models.UserIp{ UserId: channel.UserId, IpAddress: client, }). Find(&ips).Error if err != nil { return nil, errors.Wrap(err, "查询用户 ip 失败") } // 检查是否在白名单中 if len(ips) == 0 { return nil, errors.New("没有权限") } } // 响应认证成功 _, err = writer.Write([]byte{socks.AuthVersion, socks.AuthSuccess}) if err != nil { slog.Error("响应认证失败", "err", err) return nil, err } return &socks.Authentication{ Method: socks.UserPassAuth, Timeout: uint(timeout), Payload: socks.Payload{ ID: channel.UserId, }, }, nil }