package web import ( "platform/web/common" "slices" "strings" "platform/web/services" "github.com/gofiber/fiber/v2" ) func Permit(types []services.PayloadType, permissions ...string) fiber.Handler { return func(c *fiber.Ctx) error { // 获取令牌 var header = c.Get("Authorization") var token = strings.TrimPrefix(header, "Bearer ") if token == "" { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 验证令牌 auth, err := services.Session.Find(c.Context(), token) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 检查权限 // switch auth.Payload.Type { // case services.PayloadAdmin: // // 管理员不需要权限检查 // case services.PayloadUser: // if len(permissions) > 0 && !auth.AnyPermission(permissions...) { // return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ // Error: true, // Message: "拒绝访问", // }) // } // default: // return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ // Error: true, // Message: "拒绝访问", // }) // } if !slices.Contains(types, auth.Payload.Type) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } // 将认证信息存储在上下文中 c.Locals("auth", auth) c.Locals("access_token", token) // 存储原始令牌,便于后续操作 return c.Next() } } // PermitUser 创建针对单个路由的鉴权中间件 func PermitUser(permissions ...string) fiber.Handler { return func(c *fiber.Ctx) error { // 获取令牌 var header = c.Get("Authorization") var token = strings.TrimPrefix(header, "Bearer ") if token == "" { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 验证令牌 auth, err := services.Session.Find(c.Context(), token) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 检查权限 switch auth.Payload.Type { case services.PayloadAdmin: // 管理员不需要权限检查 case services.PayloadUser: if len(permissions) > 0 && !auth.AnyPermission(permissions...) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } default: return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } // 将认证信息存储在上下文中 c.Locals("auth", auth) c.Locals("access_token", token) // 存储原始令牌,便于后续操作 return c.Next() } } func PermitDevice(permissions ...string) fiber.Handler { return func(c *fiber.Ctx) error { // 获取令牌 var header = c.Get("Authorization") var token = strings.TrimPrefix(header, "Bearer ") if token == "" { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 验证令牌 auth, err := services.Session.Find(c.Context(), token) if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 检查权限 switch auth.Payload.Type { case services.PayloadAdmin: // 管理员不需要权限检查 case services.PayloadClientPublic, services.PayloadClientConfidential: if len(permissions) > 0 && !auth.AnyPermission(permissions...) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } default: return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } // 将认证信息存储在上下文中 c.Locals("auth", auth) c.Locals("access_token", token) // 存储原始令牌,便于后续操作 return c.Next() } }