Initial commit
This commit is contained in:
607
env/lib/python3.10/site-packages/wagtail/api/v2/views.py
vendored
Normal file
607
env/lib/python3.10/site-packages/wagtail/api/v2/views.py
vendored
Normal file
@@ -0,0 +1,607 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.http import Http404
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import path, reverse
|
||||
from modelcluster.fields import ParentalKey
|
||||
from rest_framework import status
|
||||
from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from wagtail.api import APIField
|
||||
from wagtail.models import Page, PageViewRestriction, Site
|
||||
|
||||
from .filters import (
|
||||
AncestorOfFilter,
|
||||
ChildOfFilter,
|
||||
DescendantOfFilter,
|
||||
FieldsFilter,
|
||||
LocaleFilter,
|
||||
OrderingFilter,
|
||||
SearchFilter,
|
||||
TranslationOfFilter,
|
||||
)
|
||||
from .pagination import WagtailPagination
|
||||
from .serializers import BaseSerializer, PageSerializer, get_serializer_class
|
||||
from .utils import (
|
||||
BadRequestError,
|
||||
get_object_detail_url,
|
||||
page_models_from_string,
|
||||
parse_fields_parameter,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIViewSet(GenericViewSet):
|
||||
renderer_classes = [JSONRenderer, BrowsableAPIRenderer]
|
||||
|
||||
pagination_class = WagtailPagination
|
||||
base_serializer_class = BaseSerializer
|
||||
filter_backends = []
|
||||
model = None # Set on subclass
|
||||
|
||||
known_query_parameters = frozenset(
|
||||
[
|
||||
"limit",
|
||||
"offset",
|
||||
"fields",
|
||||
"order",
|
||||
"search",
|
||||
"search_operator",
|
||||
# Used by jQuery for cache-busting. See #1671
|
||||
"_",
|
||||
# Required by BrowsableAPIRenderer
|
||||
"format",
|
||||
]
|
||||
)
|
||||
body_fields = ["id"]
|
||||
meta_fields = ["type", "detail_url"]
|
||||
listing_default_fields = ["id", "type", "detail_url"]
|
||||
nested_default_fields = ["id", "type", "detail_url"]
|
||||
detail_only_fields = []
|
||||
name = None # Set on subclass.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# seen_types is a mapping of type name strings (format: "app_label.ModelName")
|
||||
# to model classes. When an object is serialised in the API, its model
|
||||
# is added to this mapping. This is used by the Admin API which appends a
|
||||
# summary of the used types to the response.
|
||||
self.seen_types = OrderedDict()
|
||||
|
||||
def get_queryset(self):
|
||||
return self.model.objects.all().order_by("id")
|
||||
|
||||
def listing_view(self, request):
|
||||
queryset = self.get_queryset()
|
||||
self.check_query_parameters(queryset)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
queryset = self.paginate_queryset(queryset)
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
def detail_view(self, request, pk):
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
return Response(serializer.data)
|
||||
|
||||
def find_view(self, request):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
try:
|
||||
obj = self.find_object(queryset, request)
|
||||
|
||||
if obj is None:
|
||||
raise self.model.DoesNotExist
|
||||
|
||||
except self.model.DoesNotExist:
|
||||
raise Http404("not found")
|
||||
|
||||
# Generate redirect
|
||||
url = get_object_detail_url(
|
||||
self.request.wagtailapi_router, request, self.model, obj.pk
|
||||
)
|
||||
|
||||
if url is None:
|
||||
# Shouldn't happen unless this endpoint isn't actually installed in the router
|
||||
raise Exception(
|
||||
"Cannot generate URL to detail view. Is '{}' installed in the API router?".format(
|
||||
self.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
return redirect(url)
|
||||
|
||||
def find_object(self, queryset, request):
|
||||
"""
|
||||
Override this to implement more find methods.
|
||||
"""
|
||||
if "id" in request.GET:
|
||||
return queryset.get(id=request.GET["id"])
|
||||
|
||||
def handle_exception(self, exc):
|
||||
if isinstance(exc, Http404):
|
||||
data = {"message": str(exc)}
|
||||
return Response(data, status=status.HTTP_404_NOT_FOUND)
|
||||
elif isinstance(exc, BadRequestError):
|
||||
data = {"message": str(exc)}
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
return super().handle_exception(exc)
|
||||
|
||||
@classmethod
|
||||
def _convert_api_fields(cls, fields):
|
||||
return [
|
||||
field if isinstance(field, APIField) else APIField(field)
|
||||
for field in fields
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_body_fields(cls, model):
|
||||
return cls._convert_api_fields(
|
||||
cls.body_fields + list(getattr(model, "api_fields", ()))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_body_fields_names(cls, model):
|
||||
return [field.name for field in cls.get_body_fields(model)]
|
||||
|
||||
@classmethod
|
||||
def get_meta_fields(cls, model):
|
||||
return cls._convert_api_fields(
|
||||
cls.meta_fields + list(getattr(model, "api_meta_fields", ()))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_meta_fields_names(cls, model):
|
||||
return [field.name for field in cls.get_meta_fields(model)]
|
||||
|
||||
@classmethod
|
||||
def get_field_serializer_overrides(cls, model):
|
||||
return {
|
||||
field.name: field.serializer
|
||||
for field in cls.get_body_fields(model) + cls.get_meta_fields(model)
|
||||
if field.serializer is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_available_fields(cls, model, db_fields_only=False):
|
||||
"""
|
||||
Returns a list of all the fields that can be used in the API for the
|
||||
specified model class.
|
||||
|
||||
Setting db_fields_only to True will remove all fields that do not have
|
||||
an underlying column in the database (eg, type/detail_url and any custom
|
||||
fields that are callables)
|
||||
"""
|
||||
fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model)
|
||||
|
||||
if db_fields_only:
|
||||
# Get list of available database fields then remove any fields in our
|
||||
# list that isn't a database field
|
||||
database_fields = set()
|
||||
for field in model._meta.get_fields():
|
||||
database_fields.add(field.name)
|
||||
|
||||
if hasattr(field, "attname"):
|
||||
database_fields.add(field.attname)
|
||||
|
||||
fields = [field for field in fields if field in database_fields]
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def get_detail_default_fields(cls, model):
|
||||
return cls.get_available_fields(model)
|
||||
|
||||
@classmethod
|
||||
def get_listing_default_fields(cls, model):
|
||||
return cls.listing_default_fields[:]
|
||||
|
||||
@classmethod
|
||||
def get_nested_default_fields(cls, model):
|
||||
return cls.nested_default_fields[:]
|
||||
|
||||
def check_query_parameters(self, queryset):
|
||||
"""
|
||||
Ensure that only valid query parameters are included in the URL.
|
||||
"""
|
||||
query_parameters = set(self.request.GET.keys())
|
||||
|
||||
# All query parameters must be either a database field or an operation
|
||||
allowed_query_parameters = set(
|
||||
self.get_available_fields(queryset.model, db_fields_only=True)
|
||||
).union(self.known_query_parameters)
|
||||
unknown_parameters = query_parameters - allowed_query_parameters
|
||||
if unknown_parameters:
|
||||
raise BadRequestError(
|
||||
"query parameter is not an operation or a recognised field: %s"
|
||||
% ", ".join(sorted(unknown_parameters))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_serializer_class(
|
||||
cls, router, model, fields_config, show_details=False, nested=False
|
||||
):
|
||||
# Get all available fields
|
||||
body_fields = cls.get_body_fields_names(model)
|
||||
meta_fields = cls.get_meta_fields_names(model)
|
||||
all_fields = body_fields + meta_fields
|
||||
|
||||
# Remove any duplicates
|
||||
all_fields = list(OrderedDict.fromkeys(all_fields))
|
||||
|
||||
if not show_details:
|
||||
# Remove detail only fields
|
||||
for field in cls.detail_only_fields:
|
||||
try:
|
||||
all_fields.remove(field)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Get list of configured fields
|
||||
if show_details:
|
||||
fields = set(cls.get_detail_default_fields(model))
|
||||
elif nested:
|
||||
fields = set(cls.get_nested_default_fields(model))
|
||||
else:
|
||||
fields = set(cls.get_listing_default_fields(model))
|
||||
|
||||
# If first field is '*' start with all fields
|
||||
# If first field is '_' start with no fields
|
||||
if fields_config and fields_config[0][0] == "*":
|
||||
fields = set(all_fields)
|
||||
fields_config = fields_config[1:]
|
||||
elif fields_config and fields_config[0][0] == "_":
|
||||
fields = set()
|
||||
fields_config = fields_config[1:]
|
||||
|
||||
mentioned_fields = set()
|
||||
sub_fields = {}
|
||||
|
||||
for field_name, negated, field_sub_fields in fields_config:
|
||||
if negated:
|
||||
try:
|
||||
fields.remove(field_name)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
fields.add(field_name)
|
||||
if field_sub_fields:
|
||||
sub_fields[field_name] = field_sub_fields
|
||||
|
||||
mentioned_fields.add(field_name)
|
||||
|
||||
unknown_fields = mentioned_fields - set(all_fields)
|
||||
|
||||
if unknown_fields:
|
||||
raise BadRequestError(
|
||||
"unknown fields: %s" % ", ".join(sorted(unknown_fields))
|
||||
)
|
||||
|
||||
# Build nested serialisers
|
||||
child_serializer_classes = {}
|
||||
|
||||
for field_name in fields:
|
||||
try:
|
||||
django_field = model._meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
django_field = None
|
||||
|
||||
if django_field and django_field.is_relation:
|
||||
child_sub_fields = sub_fields.get(field_name, [])
|
||||
|
||||
# Inline (aka "child") models should display all fields by default
|
||||
if isinstance(getattr(django_field, "field", None), ParentalKey):
|
||||
if not child_sub_fields or child_sub_fields[0][0] not in ["*", "_"]:
|
||||
child_sub_fields = list(child_sub_fields)
|
||||
child_sub_fields.insert(0, ("*", False, None))
|
||||
|
||||
# Get a serializer class for the related object
|
||||
child_model = django_field.related_model
|
||||
child_endpoint_class = router.get_model_endpoint(child_model)
|
||||
child_endpoint_class = (
|
||||
child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet
|
||||
)
|
||||
child_serializer_classes[
|
||||
field_name
|
||||
] = child_endpoint_class._get_serializer_class(
|
||||
router, child_model, child_sub_fields, nested=True
|
||||
)
|
||||
|
||||
else:
|
||||
if field_name in sub_fields:
|
||||
# Sub fields were given for a non-related field
|
||||
raise BadRequestError(
|
||||
"'%s' does not support nested fields" % field_name
|
||||
)
|
||||
|
||||
# Reorder fields so it matches the order of all_fields
|
||||
fields = [field for field in all_fields if field in fields]
|
||||
|
||||
field_serializer_overrides = {
|
||||
field[0]: field[1]
|
||||
for field in cls.get_field_serializer_overrides(model).items()
|
||||
if field[0] in fields
|
||||
}
|
||||
return get_serializer_class(
|
||||
model,
|
||||
fields,
|
||||
meta_fields=meta_fields,
|
||||
field_serializer_overrides=field_serializer_overrides,
|
||||
child_serializer_classes=child_serializer_classes,
|
||||
base=cls.base_serializer_class,
|
||||
)
|
||||
|
||||
def get_serializer_class(self):
|
||||
request = self.request
|
||||
|
||||
# Get model
|
||||
if self.action == "listing_view":
|
||||
model = self.get_queryset().model
|
||||
else:
|
||||
model = type(self.get_object())
|
||||
|
||||
# Fields
|
||||
if "fields" in request.GET:
|
||||
try:
|
||||
fields_config = parse_fields_parameter(request.GET["fields"])
|
||||
except ValueError as e:
|
||||
raise BadRequestError("fields error: %s" % str(e))
|
||||
else:
|
||||
# Use default fields
|
||||
fields_config = []
|
||||
|
||||
# Allow "detail_only" (eg parent) fields on detail view
|
||||
if self.action == "listing_view":
|
||||
show_details = False
|
||||
else:
|
||||
show_details = True
|
||||
|
||||
return self._get_serializer_class(
|
||||
self.request.wagtailapi_router,
|
||||
model,
|
||||
fields_config,
|
||||
show_details=show_details,
|
||||
)
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
The serialization context differs between listing and detail views.
|
||||
"""
|
||||
return {
|
||||
"request": self.request,
|
||||
"view": self,
|
||||
"router": self.request.wagtailapi_router,
|
||||
}
|
||||
|
||||
def get_renderer_context(self):
|
||||
context = super().get_renderer_context()
|
||||
context["indent"] = 4
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
def get_urlpatterns(cls):
|
||||
"""
|
||||
This returns a list of URL patterns for the endpoint
|
||||
"""
|
||||
return [
|
||||
path("", cls.as_view({"get": "listing_view"}), name="listing"),
|
||||
path("<int:pk>/", cls.as_view({"get": "detail_view"}), name="detail"),
|
||||
path("find/", cls.as_view({"get": "find_view"}), name="find"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_model_listing_urlpath(cls, model, namespace=""):
|
||||
if namespace:
|
||||
url_name = namespace + ":listing"
|
||||
else:
|
||||
url_name = "listing"
|
||||
|
||||
return reverse(url_name)
|
||||
|
||||
@classmethod
|
||||
def get_object_detail_urlpath(cls, model, pk, namespace=""):
|
||||
if namespace:
|
||||
url_name = namespace + ":detail"
|
||||
else:
|
||||
url_name = "detail"
|
||||
|
||||
return reverse(url_name, args=(pk,))
|
||||
|
||||
|
||||
class PagesAPIViewSet(BaseAPIViewSet):
|
||||
base_serializer_class = PageSerializer
|
||||
filter_backends = [
|
||||
FieldsFilter,
|
||||
ChildOfFilter,
|
||||
AncestorOfFilter,
|
||||
DescendantOfFilter,
|
||||
OrderingFilter,
|
||||
TranslationOfFilter,
|
||||
LocaleFilter,
|
||||
SearchFilter, # needs to be last, as SearchResults querysets cannot be filtered further
|
||||
]
|
||||
known_query_parameters = BaseAPIViewSet.known_query_parameters.union(
|
||||
[
|
||||
"type",
|
||||
"child_of",
|
||||
"ancestor_of",
|
||||
"descendant_of",
|
||||
"translation_of",
|
||||
"locale",
|
||||
"site",
|
||||
]
|
||||
)
|
||||
body_fields = BaseAPIViewSet.body_fields + [
|
||||
"title",
|
||||
]
|
||||
meta_fields = BaseAPIViewSet.meta_fields + [
|
||||
"html_url",
|
||||
"slug",
|
||||
"show_in_menus",
|
||||
"seo_title",
|
||||
"search_description",
|
||||
"first_published_at",
|
||||
"alias_of",
|
||||
"parent",
|
||||
"locale",
|
||||
]
|
||||
listing_default_fields = BaseAPIViewSet.listing_default_fields + [
|
||||
"title",
|
||||
"html_url",
|
||||
"slug",
|
||||
"first_published_at",
|
||||
]
|
||||
nested_default_fields = BaseAPIViewSet.nested_default_fields + [
|
||||
"title",
|
||||
]
|
||||
detail_only_fields = ["parent"]
|
||||
name = "pages"
|
||||
model = Page
|
||||
|
||||
@classmethod
|
||||
def get_detail_default_fields(cls, model):
|
||||
detail_default_fields = super().get_detail_default_fields(model)
|
||||
|
||||
# When i18n is disabled, remove "locale" from default fields
|
||||
if not getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
detail_default_fields.remove("locale")
|
||||
|
||||
return detail_default_fields
|
||||
|
||||
@classmethod
|
||||
def get_listing_default_fields(cls, model):
|
||||
listing_default_fields = super().get_listing_default_fields(model)
|
||||
|
||||
# When i18n is enabled, add "locale" to default fields
|
||||
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
listing_default_fields.append("locale")
|
||||
|
||||
return listing_default_fields
|
||||
|
||||
def get_root_page(self):
|
||||
"""
|
||||
Returns the page that is used when the `&child_of=root` filter is used.
|
||||
"""
|
||||
return Site.find_for_request(self.request).root_page
|
||||
|
||||
def get_base_queryset(self):
|
||||
"""
|
||||
Returns a queryset containing all pages that can be seen by this user.
|
||||
|
||||
This is used as the base for get_queryset and is also used to find the
|
||||
parent pages when using the child_of and descendant_of filters as well.
|
||||
"""
|
||||
|
||||
request = self.request
|
||||
|
||||
# Get all live pages
|
||||
queryset = Page.objects.all().live()
|
||||
|
||||
# Exclude pages that the user doesn't have access to
|
||||
restricted_pages = [
|
||||
restriction.page
|
||||
for restriction in PageViewRestriction.objects.all().select_related("page")
|
||||
if not restriction.accept_request(self.request)
|
||||
]
|
||||
|
||||
# Exclude the restricted pages and their descendants from the queryset
|
||||
for restricted_page in restricted_pages:
|
||||
queryset = queryset.not_descendant_of(restricted_page, inclusive=True)
|
||||
|
||||
# Check if we have a specific site to look for
|
||||
if "site" in request.GET:
|
||||
# Optionally allow querying by port
|
||||
if ":" in request.GET["site"]:
|
||||
(hostname, port) = request.GET["site"].split(":", 1)
|
||||
query = {
|
||||
"hostname": hostname,
|
||||
"port": port,
|
||||
}
|
||||
else:
|
||||
query = {
|
||||
"hostname": request.GET["site"],
|
||||
}
|
||||
try:
|
||||
site = Site.objects.get(**query)
|
||||
except Site.MultipleObjectsReturned:
|
||||
raise BadRequestError(
|
||||
"Your query returned multiple sites. Try adding a port number to your site filter."
|
||||
)
|
||||
else:
|
||||
# Otherwise, find the site from the request
|
||||
site = Site.find_for_request(self.request)
|
||||
|
||||
if site:
|
||||
base_queryset = queryset
|
||||
queryset = base_queryset.descendant_of(site.root_page, inclusive=True)
|
||||
|
||||
# If internationalisation is enabled, include pages from other language trees
|
||||
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
for translation in site.root_page.get_translations():
|
||||
queryset |= base_queryset.descendant_of(translation, inclusive=True)
|
||||
|
||||
else:
|
||||
# No sites configured
|
||||
queryset = queryset.none()
|
||||
|
||||
return queryset
|
||||
|
||||
def get_queryset(self):
|
||||
request = self.request
|
||||
|
||||
# Allow pages to be filtered to a specific type
|
||||
try:
|
||||
models_type = request.GET.get("type", None)
|
||||
models = models_type and page_models_from_string(models_type) or []
|
||||
except (LookupError, ValueError):
|
||||
raise BadRequestError("type doesn't exist")
|
||||
|
||||
if not models:
|
||||
if self.model == Page:
|
||||
return self.get_base_queryset()
|
||||
else:
|
||||
return self.model.objects.filter(
|
||||
pk__in=self.get_base_queryset().values_list("pk", flat=True)
|
||||
)
|
||||
|
||||
elif len(models) == 1:
|
||||
# If a single page type has been specified, swap out the Page-based queryset for one based on
|
||||
# the specific page model so that we can filter on any custom APIFields defined on that model
|
||||
return models[0].objects.filter(
|
||||
pk__in=self.get_base_queryset().values_list("pk", flat=True)
|
||||
)
|
||||
|
||||
else: # len(models) > 1
|
||||
return self.get_base_queryset().type(*models)
|
||||
|
||||
def get_object(self):
|
||||
base = super().get_object()
|
||||
return base.specific
|
||||
|
||||
def find_object(self, queryset, request):
|
||||
site = Site.find_for_request(request)
|
||||
if "html_path" in request.GET and site is not None:
|
||||
path = request.GET["html_path"]
|
||||
path_components = [component for component in path.split("/") if component]
|
||||
|
||||
try:
|
||||
page, _, _ = site.root_page.specific.route(request, path_components)
|
||||
except Http404:
|
||||
return
|
||||
|
||||
if queryset.filter(id=page.id).exists():
|
||||
return page
|
||||
|
||||
return super().find_object(queryset, request)
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
The serialization context differs between listing and detail views.
|
||||
"""
|
||||
context = super().get_serializer_context()
|
||||
context["base_queryset"] = self.get_base_queryset()
|
||||
return context
|
||||
Reference in New Issue
Block a user