package web import ( "context" "encoding/base64" "errors" "log/slog" "platform/web/common" q "platform/web/queries" "slices" "strings" "platform/web/services" "github.com/gofiber/fiber/v2" ) func Permit(types []services.PayloadType, permissions ...string) fiber.Handler { return func(c *fiber.Ctx) error { // 获取令牌 var header = c.Get("Authorization") var split = strings.Split(header, " ") if len(split) != 2 { return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{ Error: true, Message: "无效的令牌", }) } var token = split[1] if token == "" { return c.Status(fiber.StatusBadRequest).JSON(common.ErrResp{ Error: true, Message: "无效的令牌", }) } var auth *services.AuthContext var err error switch split[0] { case "Bearer": auth, err = authBearer(c.Context(), token) case "Basic": if !slices.Contains(types, services.PayloadClientConfidential) { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } auth, err = authBasic(c.Context(), token) } if err != nil { return c.Status(fiber.StatusUnauthorized).JSON(common.ErrResp{ Error: true, Message: "没有权限", }) } // 检查权限 if !slices.Contains(types, auth.Payload.Type) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { return c.Status(fiber.StatusForbidden).JSON(common.ErrResp{ Error: true, Message: "拒绝访问", }) } // 将认证信息存储在上下文中 c.Locals("auth", auth) c.Locals("access_token", token) // 存储原始令牌,便于后续操作 return c.Next() } } func PermitAll(permissions ...string) fiber.Handler { return Permit([]services.PayloadType{ services.PayloadClientPublic, services.PayloadClientConfidential, services.PayloadUser, services.PayloadAdmin, }, permissions...) } // PermitUser 创建针对单个路由的鉴权中间件 func PermitUser(permissions ...string) fiber.Handler { return Permit([]services.PayloadType{ services.PayloadUser, services.PayloadAdmin, }, permissions...) } func PermitDevice(permissions ...string) fiber.Handler { return Permit([]services.PayloadType{ services.PayloadClientPublic, services.PayloadClientConfidential, services.PayloadAdmin, }, permissions...) } func PermitPublic(permissions ...string) fiber.Handler { return Permit([]services.PayloadType{ services.PayloadClientPublic, services.PayloadAdmin, }, permissions...) } func PermitConfidential(permissions ...string) fiber.Handler { return Permit([]services.PayloadType{ services.PayloadClientConfidential, services.PayloadAdmin, }, permissions...) } func authBearer(ctx context.Context, token string) (*services.AuthContext, error) { auth, err := services.Session.Find(ctx, token) if err != nil { slog.Debug(err.Error()) return nil, err } return auth, nil } func authBasic(ctx context.Context, token string) (*services.AuthContext, error) { // 解析 Basic 认证信息 var base, err = base64.URLEncoding.DecodeString(token) if err != nil { slog.Debug(err.Error()) return nil, err } var split = strings.Split(string(base), ":") if len(split) != 2 { msg := "无法解析 Basic 认证信息" slog.Debug(msg) return nil, errors.New(msg) } var clientID = split[0] // 获取客户端信息 client, err := q.Client. Where( q.Client.ClientID.Eq(clientID), q.Client.Spec.Eq(0), q.Client.GrantClient.Is(true), q.Client.Status.Eq(1)). Take() if err != nil { return nil, err } // todo 查询客户端关联权限 // 组织授权信息(一次性请求) return &services.AuthContext{ Payload: services.Payload{ Id: client.ID, Type: services.PayloadClientConfidential, Name: client.Name, Avatar: client.Icon, }, Permissions: nil, Metadata: nil, }, nil }