package auth import ( "context" "encoding/base64" "errors" "log/slog" q "platform/web/queries" "slices" "strings" "platform/web/services" "github.com/gofiber/fiber/v2" "golang.org/x/crypto/bcrypt" ) func Protect(c *fiber.Ctx, types []services.PayloadType, permissions []string) (*services.AuthContext, error) { // 获取令牌 var header = c.Get("Authorization") var split = strings.Split(header, " ") if len(split) != 2 { return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var token = split[1] if token == "" { return nil, fiber.NewError(fiber.StatusUnauthorized, "无效的令牌") } var auth *services.AuthContext var err error switch split[0] { case "Bearer": auth, err = authBearer(c.Context(), token) if err != nil { slog.Debug("Bearer 认证失败") return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") } case "Basic": if !slices.Contains(types, services.PayloadClientConfidential) { slog.Debug("禁止使用 Basic 认证方式") return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") } auth, err = authBasic(c.Context(), token) if err != nil { slog.Debug("Basic 认证失败") return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") } default: slog.Debug("无效的认证方式") return nil, fiber.NewError(fiber.StatusUnauthorized, "没有权限") } // 检查权限 if !slices.Contains(types, auth.Payload.Type) { slog.Debug("无效的认证主体") return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } if len(permissions) > 0 && !auth.AnyPermission(permissions...) { slog.Debug("无效的认证权限") return nil, fiber.NewError(fiber.StatusForbidden, "没有权限") } return auth, nil } func authBearer(ctx context.Context, token string) (*services.AuthContext, error) { auth, err := services.Session.Find(ctx, token) if err != nil { slog.Debug(err.Error()) return nil, err } return auth, nil } func authBasic(_ context.Context, token string) (*services.AuthContext, error) { // 解析 Basic 认证信息 var base, err = base64.URLEncoding.DecodeString(token) if err != nil { slog.Debug(err.Error()) return nil, err } var split = strings.Split(string(base), ":") if len(split) != 2 { msg := "无法解析 Basic 认证信息" slog.Debug(msg) return nil, errors.New(msg) } var clientID = split[0] // 获取客户端信息 client, err := q.Client. Where( q.Client.ClientID.Eq(clientID), q.Client.Spec.Eq(0), q.Client.GrantClient.Is(true), q.Client.Status.Eq(1)). Take() if err != nil { return nil, err } // 检查客户端状态 if client.Status != 1 { return nil, errors.New("客户端已被禁用") } // 检查客户端类型 if client.Spec != 0 { return nil, errors.New("客户端类型错误") } // 检查客户端密钥 var clientSecret = split[1] if bcrypt.CompareHashAndPassword([]byte(client.ClientSecret), []byte(clientSecret)) != nil { return nil, errors.New("客户端密钥错误") } // todo 查询客户端关联权限 // 组织授权信息(一次性请求) return &services.AuthContext{ Payload: services.Payload{ Id: client.ID, Type: services.PayloadClientConfidential, Name: client.Name, Avatar: client.Icon, }, Permissions: nil, Metadata: nil, }, nil }