package auth import ( "context" "encoding/base64" "errors" "log/slog" client2 "platform/web/domains/client" q "platform/web/queries" "slices" "strings" "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" ) func Protect(c *fiber.Ctx, types []PayloadType, permissions []string) (*Context, error) { // 获取令牌 var header = c.Get("Authorization") var split = strings.Split(header, " ") if len(split) != 2 { return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var token = split[1] if token == "" { return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var auth *Context var err error switch split[0] { case "Bearer": auth, err = authBearer(c.Context(), token) if err != nil { slog.Debug("Bearer 认证失败", "err", err) return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } case "Basic": if !slices.Contains(types, PayloadSecuredServer) { slog.Debug("禁止使用 Basic 认证方式") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } auth, err = authBasic(c.Context(), token) if err != nil { slog.Debug("Basic 认证失败", "err", err) return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } default: slog.Debug("无效的认证方式") return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } // 检查权限 if !slices.Contains(types, auth.Payload.Type) { slog.Debug("无效的认证主体") return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { slog.Debug("无效的认证权限") return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } // 保存到上下文 Locals(c, auth) return auth, nil } func Locals(c *fiber.Ctx, auth *Context) { c.Locals("auth", auth) c.Locals("authtype", auth.Payload.Type.ToStr()) c.Locals("authid", auth.Payload.Id) } func authBearer(ctx context.Context, token string) (*Context, error) { auth, err := find(ctx, token) if err != nil { slog.Debug(err.Error()) return nil, err } return auth, nil } func authBasic(_ context.Context, token string) (*Context, error) { // 解析 Basic 认证信息 var base, err = base64.RawURLEncoding.DecodeString(token) if err != nil { slog.Debug(err.Error()) return nil, err } var split = strings.Split(string(base), ":") if len(split) != 2 { msg := "无法解析 Basic 认证信息" slog.Debug(msg) return nil, errors.New(msg) } var clientID = split[0] // 获取客户端信息 client, err := q.Client. Where( q.Client.ClientID.Eq(clientID), q.Client.Spec.In(int32(client2.SpecWeb), int32(client2.SpecTrusted)), q.Client.GrantClient.Is(true), q.Client.Status.Eq(1)). Take() if err != nil { return nil, err } // 检查客户端密钥 var clientSecret = split[1] if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { return nil, errors.New("客户端密钥错误") } // todo 查询客户端关联权限 // 组织授权信息(一次性请求) return &Context{ Payload: Payload{ Id: client.ID, Type: PayloadSecuredServer, Name: client.Name, Avatar: client.Icon, }, Permissions: nil, Metadata: nil, }, nil }