优化认证逻辑,无权限打印原因

This commit is contained in:
2025-04-28 10:05:44 +08:00
parent 129f842153
commit 370362b0d5
3 changed files with 32 additions and 92 deletions

View File

@@ -5,7 +5,6 @@ import (
"encoding/base64"
"errors"
"log/slog"
"platform/web/common"
q "platform/web/queries"
"slices"
"strings"
@@ -13,88 +12,9 @@ import (
"platform/web/services"
"github.com/gofiber/fiber/v2"
"golang.org/x/crypto/bcrypt"
)
func Permit(types []services.PayloadType, permissions ...string) fiber.Handler {
return func(c *fiber.Ctx) error {
// 获取令牌
var header = c.Get("Authorization")
var split = strings.Split(header, " ")
if len(split) != 2 {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "无效的令牌",
})
}
var token = split[1]
if token == "" {
return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{
Error: true,
Message: "无效的令牌",
})
}
var auth *services.AuthContext
var err error
switch split[0] {
case "Bearer":
auth, err = authBearer(c.Context(), token)
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
auth, err = authBasic(c.Context(), token)
}
if err != nil {
return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{
Error: true,
Message: "没有权限",
})
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{
Error: true,
Message: "拒绝访问",
})
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
return c.Next()
}
}
func PermitAll(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadClientConfidential,
services.PayloadUser,
services.PayloadAdmin,
}, permissions...)
}
func PermitDevice(permissions ...string) fiber.Handler {
return Permit([]services.PayloadType{
services.PayloadClientPublic,
services.PayloadClientConfidential,
services.PayloadAdmin,
}, permissions...)
}
func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) {
// 获取令牌
var header = c.Get("Authorization")
@@ -115,33 +35,36 @@ func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (
case "Bearer":
auth, err = authBearer(c.Context(), token)
if err != nil {
slog.Debug("Bearer 认证失败")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
case "Basic":
if !slices.Contains(types, services.PayloadClientConfidential) {
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
slog.Debug("禁止使用 Basic 认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
auth, err = authBasic(c.Context(), token)
if err != nil {
slog.Debug("Basic 认证失败")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
default:
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
slog.Debug("无效的认证方式")
return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限")
}
// 检查权限
if !slices.Contains(types, auth.Payload.Type) {
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证主体")
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
// 将认证信息存储在上下文中
c.Locals("auth", auth)
c.Locals("access_token", token) // 存储原始令牌,便于后续操作
if len(permissions) > 0 && !auth.AnyPermission(permissions...) {
slog.Debug("无效的认证权限")
return nil, fiber.NewError(fiber.StatusForbidden, "没有权限")
}
return auth, nil
}
@@ -185,6 +108,22 @@ func authBasic(_ context.Context, token string) (*services.AuthContext, error) {
return nil, err
}
// 检查客户端状态
if client.Status != 1 {
return nil, errors.New("客户端已被禁用")
}
// 检查客户端类型
if client.Spec != 0 {
return nil, errors.New("客户端类型错误")
}
// 检查客户端密钥
var clientSecret = split[1]
if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil {
return nil, errors.New("客户端密钥错误")
}
// todo 查询客户端关联权限
// 组织授权信息(一次性请求)

View File

@@ -1,7 +1,6 @@
package web
import (
auth2 "platform/web/auth"
"platform/web/handlers"
"github.com/gofiber/fiber/v2"
@@ -29,7 +28,7 @@ func ApplyRouters(app *fiber.App) {
// 通道
channel := api.Group("/channel")
channel.Post("/create", handlers.CreateChannel)
channel.Post("/remove", auth2.PermitAll(), handlers.RemoveChannels)
channel.Post("/remove", handlers.RemoveChannels)
// 白名单
whitelist := api.Group("/whitelist")