Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1 @@
from .conf import APIField # noqa: F401

View File

@@ -0,0 +1,10 @@
class APIField:
def __init__(self, name, serializer=None):
self.name = name
self.serializer = serializer
def __hash__(self):
return hash(self.name)
def __repr__(self):
return f"<APIField {self.name}>"

View File

@@ -0,0 +1,22 @@
from django.apps import AppConfig, apps
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.translation import gettext_lazy as _
class WagtailAPIV2AppConfig(AppConfig):
name = "wagtail.api.v2"
label = "wagtailapi_v2"
verbose_name = _("Wagtail API v2")
def ready(self):
# Install cache purging signal handlers
if getattr(settings, "WAGTAILAPI_USE_FRONTENDCACHE", False):
if apps.is_installed("wagtail.contrib.frontend_cache"):
from wagtail.api.v2.signal_handlers import register_signal_handlers
register_signal_handlers()
else:
raise ImproperlyConfigured(
"The setting 'WAGTAILAPI_USE_FRONTENDCACHE' is True but 'wagtail.contrib.frontend_cache' is not in INSTALLED_APPS."
)

View File

@@ -0,0 +1,307 @@
from django.conf import settings
from django.db import models
from django.shortcuts import get_object_or_404
from rest_framework.filters import BaseFilterBackend
from taggit.managers import TaggableManager
from wagtail.models import Locale, Page
from wagtail.search.backends import get_search_backend
from wagtail.search.backends.base import FilterFieldError, OrderByFieldError
from .utils import BadRequestError, parse_boolean
class FieldsFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This performs field level filtering on the result set
Eg: ?title=James Joyce
"""
fields = set(view.get_available_fields(queryset.model, db_fields_only=True))
# Locale is a database field, but we provide a separate filter for it
if "locale" in fields:
fields.remove("locale")
for field_name, value in request.GET.items():
if field_name in fields:
try:
field = queryset.model._meta.get_field(field_name)
except LookupError:
field = None
# Convert value into python
try:
if isinstance(
field, (models.BooleanField, models.NullBooleanField)
):
value = parse_boolean(value)
elif isinstance(field, (models.IntegerField, models.AutoField)):
value = int(value)
elif isinstance(field, models.ForeignKey):
value = field.target_field.get_prep_value(value)
except ValueError as e:
raise BadRequestError(
"field filter error. '%s' is not a valid value for %s (%s)"
% (value, field_name, str(e))
)
if "\x00" in str(value):
raise BadRequestError(
"field filter error. null characters are not allowed for %s"
% field_name
)
if isinstance(field, TaggableManager):
for tag in value.split(","):
queryset = queryset.filter(**{field_name + "__name": tag})
# Stick a message on the queryset to indicate that tag filtering has been performed
# This will let the do_search method know that it must raise an error as searching
# and tag filtering at the same time is not supported
queryset._filtered_by_tag = True
else:
queryset = queryset.filter(**{field_name: value})
return queryset
class OrderingFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This applies ordering to the result set with support for multiple fields.
Eg: ?order=title or ?order=title,created_at
It also supports reverse ordering
Eg: ?order=-title
And random ordering
Eg: ?order=random
"""
if "order" in request.GET:
order_by_list = request.GET["order"].split(",")
# Random ordering
if "random" in order_by_list:
if len(order_by_list) > 1:
raise BadRequestError(
"random ordering cannot be combined with other fields"
)
# Prevent ordering by random with offset
if "offset" in request.GET:
raise BadRequestError(
"random ordering with offset is not supported"
)
return queryset.order_by("?")
order_by_fields = []
for order_by in order_by_list:
# Check if reverse ordering is set
if order_by.startswith("-"):
reverse_order = True
order_by = order_by[1:]
else:
reverse_order = False
# Add ordering
if order_by in view.get_available_fields(queryset.model):
order_by_fields.append(order_by)
else:
# Unknown field
raise BadRequestError(
"cannot order by '%s' (unknown field)" % order_by
)
# Apply ordering to the queryset
queryset = queryset.order_by(*order_by_fields)
# Reverse order if needed
if reverse_order:
queryset = queryset.reverse()
return queryset
class SearchFilter(BaseFilterBackend):
def filter_queryset(self, request, queryset, view):
"""
This performs a full-text search on the result set
Eg: ?search=James Joyce
"""
search_enabled = getattr(settings, "WAGTAILAPI_SEARCH_ENABLED", True)
if "search" in request.GET:
if not search_enabled:
raise BadRequestError("search is disabled")
# Searching and filtering by tag at the same time is not supported
if getattr(queryset, "_filtered_by_tag", False):
raise BadRequestError(
"filtering by tag with a search query is not supported"
)
search_query = request.GET["search"]
search_operator = request.GET.get("search_operator", None)
order_by_relevance = "order" not in request.GET
sb = get_search_backend()
try:
queryset = sb.search(
search_query,
queryset,
operator=search_operator,
order_by_relevance=order_by_relevance,
)
except FilterFieldError as e:
raise BadRequestError(
"cannot filter by '{}' while searching (field is not indexed)".format(
e.field_name
)
)
except OrderByFieldError as e:
raise BadRequestError(
"cannot order by '{}' while searching (field is not indexed)".format(
e.field_name
)
)
return queryset
class ChildOfFilter(BaseFilterBackend):
"""
Implements the ?child_of filter used to filter the results to only contain
pages that are direct children of the specified page.
"""
def filter_queryset(self, request, queryset, view):
if "child_of" in request.GET:
try:
parent_page_id = int(request.GET["child_of"])
if parent_page_id < 0:
raise ValueError()
parent_page = view.get_base_queryset().get(id=parent_page_id)
except ValueError:
if request.GET["child_of"] == "root":
parent_page = view.get_root_page()
else:
raise BadRequestError("child_of must be a positive integer")
except Page.DoesNotExist:
raise BadRequestError("parent page doesn't exist")
queryset = queryset.child_of(parent_page)
# Save the parent page on the queryset. This is required for the page
# explorer, which needs to pass the parent page into
# `construct_explorer_page_queryset` hook functions
queryset._filtered_by_child_of = parent_page
return queryset
class AncestorOfFilter(BaseFilterBackend):
"""
Implements the ?ancestor filter which limits the set of pages to a
particular branch of the page tree.
"""
def filter_queryset(self, request, queryset, view):
if "ancestor_of" in request.GET:
try:
descendant_page_id = int(request.GET["ancestor_of"])
if descendant_page_id < 0:
raise ValueError()
descendant_page = view.get_base_queryset().get(id=descendant_page_id)
except ValueError:
raise BadRequestError("ancestor_of must be a positive integer")
except Page.DoesNotExist:
raise BadRequestError("descendant page doesn't exist")
queryset = queryset.ancestor_of(descendant_page)
return queryset
class DescendantOfFilter(BaseFilterBackend):
"""
Implements the ?descendant_of filter which limits the set of pages to a
particular branch of the page tree.
"""
def filter_queryset(self, request, queryset, view):
if "descendant_of" in request.GET:
if hasattr(queryset, "_filtered_by_child_of"):
raise BadRequestError(
"filtering by descendant_of with child_of is not supported"
)
try:
parent_page_id = int(request.GET["descendant_of"])
if parent_page_id < 0:
raise ValueError()
parent_page = view.get_base_queryset().get(id=parent_page_id)
except ValueError:
if request.GET["descendant_of"] == "root":
parent_page = view.get_root_page()
else:
raise BadRequestError("descendant_of must be a positive integer")
except Page.DoesNotExist:
raise BadRequestError("ancestor page doesn't exist")
queryset = queryset.descendant_of(parent_page)
return queryset
class TranslationOfFilter(BaseFilterBackend):
"""
Implements the ?translation_of filter which limits the set of pages to translations
of a page.
"""
def filter_queryset(self, request, queryset, view):
if "translation_of" in request.GET:
try:
page_id = int(request.GET["translation_of"])
if page_id < 0:
raise ValueError()
page = view.get_base_queryset().get(id=page_id)
except ValueError:
if request.GET["translation_of"] == "root":
page = view.get_root_page()
else:
raise BadRequestError("translation_of must be a positive integer")
except Page.DoesNotExist:
raise BadRequestError("translation_of page doesn't exist")
_filtered_by_child_of = getattr(queryset, "_filtered_by_child_of", None)
queryset = queryset.translation_of(page)
if _filtered_by_child_of:
queryset._filtered_by_child_of = _filtered_by_child_of
return queryset
class LocaleFilter(BaseFilterBackend):
"""
Implements the ?locale filter which limits the set of pages to a
particular locale.
"""
def filter_queryset(self, request, queryset, view):
if "locale" in request.GET:
_filtered_by_child_of = getattr(queryset, "_filtered_by_child_of", None)
locale = get_object_or_404(Locale, language_code=request.GET["locale"])
queryset = queryset.filter(locale=locale)
if _filtered_by_child_of:
queryset._filtered_by_child_of = _filtered_by_child_of
return queryset

View File

@@ -0,0 +1,53 @@
from collections import OrderedDict
from django.conf import settings
from rest_framework.pagination import BasePagination
from rest_framework.response import Response
from .utils import BadRequestError
class WagtailPagination(BasePagination):
def paginate_queryset(self, queryset, request, view=None):
limit_max = getattr(settings, "WAGTAILAPI_LIMIT_MAX", 20)
try:
offset = int(request.GET.get("offset", 0))
if offset < 0:
raise ValueError()
except ValueError:
raise BadRequestError("offset must be a positive integer")
try:
limit_default = 20 if not limit_max else min(20, limit_max)
limit = int(request.GET.get("limit", limit_default))
if limit < 0:
raise ValueError()
except ValueError:
raise BadRequestError("limit must be a positive integer")
if limit_max and limit > limit_max:
raise BadRequestError("limit cannot be higher than %d" % limit_max)
start = offset
stop = offset + limit
self.view = view
self.total_count = queryset.count()
return queryset[start:stop]
def get_paginated_response(self, data):
data = OrderedDict(
[
(
"meta",
OrderedDict(
[
("total_count", self.total_count),
]
),
),
("items", data),
]
)
return Response(data)

View File

@@ -0,0 +1,95 @@
import functools
from django.urls import include, re_path
from wagtail.utils.urlpatterns import decorate_urlpatterns
class WagtailAPIRouter:
"""
A class that provides routing and cross-linking for a collection
of API endpoints
"""
def __init__(self, url_namespace):
self.url_namespace = url_namespace
self._endpoints = {}
def register_endpoint(self, name, class_):
self._endpoints[name] = class_
def get_model_endpoint(self, model):
"""
Finds the endpoint in the API that represents a model
Returns a (name, endpoint_class) tuple. Or None if an
endpoint is not found.
"""
for name, class_ in self._endpoints.items():
if issubclass(model, class_.model):
return name, class_
def get_model_listing_urlpath(self, model):
"""
Returns a URL path (excluding scheme and hostname) to the listing
page of a model
Returns None if the model is not represented by any endpoints.
"""
endpoint = self.get_model_endpoint(model)
if endpoint:
endpoint_name, endpoint_class = endpoint[0], endpoint[1]
url_namespace = self.url_namespace + ":" + endpoint_name
return endpoint_class.get_model_listing_urlpath(
model, namespace=url_namespace
)
def get_object_detail_urlpath(self, model, pk):
"""
Returns a URL path (excluding scheme and hostname) to the detail
page of an object.
Returns None if the object is not represented by any endpoints.
"""
endpoint = self.get_model_endpoint(model)
if endpoint:
endpoint_name, endpoint_class = endpoint[0], endpoint[1]
url_namespace = self.url_namespace + ":" + endpoint_name
return endpoint_class.get_object_detail_urlpath(
model, pk, namespace=url_namespace
)
def wrap_view(self, func):
@functools.wraps(func)
def wrapped(request, *args, **kwargs):
request.wagtailapi_router = self
return func(request, *args, **kwargs)
return wrapped
def get_urlpatterns(self):
urlpatterns = []
for name, class_ in self._endpoints.items():
pattern = re_path(
rf"^{name}/",
include((class_.get_urlpatterns(), name), namespace=name),
)
urlpatterns.append(pattern)
decorate_urlpatterns(urlpatterns, self.wrap_view)
return urlpatterns
@property
def urls(self):
"""
A shortcut to allow quick registration of the API in a URLconf.
Use with Django's include() function:
path('api/', include(myapi.urls)),
"""
return self.get_urlpatterns(), self.url_namespace, self.url_namespace

View File

@@ -0,0 +1,419 @@
from collections import OrderedDict
from django.urls.exceptions import NoReverseMatch
from modelcluster.models import get_all_child_relations
from rest_framework import relations, serializers
from rest_framework.fields import Field, SkipField
from taggit.managers import _TaggableManager
from wagtail import fields as wagtailcore_fields
from .utils import get_object_detail_url
class TypeField(Field):
"""
Serializes the "type" field of each object.
Example:
"type": "wagtailimages.Image"
"""
def get_attribute(self, instance):
return instance
def to_representation(self, obj):
name = type(obj)._meta.app_label + "." + type(obj).__name__
self.context["view"].seen_types[name] = type(obj)
return name
class DetailUrlField(Field):
"""
Serializes the "detail_url" field of each object.
Example:
"detail_url": "http://api.example.com/v1/images/1/"
"""
def get_attribute(self, instance):
url = get_object_detail_url(
self.context["router"], self.context["request"], type(instance), instance.pk
)
if url:
return url
else:
# Hide the detail_url field if the object doesn't have an endpoint
raise SkipField
def to_representation(self, url):
return url
class PageHtmlUrlField(Field):
"""
Serializes the "html_url" field for pages.
Example:
"html_url": "http://www.example.com/blog/blog-post/"
"""
def get_attribute(self, instance):
return instance
def to_representation(self, page):
try:
return page.full_url
except NoReverseMatch:
return None
class PageTypeField(Field):
"""
Serializes the "type" field for pages.
This takes into account the fact that we sometimes may not have the "specific"
page object by calling "page.specific_class" instead of looking at the object's
type.
Example:
"type": "blog.BlogPage"
"""
def get_attribute(self, instance):
return instance
def to_representation(self, page):
if page.specific_class is None:
return None
name = page.specific_class._meta.app_label + "." + page.specific_class.__name__
self.context["view"].seen_types[name] = page.specific_class
return name
class PageLocaleField(Field):
"""
Serializes the "locale" field for pages.
"""
def get_attribute(self, instance):
return instance
def to_representation(self, page):
return page.locale.language_code
class RelatedField(relations.RelatedField):
"""
Serializes related objects (eg, foreign keys).
Example:
"feed_image": {
"id": 1,
"meta": {
"type": "wagtailimages.Image",
"detail_url": "http://api.example.com/v1/images/1/"
}
}
"""
def __init__(self, *args, **kwargs):
self.serializer_class = kwargs.pop("serializer_class")
super().__init__(*args, **kwargs)
def to_representation(self, value):
serializer = self.serializer_class(context=self.context)
return serializer.to_representation(value)
class PageParentField(relations.RelatedField):
"""
Serializes the "parent" field on Page objects.
Pages don't have a "parent" field so some extra logic is needed to find the
parent page. That logic is implemented in this class.
The representation is the same as the RelatedField class.
"""
def get_attribute(self, instance):
parent = instance.get_parent()
if self.context["base_queryset"].filter(id=parent.id).exists():
return parent
def to_representation(self, value):
serializer_class = get_serializer_class(
value.__class__,
["id", "type", "detail_url", "html_url", "title"],
meta_fields=["type", "detail_url", "html_url"],
base=PageSerializer,
)
serializer = serializer_class(context=self.context)
return serializer.to_representation(value)
class PageAliasOfField(relations.RelatedField):
"""
Serializes the "alias_of" field on Page objects.
"""
def get_attribute(self, instance):
return instance.alias_of
def to_representation(self, value):
serializer_class = get_serializer_class(
value.__class__,
["id", "type", "detail_url", "html_url", "title"],
meta_fields=["type", "detail_url", "html_url"],
base=PageSerializer,
)
serializer = serializer_class(context=self.context)
return serializer.to_representation(value)
class ChildRelationField(Field):
"""
Serializes child relations.
Child relations are any model that is related to a Page using a ParentalKey.
They are used for repeated fields on a page such as carousel items or related
links.
Child objects are part of the pages content so we nest them. The relation is
represented as a list of objects.
Example:
"carousel_items": [
{
"id": 1,
"meta": {
"type": "demo.MyCarouselItem"
},
"title": "First carousel item",
"image": {
"id": 1,
"meta": {
"type": "wagtailimages.Image",
"detail_url": "http://api.example.com/v1/images/1/"
}
}
},
{
"id": 2,
"meta": {
"type": "demo.MyCarouselItem"
},
"title": "Second carousel item (no image)",
"image": null
}
]
"""
def __init__(self, *args, **kwargs):
self.serializer_class = kwargs.pop("serializer_class")
super().__init__(*args, **kwargs)
def to_representation(self, value):
serializer = self.serializer_class(context=self.context)
return [
serializer.to_representation(child_object) for child_object in value.all()
]
class StreamField(Field):
"""
Serializes StreamField values.
Stream fields are stored in JSON format in the database. We reuse that in
the API.
Example:
"body": [
{
"type": "heading",
"value": {
"text": "Hello world!",
"size": "h1"
}
},
{
"type": "paragraph",
"value": "Some content"
}
{
"type": "image",
"value": 1
}
]
Where "heading" is a struct block containing "text" and "size" fields, and
"paragraph" is a simple text block.
Note that foreign keys are represented slightly differently in stream fields
to other parts of the API. In stream fields, a foreign key is represented
by an integer (the ID of the related object) but elsewhere in the API,
foreign objects are nested objects with id and meta as attributes.
"""
def to_representation(self, value):
return value.stream_block.get_api_representation(value, self.context)
class TagsField(Field):
"""
Serializes django-taggit TaggableManager fields.
These fields are a common way to link tags to objects in Wagtail. The API
serializes these as a list of strings taken from the name attribute of each
tag.
Example:
"tags": ["bird", "wagtail"]
"""
def to_representation(self, value):
return list(value.all().order_by("name").values_list("name", flat=True))
class BaseSerializer(serializers.ModelSerializer):
# Add StreamField to serializer_field_mapping
serializer_field_mapping = (
serializers.ModelSerializer.serializer_field_mapping.copy()
)
serializer_field_mapping.update(
{
wagtailcore_fields.StreamField: StreamField,
}
)
serializer_related_field = RelatedField
# Meta fields
type = TypeField(read_only=True)
detail_url = DetailUrlField(read_only=True)
def to_representation(self, instance):
data = OrderedDict()
fields = [field for field in self.fields.values() if not field.write_only]
# Split meta fields from core fields
meta_fields = [
field for field in fields if field.field_name in self.meta_fields
]
fields = [field for field in fields if field.field_name not in self.meta_fields]
# Make sure id is always first. This will be filled in later
if "id" in [field.field_name for field in fields]:
data["id"] = None
# Serialise meta fields
meta = OrderedDict()
for field in meta_fields:
try:
attribute = field.get_attribute(instance)
except SkipField:
continue
if attribute is None:
# We skip `to_representation` for `None` values so that
# fields do not have to explicitly deal with that case.
meta[field.field_name] = None
else:
meta[field.field_name] = field.to_representation(attribute)
if meta:
data["meta"] = meta
# Serialise core fields
for field in fields:
try:
if field.field_name == "admin_display_title":
instance = instance.specific_deferred
attribute = field.get_attribute(instance)
except SkipField:
continue
if attribute is None:
# We skip `to_representation` for `None` values so that
# fields do not have to explicitly deal with that case.
data[field.field_name] = None
else:
data[field.field_name] = field.to_representation(attribute)
return data
def build_property_field(self, field_name, model_class):
# TaggableManager is not a Django field so it gets treated as a property
field = getattr(model_class, field_name)
if isinstance(field, _TaggableManager):
return TagsField, {}
return super().build_property_field(field_name, model_class)
def build_relational_field(self, field_name, relation_info):
field_class, field_kwargs = super().build_relational_field(
field_name, relation_info
)
field_kwargs["serializer_class"] = self.child_serializer_classes[field_name]
return field_class, field_kwargs
class PageSerializer(BaseSerializer):
type = PageTypeField(read_only=True)
locale = PageLocaleField(read_only=True)
html_url = PageHtmlUrlField(read_only=True)
parent = PageParentField(read_only=True)
alias_of = PageAliasOfField(read_only=True)
def build_relational_field(self, field_name, relation_info):
# Find all relation fields that point to child class and make them use
# the ChildRelationField class.
if relation_info.to_many:
model = getattr(self.Meta, "model")
child_relations = {
child_relation.field.remote_field.related_name: child_relation.related_model
for child_relation in get_all_child_relations(model)
}
if (
field_name in child_relations
and field_name in self.child_serializer_classes
):
return ChildRelationField, {
"serializer_class": self.child_serializer_classes[field_name]
}
return super().build_relational_field(field_name, relation_info)
def get_serializer_class(
model,
field_names,
meta_fields,
field_serializer_overrides=None,
child_serializer_classes=None,
base=BaseSerializer,
):
model_ = model
class Meta:
model = model_
fields = list(field_names)
attrs = {
"Meta": Meta,
"meta_fields": list(meta_fields),
"child_serializer_classes": child_serializer_classes or {},
}
if field_serializer_overrides:
attrs.update(field_serializer_overrides)
return type(str(model_.__name__ + "Serializer"), (base,), attrs)

View File

@@ -0,0 +1,61 @@
from django.db.models.signals import post_delete, post_save
from django.urls import reverse
from wagtail.contrib.frontend_cache.utils import purge_url_from_cache
from wagtail.documents import get_document_model
from wagtail.images import get_image_model
from wagtail.models import get_page_models
from wagtail.signals import page_published, page_unpublished
from .utils import get_base_url
def purge_page_from_cache(instance, **kwargs):
base_url = get_base_url()
purge_url_from_cache(
base_url + reverse("wagtailapi_v2:pages:detail", args=(instance.id,))
)
def purge_image_from_cache(instance, **kwargs):
if not kwargs.get("created", False):
base_url = get_base_url()
purge_url_from_cache(
base_url + reverse("wagtailapi_v2:images:detail", args=(instance.id,))
)
def purge_document_from_cache(instance, **kwargs):
if not kwargs.get("created", False):
base_url = get_base_url()
purge_url_from_cache(
base_url + reverse("wagtailapi_v2:documents:detail", args=(instance.id,))
)
def register_signal_handlers():
Image = get_image_model()
Document = get_document_model()
for model in get_page_models():
page_published.connect(purge_page_from_cache, sender=model)
page_unpublished.connect(purge_page_from_cache, sender=model)
post_save.connect(purge_image_from_cache, sender=Image)
post_delete.connect(purge_image_from_cache, sender=Image)
post_save.connect(purge_document_from_cache, sender=Document)
post_delete.connect(purge_document_from_cache, sender=Document)
def unregister_signal_handlers():
Image = get_image_model()
Document = get_document_model()
for model in get_page_models():
page_published.disconnect(purge_page_from_cache, sender=model)
page_unpublished.disconnect(purge_page_from_cache, sender=model)
post_save.disconnect(purge_image_from_cache, sender=Image)
post_delete.disconnect(purge_image_from_cache, sender=Image)
post_save.disconnect(purge_document_from_cache, sender=Document)
post_delete.disconnect(purge_document_from_cache, sender=Document)

View File

@@ -0,0 +1,615 @@
import json
from unittest import mock
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.urls import reverse
from wagtail.api.v2 import signal_handlers
from wagtail.documents import get_document_model
class TestDocumentListing(TestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:documents:listing"), params)
def get_document_id_list(self, content):
return [document["id"] for document in content["items"]]
# BASIC TESTS
def test_basic(self):
response = self.get_response()
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
# Check that the meta section is there
self.assertIn("meta", content)
self.assertIsInstance(content["meta"], dict)
# Check that the total count is there and correct
self.assertIn("total_count", content["meta"])
self.assertIsInstance(content["meta"]["total_count"], int)
self.assertEqual(
content["meta"]["total_count"], get_document_model().objects.count()
)
# Check that the items section is there
self.assertIn("items", content)
self.assertIsInstance(content["items"], list)
# Check that each document has a meta section with type and detail_url attributes
for document in content["items"]:
self.assertIn("meta", document)
self.assertIsInstance(document["meta"], dict)
self.assertEqual(
set(document["meta"].keys()),
{"type", "detail_url", "download_url", "tags"},
)
# Type should always be wagtaildocs.Document
self.assertEqual(document["meta"]["type"], "wagtaildocs.Document")
# Check detail_url
self.assertEqual(
document["meta"]["detail_url"],
"http://localhost/api/main/documents/%d/" % document["id"],
)
# Check download_url
self.assertTrue(
document["meta"]["download_url"].startswith(
"http://localhost/documents/%d/" % document["id"]
)
)
# FIELDS
def test_fields_default(self):
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
self.assertEqual(
set(document["meta"].keys()),
{"type", "detail_url", "download_url", "tags"},
)
def test_fields(self):
response = self.get_response(fields="title")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
self.assertEqual(
set(document["meta"].keys()),
{"type", "detail_url", "download_url", "tags"},
)
def test_remove_fields(self):
response = self.get_response(fields="-title")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta"})
def test_remove_meta_fields(self):
response = self.get_response(fields="-download_url")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
self.assertEqual(
set(document["meta"].keys()), {"type", "detail_url", "tags"}
)
def test_remove_all_meta_fields(self):
response = self.get_response(fields="-type,-detail_url,-tags,-download_url")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "title"})
def test_remove_id_field(self):
response = self.get_response(fields="-id")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"meta", "title"})
def test_all_fields(self):
response = self.get_response(fields="*")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
self.assertEqual(
set(document["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
def test_all_fields_then_remove_something(self):
response = self.get_response(fields="*,-title,-download_url")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertEqual(set(document.keys()), {"id", "meta"})
self.assertEqual(
set(document["meta"].keys()), {"type", "detail_url", "tags"}
)
def test_fields_tags(self):
response = self.get_response(fields="tags")
content = json.loads(response.content.decode("UTF-8"))
for document in content["items"]:
self.assertIsInstance(document["meta"]["tags"], list)
def test_star_in_wrong_position_gives_error(self):
response = self.get_response(fields="title,*")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "fields error: '*' must be in the first position"}
)
def test_fields_which_are_not_in_api_fields_gives_error(self):
response = self.get_response(fields="uploaded_by_user")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: uploaded_by_user"})
def test_fields_unknown_field_gives_error(self):
response = self.get_response(fields="123,title,abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_fields_remove_unknown_field_gives_error(self):
response = self.get_response(fields="-123,-title,-abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
# FILTERING
def test_filtering_exact_filter(self):
response = self.get_response(title="James Joyce")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [2])
def test_filtering_on_id(self):
response = self.get_response(id=10)
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [10])
def test_filtering_tags(self):
get_document_model().objects.get(id=3).tags.add("test")
response = self.get_response(tags="test")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [3])
def test_filtering_unknown_field_gives_error(self):
response = self.get_response(not_a_field="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content,
{
"message": "query parameter is not an operation or a recognised field: not_a_field"
},
)
# ORDERING
def test_ordering_by_title(self):
response = self.get_response(order="title")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [3, 12, 10, 2, 7, 8, 5, 4, 1, 11, 9, 6])
def test_ordering_by_title_backwards(self):
response = self.get_response(order="-title")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [6, 9, 11, 1, 4, 5, 8, 7, 2, 10, 12, 3])
def test_ordering_by_random(self):
response_1 = self.get_response(order="random")
content_1 = json.loads(response_1.content.decode("UTF-8"))
document_id_list_1 = self.get_document_id_list(content_1)
response_2 = self.get_response(order="random")
content_2 = json.loads(response_2.content.decode("UTF-8"))
document_id_list_2 = self.get_document_id_list(content_2)
self.assertNotEqual(document_id_list_1, document_id_list_2)
def test_ordering_by_random_backwards_gives_error(self):
response = self.get_response(order="-random")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "cannot order by 'random' (unknown field)"}
)
def test_ordering_by_random_with_offset_gives_error(self):
response = self.get_response(order="random", offset=10)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "random ordering with offset is not supported"}
)
def test_ordering_by_unknown_field_gives_error(self):
response = self.get_response(order="not_a_field")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "cannot order by 'not_a_field' (unknown field)"}
)
# LIMIT
def test_limit_only_two_items_returned(self):
response = self.get_response(limit=2)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(len(content["items"]), 2)
def test_limit_total_count(self):
response = self.get_response(limit=2)
content = json.loads(response.content.decode("UTF-8"))
# The total count must not be affected by "limit"
self.assertEqual(
content["meta"]["total_count"], get_document_model().objects.count()
)
def test_limit_not_integer_gives_error(self):
response = self.get_response(limit="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit must be a positive integer"})
def test_limit_too_high_gives_error(self):
response = self.get_response(limit=1000)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit cannot be higher than 20"})
@override_settings(WAGTAILAPI_LIMIT_MAX=None)
def test_limit_max_none_gives_no_errors(self):
response = self.get_response(limit=1000000)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(content["items"]), get_document_model().objects.count())
@override_settings(WAGTAILAPI_LIMIT_MAX=10)
def test_limit_maximum_can_be_changed(self):
response = self.get_response(limit=20)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit cannot be higher than 10"})
@override_settings(WAGTAILAPI_LIMIT_MAX=2)
def test_limit_default_changes_with_max(self):
# The default limit is 20. If WAGTAILAPI_LIMIT_MAX is less than that,
# the default should change accordingly.
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(len(content["items"]), 2)
# OFFSET
def test_offset_5_usually_appears_5th_in_list(self):
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list.index(5), 4)
def test_offset_5_moves_after_offset(self):
response = self.get_response(offset=4)
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list.index(5), 0)
def test_offset_total_count(self):
response = self.get_response(offset=10)
content = json.loads(response.content.decode("UTF-8"))
# The total count must not be affected by "offset"
self.assertEqual(
content["meta"]["total_count"], get_document_model().objects.count()
)
def test_offset_not_integer_gives_error(self):
response = self.get_response(offset="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "offset must be a positive integer"})
class TestDocumentListingSearch(TransactionTestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:documents:listing"), params)
def get_document_id_list(self, content):
return [document["id"] for document in content["items"]]
def test_search_for_james_joyce(self):
response = self.get_response(search="james")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(set(document_id_list), {2})
def test_search_with_order(self):
response = self.get_response(search="james", order="title")
content = json.loads(response.content.decode("UTF-8"))
document_id_list = self.get_document_id_list(content)
self.assertEqual(document_id_list, [2])
@override_settings(WAGTAILAPI_SEARCH_ENABLED=False)
def test_search_when_disabled_gives_error(self):
response = self.get_response(search="james")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "search is disabled"})
def test_search_when_filtering_by_tag_gives_error(self):
response = self.get_response(search="james", tags="wagtail")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content,
{"message": "filtering by tag with a search query is not supported"},
)
class TestDocumentDetail(TestCase):
fixtures = ["demosite.json"]
def get_response(self, image_id, **params):
return self.client.get(
reverse("wagtailapi_v2:documents:detail", args=(image_id,)), params
)
def test_basic(self):
response = self.get_response(1)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
# Check the id field
self.assertIn("id", content)
self.assertEqual(content["id"], 1)
# Check that the meta section is there
self.assertIn("meta", content)
self.assertIsInstance(content["meta"], dict)
# Check the meta type
self.assertIn("type", content["meta"])
self.assertEqual(content["meta"]["type"], "wagtaildocs.Document")
# Check the meta detail_url
self.assertIn("detail_url", content["meta"])
self.assertEqual(
content["meta"]["detail_url"], "http://localhost/api/main/documents/1/"
)
# Check the meta download_url
self.assertIn("download_url", content["meta"])
self.assertEqual(
content["meta"]["download_url"],
"http://localhost/documents/1/wagtail_by_markyharky.jpg",
)
# Check the title field
self.assertIn("title", content)
self.assertEqual(content["title"], "Wagtail by mark Harkin")
# Check the tags field
self.assertIn("tags", content["meta"])
self.assertEqual(content["meta"]["tags"], [])
def test_tags(self):
get_document_model().objects.get(id=1).tags.add("hello")
get_document_model().objects.get(id=1).tags.add("world")
response = self.get_response(1)
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("tags", content["meta"])
self.assertEqual(content["meta"]["tags"], ["hello", "world"])
@override_settings(WAGTAILAPI_BASE_URL="http://api.example.com/")
def test_download_url_with_custom_base_url(self):
response = self.get_response(1)
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("download_url", content["meta"])
self.assertEqual(
content["meta"]["download_url"],
"http://api.example.com/documents/1/wagtail_by_markyharky.jpg",
)
# FIELDS
def test_remove_fields(self):
response = self.get_response(2, fields="-title")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("id", set(content.keys()))
self.assertNotIn("title", set(content.keys()))
def test_remove_meta_fields(self):
response = self.get_response(2, fields="-download_url")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("detail_url", set(content["meta"].keys()))
self.assertNotIn("download_url", set(content["meta"].keys()))
def test_remove_id_field(self):
response = self.get_response(2, fields="-id")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("title", set(content.keys()))
self.assertNotIn("id", set(content.keys()))
def test_remove_all_fields(self):
response = self.get_response(2, fields="_,id,type")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(set(content.keys()), {"id", "meta"})
self.assertEqual(set(content["meta"].keys()), {"type"})
def test_star_in_wrong_position_gives_error(self):
response = self.get_response(2, fields="title,*")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "fields error: '*' must be in the first position"}
)
def test_fields_which_are_not_in_api_fields_gives_error(self):
response = self.get_response(2, fields="path")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: path"})
def test_fields_unknown_field_gives_error(self):
response = self.get_response(2, fields="123,title,abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_fields_remove_unknown_field_gives_error(self):
response = self.get_response(2, fields="-123,-title,-abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_nested_fields_on_non_relational_field_gives_error(self):
response = self.get_response(2, fields="title(foo,bar)")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "'title' does not support nested fields"})
class TestDocumentFind(TestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:documents:find"), params)
def test_without_parameters(self):
response = self.get_response()
self.assertEqual(response.status_code, 404)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(content, {"message": "not found"})
def test_find_by_id(self):
response = self.get_response(id=5)
self.assertRedirects(
response,
"http://localhost" + reverse("wagtailapi_v2:documents:detail", args=[5]),
fetch_redirect_response=False,
)
def test_find_by_id_nonexistent(self):
response = self.get_response(id=1234)
self.assertEqual(response.status_code, 404)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(content, {"message": "not found"})
@override_settings(
WAGTAILFRONTENDCACHE={
"varnish": {
"BACKEND": "wagtail.contrib.frontend_cache.backends.HTTPBackend",
"LOCATION": "http://localhost:8000",
},
},
WAGTAILAPI_BASE_URL="http://api.example.com",
)
@mock.patch("wagtail.contrib.frontend_cache.backends.http.HTTPBackend.purge")
class TestDocumentCacheInvalidation(TestCase):
fixtures = ["demosite.json"]
@classmethod
def setUpClass(cls):
super().setUpClass()
signal_handlers.register_signal_handlers()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
signal_handlers.unregister_signal_handlers()
def test_resave_document_purges(self, purge):
get_document_model().objects.get(id=5).save()
purge.assert_any_call("http://api.example.com/api/main/documents/5/")
def test_delete_document_purges(self, purge):
get_document_model().objects.get(id=5).delete()
purge.assert_any_call("http://api.example.com/api/main/documents/5/")

View File

@@ -0,0 +1,607 @@
import json
from unittest import mock
from django.test import TestCase, TransactionTestCase
from django.test.utils import override_settings
from django.urls import reverse
from wagtail.api.v2 import signal_handlers
from wagtail.images import get_image_model
class TestImageListing(TestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:images:listing"), params)
def get_image_id_list(self, content):
return [image["id"] for image in content["items"]]
# BASIC TESTS
def test_basic(self):
response = self.get_response()
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
# Check that the meta section is there
self.assertIn("meta", content)
self.assertIsInstance(content["meta"], dict)
# Check that the total count is there and correct
self.assertIn("total_count", content["meta"])
self.assertIsInstance(content["meta"]["total_count"], int)
self.assertEqual(
content["meta"]["total_count"], get_image_model().objects.count()
)
# Check that the items section is there
self.assertIn("items", content)
self.assertIsInstance(content["items"], list)
# Check that each image has a meta section with type and detail_url attributes
for image in content["items"]:
self.assertIn("meta", image)
self.assertIsInstance(image["meta"], dict)
self.assertEqual(
set(image["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
# Type should always be wagtailimages.Image
self.assertEqual(image["meta"]["type"], "wagtailimages.Image")
# Check detail url
self.assertEqual(
image["meta"]["detail_url"],
"http://localhost/api/main/images/%d/" % image["id"],
)
# FIELDS
def test_fields_default(self):
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
self.assertEqual(
set(image["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
def test_fields(self):
response = self.get_response(fields="width,height")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(
set(image.keys()), {"id", "meta", "title", "width", "height"}
)
self.assertEqual(
set(image["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
def test_remove_fields(self):
response = self.get_response(fields="-title")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "meta"})
def test_remove_meta_fields(self):
response = self.get_response(fields="-tags")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
self.assertEqual(
set(image["meta"].keys()), {"type", "detail_url", "download_url"}
)
def test_remove_all_meta_fields(self):
response = self.get_response(fields="-type,-detail_url,-tags")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "title", "meta"})
def test_remove_id_field(self):
response = self.get_response(fields="-id")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"meta", "title"})
def test_all_fields(self):
response = self.get_response(fields="*")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(
set(image.keys()), {"id", "meta", "title", "width", "height"}
)
self.assertEqual(
set(image["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
def test_all_fields_then_remove_something(self):
response = self.get_response(fields="*,-title,-tags")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "meta", "width", "height"})
self.assertEqual(
set(image["meta"].keys()), {"type", "detail_url", "download_url"}
)
def test_fields_tags(self):
response = self.get_response(fields="tags")
content = json.loads(response.content.decode("UTF-8"))
for image in content["items"]:
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
self.assertEqual(
set(image["meta"].keys()),
{"type", "detail_url", "tags", "download_url"},
)
self.assertIsInstance(image["meta"]["tags"], list)
def test_star_in_wrong_position_gives_error(self):
response = self.get_response(fields="title,*")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "fields error: '*' must be in the first position"}
)
def test_fields_which_are_not_in_api_fields_gives_error(self):
response = self.get_response(fields="uploaded_by_user")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: uploaded_by_user"})
def test_fields_unknown_field_gives_error(self):
response = self.get_response(fields="123,title,abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_fields_remove_unknown_field_gives_error(self):
response = self.get_response(fields="-123,-title,-abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
# FILTERING
def test_filtering_exact_filter(self):
response = self.get_response(title="James Joyce")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [5])
def test_filtering_on_id(self):
response = self.get_response(id=10)
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [10])
def test_filtering_tags(self):
get_image_model().objects.get(id=6).tags.add("test")
response = self.get_response(tags="test")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [6])
def test_filtering_unknown_field_gives_error(self):
response = self.get_response(not_a_field="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content,
{
"message": "query parameter is not an operation or a recognised field: not_a_field"
},
)
# ORDERING
def test_ordering_by_title(self):
response = self.get_response(order="title")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [6, 15, 13, 5, 10, 11, 8, 7, 4, 14, 12, 9])
def test_ordering_by_title_backwards(self):
response = self.get_response(order="-title")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [9, 12, 14, 4, 7, 8, 11, 10, 5, 13, 15, 6])
def test_ordering_by_random(self):
response_1 = self.get_response(order="random")
content_1 = json.loads(response_1.content.decode("UTF-8"))
image_id_list_1 = self.get_image_id_list(content_1)
response_2 = self.get_response(order="random")
content_2 = json.loads(response_2.content.decode("UTF-8"))
image_id_list_2 = self.get_image_id_list(content_2)
self.assertNotEqual(image_id_list_1, image_id_list_2)
def test_ordering_by_random_backwards_gives_error(self):
response = self.get_response(order="-random")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "cannot order by 'random' (unknown field)"}
)
def test_ordering_by_random_with_offset_gives_error(self):
response = self.get_response(order="random", offset=10)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "random ordering with offset is not supported"}
)
def test_ordering_by_unknown_field_gives_error(self):
response = self.get_response(order="not_a_field")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "cannot order by 'not_a_field' (unknown field)"}
)
# LIMIT
def test_limit_only_two_items_returned(self):
response = self.get_response(limit=2)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(len(content["items"]), 2)
def test_limit_total_count(self):
response = self.get_response(limit=2)
content = json.loads(response.content.decode("UTF-8"))
# The total count must not be affected by "limit"
self.assertEqual(
content["meta"]["total_count"], get_image_model().objects.count()
)
def test_limit_not_integer_gives_error(self):
response = self.get_response(limit="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit must be a positive integer"})
def test_limit_too_high_gives_error(self):
response = self.get_response(limit=1000)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit cannot be higher than 20"})
@override_settings(WAGTAILAPI_LIMIT_MAX=None)
def test_limit_max_none_gives_no_errors(self):
response = self.get_response(limit=1000000)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 200)
self.assertEqual(len(content["items"]), get_image_model().objects.count())
@override_settings(WAGTAILAPI_LIMIT_MAX=10)
def test_limit_maximum_can_be_changed(self):
response = self.get_response(limit=20)
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "limit cannot be higher than 10"})
@override_settings(WAGTAILAPI_LIMIT_MAX=2)
def test_limit_default_changes_with_max(self):
# The default limit is 20. If WAGTAILAPI_LIMIT_MAX is less than that,
# the default should change accordingly.
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(len(content["items"]), 2)
# OFFSET
def test_offset_10_usually_appears_7th_in_list(self):
response = self.get_response()
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list.index(10), 6)
def test_offset_10_moves_after_offset(self):
response = self.get_response(offset=4)
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list.index(10), 2)
def test_offset_total_count(self):
response = self.get_response(offset=10)
content = json.loads(response.content.decode("UTF-8"))
# The total count must not be affected by "offset"
self.assertEqual(
content["meta"]["total_count"], get_image_model().objects.count()
)
def test_offset_not_integer_gives_error(self):
response = self.get_response(offset="abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "offset must be a positive integer"})
class TestImageListingSearch(TransactionTestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:images:listing"), params)
def get_image_id_list(self, content):
return [image["id"] for image in content["items"]]
def test_search_for_james_joyce(self):
response = self.get_response(search="james")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(set(image_id_list), {5})
def test_search_with_order(self):
response = self.get_response(search="james", order="title")
content = json.loads(response.content.decode("UTF-8"))
image_id_list = self.get_image_id_list(content)
self.assertEqual(image_id_list, [5])
@override_settings(WAGTAILAPI_SEARCH_ENABLED=False)
def test_search_when_disabled_gives_error(self):
response = self.get_response(search="james")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "search is disabled"})
def test_search_when_filtering_by_tag_gives_error(self):
response = self.get_response(search="james", tags="wagtail")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content,
{"message": "filtering by tag with a search query is not supported"},
)
class TestImageDetail(TestCase):
fixtures = ["demosite.json"]
def get_response(self, image_id, **params):
return self.client.get(
reverse("wagtailapi_v2:images:detail", args=(image_id,)), params
)
def test_basic(self):
response = self.get_response(5)
self.assertEqual(response.status_code, 200)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
# Check the id field
self.assertIn("id", content)
self.assertEqual(content["id"], 5)
# Check that the meta section is there
self.assertIn("meta", content)
self.assertIsInstance(content["meta"], dict)
# Check the meta type
self.assertIn("type", content["meta"])
self.assertEqual(content["meta"]["type"], "wagtailimages.Image")
# Check the meta detail_url
self.assertIn("detail_url", content["meta"])
self.assertEqual(
content["meta"]["detail_url"], "http://localhost/api/main/images/5/"
)
# Check the title field
self.assertIn("title", content)
self.assertEqual(content["title"], "James Joyce")
# Check the width and height fields
self.assertIn("width", content)
self.assertIn("height", content)
self.assertEqual(content["width"], 500)
self.assertEqual(content["height"], 392)
# Check the tags field
self.assertIn("tags", content["meta"])
self.assertEqual(content["meta"]["tags"], [])
def test_tags(self):
image = get_image_model().objects.get(id=5)
image.tags.add("hello")
image.tags.add("world")
response = self.get_response(5)
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("tags", content["meta"])
self.assertEqual(content["meta"]["tags"], ["hello", "world"])
# FIELDS
def test_remove_fields(self):
response = self.get_response(5, fields="-title")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("id", set(content.keys()))
self.assertNotIn("title", set(content.keys()))
def test_remove_meta_fields(self):
response = self.get_response(5, fields="-type")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("detail_url", set(content["meta"].keys()))
self.assertNotIn("type", set(content["meta"].keys()))
def test_remove_id_field(self):
response = self.get_response(5, fields="-id")
content = json.loads(response.content.decode("UTF-8"))
self.assertIn("title", set(content.keys()))
self.assertNotIn("id", set(content.keys()))
def test_remove_all_fields(self):
response = self.get_response(5, fields="_,id,type")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(set(content.keys()), {"id", "meta"})
self.assertEqual(set(content["meta"].keys()), {"type"})
def test_star_in_wrong_position_gives_error(self):
response = self.get_response(5, fields="title,*")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(
content, {"message": "fields error: '*' must be in the first position"}
)
def test_fields_which_are_not_in_api_fields_gives_error(self):
response = self.get_response(5, fields="path")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: path"})
def test_fields_unknown_field_gives_error(self):
response = self.get_response(5, fields="123,title,abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_fields_remove_unknown_field_gives_error(self):
response = self.get_response(5, fields="-123,-title,-abc")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
def test_nested_fields_on_non_relational_field_gives_error(self):
response = self.get_response(5, fields="title(foo,bar)")
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(response.status_code, 400)
self.assertEqual(content, {"message": "'title' does not support nested fields"})
class TestImageFind(TestCase):
fixtures = ["demosite.json"]
def get_response(self, **params):
return self.client.get(reverse("wagtailapi_v2:images:find"), params)
def test_without_parameters(self):
response = self.get_response()
self.assertEqual(response.status_code, 404)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(content, {"message": "not found"})
def test_find_by_id(self):
response = self.get_response(id=5)
self.assertRedirects(
response,
"http://localhost" + reverse("wagtailapi_v2:images:detail", args=[5]),
fetch_redirect_response=False,
)
def test_find_by_id_nonexistent(self):
response = self.get_response(id=1234)
self.assertEqual(response.status_code, 404)
self.assertEqual(response["Content-type"], "application/json")
# Will crash if the JSON is invalid
content = json.loads(response.content.decode("UTF-8"))
self.assertEqual(content, {"message": "not found"})
@override_settings(
WAGTAILFRONTENDCACHE={
"varnish": {
"BACKEND": "wagtail.contrib.frontend_cache.backends.HTTPBackend",
"LOCATION": "http://localhost:8000",
},
},
WAGTAILAPI_BASE_URL="http://api.example.com",
)
@mock.patch("wagtail.contrib.frontend_cache.backends.http.HTTPBackend.purge")
class TestImageCacheInvalidation(TestCase):
fixtures = ["demosite.json"]
@classmethod
def setUpClass(cls):
super().setUpClass()
signal_handlers.register_signal_handlers()
@classmethod
def tearDownClass(cls):
super().tearDownClass()
signal_handlers.unregister_signal_handlers()
def test_resave_image_purges(self, purge):
get_image_model().objects.get(id=5).save()
purge.assert_any_call("http://api.example.com/api/main/images/5/")
def test_delete_image_purges(self, purge):
get_image_model().objects.get(id=5).delete()
purge.assert_any_call("http://api.example.com/api/main/images/5/")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,401 @@
from django.test import RequestFactory, TestCase, override_settings
from django.utils.encoding import force_bytes
from wagtail.models import Site
from ..utils import (
FieldsParameterParseError,
get_base_url,
parse_boolean,
parse_fields_parameter,
)
class DynamicBaseUrl:
def __str__(self):
return "https://www.example.com"
def __bytes__(self):
return force_bytes(self.__str__())
def decode(self, *args, **kwargs):
return self.__bytes__().decode(*args, **kwargs)
class TestGetBaseUrl(TestCase):
def setUp(self):
Site.objects.all().delete()
def prepare_site(self):
return Site.objects.get_or_create(
hostname="other.example.com",
port=8080,
root_page_id=1,
is_default_site=True,
)[0]
def clear_cached_site(self, request):
del request._wagtail_site
def test_get_base_url_unset(self):
self.assertIsNone(get_base_url())
def test_get_base_url_from_request(self):
# base url for siteless request should be None
request = RequestFactory().get("/")
self.assertIsNone(Site.find_for_request(request))
self.assertIsNone(get_base_url(request))
# base url for request with a site should be based on the site's details
site = self.prepare_site()
self.clear_cached_site(request)
self.assertEqual(site, Site.find_for_request(request))
self.assertEqual(get_base_url(request), "http://other.example.com:8080")
# port 443 should indicate https without a port
site.port = 443
site.save()
self.clear_cached_site(request)
self.assertEqual(get_base_url(request), "https://other.example.com")
# port 80 should indicate http without a port
site.port = 80
site.save()
self.clear_cached_site(request)
self.assertEqual(get_base_url(request), "http://other.example.com")
@override_settings(WAGTAILAPI_BASE_URL="https://bar.example.com")
def test_get_base_url_prefers_setting(self):
request = RequestFactory().get("/")
site = self.prepare_site()
self.assertEqual(site, Site.find_for_request(request))
self.assertEqual(get_base_url(request), "https://bar.example.com")
with override_settings(WAGTAILAPI_BASE_URL=None):
self.assertEqual(get_base_url(request), "http://other.example.com:8080")
@override_settings(WAGTAILAPI_BASE_URL="https://bar.example.com")
def test_get_base_url_from_setting_string(self):
self.assertEqual(get_base_url(), "https://bar.example.com")
@override_settings(WAGTAILAPI_BASE_URL=b"https://baz.example.com")
def test_get_base_url_from_setting_bytes(self):
self.assertEqual(get_base_url(), "https://baz.example.com")
@override_settings(WAGTAILAPI_BASE_URL=DynamicBaseUrl())
def test_get_base_url_from_setting_object(self):
self.assertEqual(get_base_url(), "https://www.example.com")
class TestParseFieldsParameter(TestCase):
# GOOD STUFF
def test_valid_single_field(self):
parsed = parse_fields_parameter("test")
self.assertEqual(
parsed,
[
("test", False, None),
],
)
def test_valid_multiple_fields(self):
parsed = parse_fields_parameter("test,another_test")
self.assertEqual(
parsed,
[
("test", False, None),
("another_test", False, None),
],
)
def test_valid_negated_field(self):
parsed = parse_fields_parameter("-test")
self.assertEqual(
parsed,
[
("test", True, None),
],
)
def test_valid_nested_fields(self):
parsed = parse_fields_parameter("test(foo,bar)")
self.assertEqual(
parsed,
[
(
"test",
False,
[
("foo", False, None),
("bar", False, None),
],
),
],
)
def test_valid_star_field(self):
parsed = parse_fields_parameter("*,-test")
self.assertEqual(
parsed,
[
("*", False, None),
("test", True, None),
],
)
def test_valid_star_with_additional_field(self):
# Note: '*,test' is not allowed but '*,test(foo)' is
parsed = parse_fields_parameter("*,test(foo)")
self.assertEqual(
parsed,
[
("*", False, None),
(
"test",
False,
[
("foo", False, None),
],
),
],
)
def test_valid_underscore_field(self):
parsed = parse_fields_parameter("_,test")
self.assertEqual(
parsed,
[
("_", False, None),
("test", False, None),
],
)
def test_valid_field_with_underscore_in_middle(self):
parsed = parse_fields_parameter("a_test")
self.assertEqual(
parsed,
[
("a_test", False, None),
],
)
def test_valid_negated_field_with_underscore_in_middle(self):
parsed = parse_fields_parameter("-a_test")
self.assertEqual(
parsed,
[
("a_test", True, None),
],
)
def test_valid_field_with_underscore_at_beginning(self):
parsed = parse_fields_parameter("_test")
self.assertEqual(
parsed,
[
("_test", False, None),
],
)
def test_valid_field_with_underscore_at_end(self):
parsed = parse_fields_parameter("test_")
self.assertEqual(
parsed,
[
("test_", False, None),
],
)
# BAD STUFF
def test_invalid_char(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test#")
self.assertEqual(str(e.exception), "unexpected char '#' at position 4")
def test_invalid_whitespace_before_identifier(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter(" test")
self.assertEqual(str(e.exception), "unexpected whitespace at position 0")
def test_invalid_whitespace_after_identifier(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test ")
self.assertEqual(str(e.exception), "unexpected whitespace at position 4")
def test_invalid_whitespace_after_comma(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test, test")
self.assertEqual(str(e.exception), "unexpected whitespace at position 5")
def test_invalid_whitespace_before_comma(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test ,test")
self.assertEqual(str(e.exception), "unexpected whitespace at position 4")
def test_invalid_unexpected_negation_operator(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test-")
self.assertEqual(str(e.exception), "unexpected char '-' at position 4")
def test_invalid_unexpected_open_bracket(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test,(foo)")
self.assertEqual(str(e.exception), "unexpected char '(' at position 5")
def test_invalid_unexpected_close_bracket(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test)")
self.assertEqual(str(e.exception), "unexpected char ')' at position 4")
def test_invalid_unexpected_comma_in_middle(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test,,foo")
self.assertEqual(str(e.exception), "unexpected char ',' at position 5")
def test_invalid_unexpected_comma_at_end(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test,foo,")
self.assertEqual(str(e.exception), "unexpected char ',' at position 9")
def test_invalid_unclosed_bracket(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test(foo")
self.assertEqual(
str(e.exception),
"unexpected end of input (did you miss out a close bracket?)",
)
def test_invalid_subfields_on_negated_field(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("-test(foo)")
self.assertEqual(str(e.exception), "unexpected char '(' at position 5")
def test_invalid_star_field_in_wrong_position(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test,*")
self.assertEqual(str(e.exception), "'*' must be in the first position")
def test_invalid_negated_star(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("-*")
self.assertEqual(str(e.exception), "'*' cannot be negated")
def test_invalid_star_with_nesting(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("*(foo,bar)")
self.assertEqual(str(e.exception), "unexpected char '(' at position 1")
def test_invalid_star_with_chars_after(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("*foo")
self.assertEqual(str(e.exception), "unexpected char 'f' at position 1")
def test_invalid_star_with_chars_before(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("foo*")
self.assertEqual(str(e.exception), "unexpected char '*' at position 3")
def test_invalid_star_with_additional_field(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("*,foo")
self.assertEqual(
str(e.exception), "additional fields with '*' doesn't make sense"
)
def test_invalid_underscore_in_wrong_position(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("test,_")
self.assertEqual(str(e.exception), "'_' must be in the first position")
def test_invalid_negated_underscore(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("-_")
self.assertEqual(str(e.exception), "'_' cannot be negated")
def test_invalid_underscore_with_nesting(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("_(foo,bar)")
self.assertEqual(str(e.exception), "unexpected char '(' at position 1")
def test_invalid_underscore_with_negated_field(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("_,-foo")
self.assertEqual(str(e.exception), "negated fields with '_' doesn't make sense")
def test_invalid_star_and_underscore(self):
with self.assertRaises(FieldsParameterParseError) as e:
parse_fields_parameter("*,_")
self.assertEqual(str(e.exception), "'_' must be in the first position")
class TestParseBoolean(TestCase):
# GOOD STUFF
def test_valid_true(self):
parsed = parse_boolean("true")
self.assertIs(parsed, True)
def test_valid_false(self):
parsed = parse_boolean("false")
self.assertIs(parsed, False)
def test_valid_1(self):
parsed = parse_boolean("1")
self.assertIs(parsed, True)
def test_valid_0(self):
parsed = parse_boolean("0")
self.assertIs(parsed, False)
# BAD STUFF
def test_invalid(self):
with self.assertRaises(ValueError) as e:
parse_boolean("foo")
self.assertEqual(str(e.exception), "expected 'true' or 'false', got 'foo'")
def test_invalid_integer(self):
with self.assertRaises(ValueError) as e:
parse_boolean("2")
self.assertEqual(str(e.exception), "expected 'true' or 'false', got '2'")

View File

@@ -0,0 +1,270 @@
from urllib.parse import urlsplit
from django.conf import settings
from django.utils.encoding import force_str
from wagtail.coreutils import resolve_model_string
from wagtail.models import Page, Site
class BadRequestError(Exception):
pass
def get_base_url(request=None):
base_url = getattr(settings, "WAGTAILAPI_BASE_URL", None)
if base_url is None and request:
site = Site.find_for_request(request)
if site:
base_url = site.root_url
if base_url:
# We only want the scheme and netloc
base_url_parsed = urlsplit(force_str(base_url))
return base_url_parsed.scheme + "://" + base_url_parsed.netloc
def get_full_url(request, path):
if path.startswith(("http://", "https://")):
return path
base_url = get_base_url(request) or ""
return base_url + path
def get_object_detail_url(router, request, model, pk):
url_path = router.get_object_detail_urlpath(model, pk)
if url_path:
return get_full_url(request, url_path)
def page_models_from_string(string):
page_models = []
for sub_string in string.split(","):
page_model = resolve_model_string(sub_string)
if not issubclass(page_model, Page):
raise ValueError("Model is not a page")
page_models.append(page_model)
return tuple(page_models)
class FieldsParameterParseError(ValueError):
pass
def parse_fields_parameter(fields_str):
"""
Parses the ?fields= GET parameter. As this parameter is supposed to be used
by developers, the syntax is quite tight (eg, not allowing any whitespace).
Having a strict syntax allows us to extend the it at a later date with less
chance of breaking anyone's code.
This function takes a string and returns a list of tuples representing each
top-level field. Each tuple contains three items:
- The name of the field (string)
- Whether the field has been negated (boolean)
- A list of nested fields if there are any, None otherwise
Some examples of how this function works:
>>> parse_fields_parameter("foo")
[
('foo', False, None),
]
>>> parse_fields_parameter("foo,bar")
[
('foo', False, None),
('bar', False, None),
]
>>> parse_fields_parameter("-foo")
[
('foo', True, None),
]
>>> parse_fields_parameter("foo(bar,baz)")
[
('foo', False, [
('bar', False, None),
('baz', False, None),
]),
]
It raises a FieldsParameterParseError (subclass of ValueError) if it
encounters a syntax error
"""
def get_position(current_str):
return len(fields_str) - len(current_str)
def parse_field_identifier(fields_str):
first_char = True
negated = False
ident = ""
while fields_str:
char = fields_str[0]
if char in ["(", ")", ","]:
if not ident:
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
if ident in ["*", "_"] and char == "(":
# * and _ cannot have nested fields
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
return ident, negated, fields_str
elif char == "-":
if not first_char:
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
negated = True
elif char in ["*", "_"]:
if ident and char == "*":
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
ident += char
elif char.isalnum() or char == "_":
if ident == "*":
# * can only be on its own
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
ident += char
elif char.isspace():
raise FieldsParameterParseError(
"unexpected whitespace at position %d" % get_position(fields_str)
)
else:
raise FieldsParameterParseError(
"unexpected char '%s' at position %d"
% (char, get_position(fields_str))
)
first_char = False
fields_str = fields_str[1:]
return ident, negated, fields_str
def parse_fields(fields_str, expect_close_bracket=False):
first_ident = None
is_first = True
fields = []
while fields_str:
sub_fields = None
ident, negated, fields_str = parse_field_identifier(fields_str)
# Some checks specific to '*' and '_'
if ident in ["*", "_"]:
if not is_first:
raise FieldsParameterParseError(
"'%s' must be in the first position" % ident
)
if negated:
raise FieldsParameterParseError("'%s' cannot be negated" % ident)
if fields_str and fields_str[0] == "(":
if negated:
# Negated fields cannot contain subfields
raise FieldsParameterParseError(
"unexpected char '(' at position %d" % get_position(fields_str)
)
sub_fields, fields_str = parse_fields(
fields_str[1:], expect_close_bracket=True
)
if is_first:
first_ident = ident
else:
# Negated fields can't be used with '_'
if first_ident == "_" and negated:
# _,foo is allowed but _,-foo is not
raise FieldsParameterParseError(
"negated fields with '_' doesn't make sense"
)
# Additional fields without sub fields can't be used with '*'
if first_ident == "*" and not negated and not sub_fields:
# *,foo(bar) and *,-foo are allowed but *,foo is not
raise FieldsParameterParseError(
"additional fields with '*' doesn't make sense"
)
fields.append((ident, negated, sub_fields))
if fields_str and fields_str[0] == ")":
if not expect_close_bracket:
raise FieldsParameterParseError(
"unexpected char ')' at position %d" % get_position(fields_str)
)
return fields, fields_str[1:]
if fields_str and fields_str[0] == ",":
fields_str = fields_str[1:]
# A comma can not exist immediately before another comma or the end of the string
if not fields_str or fields_str[0] == ",":
raise FieldsParameterParseError(
"unexpected char ',' at position %d" % get_position(fields_str)
)
is_first = False
if expect_close_bracket:
# This parser should've exited with a close bracket but instead we
# hit the end of the input. Raise an error
raise FieldsParameterParseError(
"unexpected end of input (did you miss out a close bracket?)"
)
return fields, fields_str
fields, _ = parse_fields(fields_str)
return fields
def parse_boolean(value):
"""
Parses strings into booleans using the following mapping (case-sensitive):
'true' => True
'false' => False
'1' => True
'0' => False
"""
if value in ["true", "1"]:
return True
elif value in ["false", "0"]:
return False
else:
raise ValueError("expected 'true' or 'false', got '%s'" % value)

View File

@@ -0,0 +1,607 @@
from collections import OrderedDict
from django.conf import settings
from django.core.exceptions import FieldDoesNotExist
from django.http import Http404
from django.shortcuts import redirect
from django.urls import path, reverse
from modelcluster.fields import ParentalKey
from rest_framework import status
from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer
from rest_framework.response import Response
from rest_framework.viewsets import GenericViewSet
from wagtail.api import APIField
from wagtail.models import Page, PageViewRestriction, Site
from .filters import (
AncestorOfFilter,
ChildOfFilter,
DescendantOfFilter,
FieldsFilter,
LocaleFilter,
OrderingFilter,
SearchFilter,
TranslationOfFilter,
)
from .pagination import WagtailPagination
from .serializers import BaseSerializer, PageSerializer, get_serializer_class
from .utils import (
BadRequestError,
get_object_detail_url,
page_models_from_string,
parse_fields_parameter,
)
class BaseAPIViewSet(GenericViewSet):
renderer_classes = [JSONRenderer, BrowsableAPIRenderer]
pagination_class = WagtailPagination
base_serializer_class = BaseSerializer
filter_backends = []
model = None # Set on subclass
known_query_parameters = frozenset(
[
"limit",
"offset",
"fields",
"order",
"search",
"search_operator",
# Used by jQuery for cache-busting. See #1671
"_",
# Required by BrowsableAPIRenderer
"format",
]
)
body_fields = ["id"]
meta_fields = ["type", "detail_url"]
listing_default_fields = ["id", "type", "detail_url"]
nested_default_fields = ["id", "type", "detail_url"]
detail_only_fields = []
name = None # Set on subclass.
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# seen_types is a mapping of type name strings (format: "app_label.ModelName")
# to model classes. When an object is serialised in the API, its model
# is added to this mapping. This is used by the Admin API which appends a
# summary of the used types to the response.
self.seen_types = OrderedDict()
def get_queryset(self):
return self.model.objects.all().order_by("id")
def listing_view(self, request):
queryset = self.get_queryset()
self.check_query_parameters(queryset)
queryset = self.filter_queryset(queryset)
queryset = self.paginate_queryset(queryset)
serializer = self.get_serializer(queryset, many=True)
return self.get_paginated_response(serializer.data)
def detail_view(self, request, pk):
instance = self.get_object()
serializer = self.get_serializer(instance)
return Response(serializer.data)
def find_view(self, request):
queryset = self.get_queryset()
try:
obj = self.find_object(queryset, request)
if obj is None:
raise self.model.DoesNotExist
except self.model.DoesNotExist:
raise Http404("not found")
# Generate redirect
url = get_object_detail_url(
self.request.wagtailapi_router, request, self.model, obj.pk
)
if url is None:
# Shouldn't happen unless this endpoint isn't actually installed in the router
raise Exception(
"Cannot generate URL to detail view. Is '{}' installed in the API router?".format(
self.__class__.__name__
)
)
return redirect(url)
def find_object(self, queryset, request):
"""
Override this to implement more find methods.
"""
if "id" in request.GET:
return queryset.get(id=request.GET["id"])
def handle_exception(self, exc):
if isinstance(exc, Http404):
data = {"message": str(exc)}
return Response(data, status=status.HTTP_404_NOT_FOUND)
elif isinstance(exc, BadRequestError):
data = {"message": str(exc)}
return Response(data, status=status.HTTP_400_BAD_REQUEST)
return super().handle_exception(exc)
@classmethod
def _convert_api_fields(cls, fields):
return [
field if isinstance(field, APIField) else APIField(field)
for field in fields
]
@classmethod
def get_body_fields(cls, model):
return cls._convert_api_fields(
cls.body_fields + list(getattr(model, "api_fields", ()))
)
@classmethod
def get_body_fields_names(cls, model):
return [field.name for field in cls.get_body_fields(model)]
@classmethod
def get_meta_fields(cls, model):
return cls._convert_api_fields(
cls.meta_fields + list(getattr(model, "api_meta_fields", ()))
)
@classmethod
def get_meta_fields_names(cls, model):
return [field.name for field in cls.get_meta_fields(model)]
@classmethod
def get_field_serializer_overrides(cls, model):
return {
field.name: field.serializer
for field in cls.get_body_fields(model) + cls.get_meta_fields(model)
if field.serializer is not None
}
@classmethod
def get_available_fields(cls, model, db_fields_only=False):
"""
Returns a list of all the fields that can be used in the API for the
specified model class.
Setting db_fields_only to True will remove all fields that do not have
an underlying column in the database (eg, type/detail_url and any custom
fields that are callables)
"""
fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model)
if db_fields_only:
# Get list of available database fields then remove any fields in our
# list that isn't a database field
database_fields = set()
for field in model._meta.get_fields():
database_fields.add(field.name)
if hasattr(field, "attname"):
database_fields.add(field.attname)
fields = [field for field in fields if field in database_fields]
return fields
@classmethod
def get_detail_default_fields(cls, model):
return cls.get_available_fields(model)
@classmethod
def get_listing_default_fields(cls, model):
return cls.listing_default_fields[:]
@classmethod
def get_nested_default_fields(cls, model):
return cls.nested_default_fields[:]
def check_query_parameters(self, queryset):
"""
Ensure that only valid query parameters are included in the URL.
"""
query_parameters = set(self.request.GET.keys())
# All query parameters must be either a database field or an operation
allowed_query_parameters = set(
self.get_available_fields(queryset.model, db_fields_only=True)
).union(self.known_query_parameters)
unknown_parameters = query_parameters - allowed_query_parameters
if unknown_parameters:
raise BadRequestError(
"query parameter is not an operation or a recognised field: %s"
% ", ".join(sorted(unknown_parameters))
)
@classmethod
def _get_serializer_class(
cls, router, model, fields_config, show_details=False, nested=False
):
# Get all available fields
body_fields = cls.get_body_fields_names(model)
meta_fields = cls.get_meta_fields_names(model)
all_fields = body_fields + meta_fields
# Remove any duplicates
all_fields = list(OrderedDict.fromkeys(all_fields))
if not show_details:
# Remove detail only fields
for field in cls.detail_only_fields:
try:
all_fields.remove(field)
except ValueError:
pass
# Get list of configured fields
if show_details:
fields = set(cls.get_detail_default_fields(model))
elif nested:
fields = set(cls.get_nested_default_fields(model))
else:
fields = set(cls.get_listing_default_fields(model))
# If first field is '*' start with all fields
# If first field is '_' start with no fields
if fields_config and fields_config[0][0] == "*":
fields = set(all_fields)
fields_config = fields_config[1:]
elif fields_config and fields_config[0][0] == "_":
fields = set()
fields_config = fields_config[1:]
mentioned_fields = set()
sub_fields = {}
for field_name, negated, field_sub_fields in fields_config:
if negated:
try:
fields.remove(field_name)
except KeyError:
pass
else:
fields.add(field_name)
if field_sub_fields:
sub_fields[field_name] = field_sub_fields
mentioned_fields.add(field_name)
unknown_fields = mentioned_fields - set(all_fields)
if unknown_fields:
raise BadRequestError(
"unknown fields: %s" % ", ".join(sorted(unknown_fields))
)
# Build nested serialisers
child_serializer_classes = {}
for field_name in fields:
try:
django_field = model._meta.get_field(field_name)
except FieldDoesNotExist:
django_field = None
if django_field and django_field.is_relation:
child_sub_fields = sub_fields.get(field_name, [])
# Inline (aka "child") models should display all fields by default
if isinstance(getattr(django_field, "field", None), ParentalKey):
if not child_sub_fields or child_sub_fields[0][0] not in ["*", "_"]:
child_sub_fields = list(child_sub_fields)
child_sub_fields.insert(0, ("*", False, None))
# Get a serializer class for the related object
child_model = django_field.related_model
child_endpoint_class = router.get_model_endpoint(child_model)
child_endpoint_class = (
child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet
)
child_serializer_classes[
field_name
] = child_endpoint_class._get_serializer_class(
router, child_model, child_sub_fields, nested=True
)
else:
if field_name in sub_fields:
# Sub fields were given for a non-related field
raise BadRequestError(
"'%s' does not support nested fields" % field_name
)
# Reorder fields so it matches the order of all_fields
fields = [field for field in all_fields if field in fields]
field_serializer_overrides = {
field[0]: field[1]
for field in cls.get_field_serializer_overrides(model).items()
if field[0] in fields
}
return get_serializer_class(
model,
fields,
meta_fields=meta_fields,
field_serializer_overrides=field_serializer_overrides,
child_serializer_classes=child_serializer_classes,
base=cls.base_serializer_class,
)
def get_serializer_class(self):
request = self.request
# Get model
if self.action == "listing_view":
model = self.get_queryset().model
else:
model = type(self.get_object())
# Fields
if "fields" in request.GET:
try:
fields_config = parse_fields_parameter(request.GET["fields"])
except ValueError as e:
raise BadRequestError("fields error: %s" % str(e))
else:
# Use default fields
fields_config = []
# Allow "detail_only" (eg parent) fields on detail view
if self.action == "listing_view":
show_details = False
else:
show_details = True
return self._get_serializer_class(
self.request.wagtailapi_router,
model,
fields_config,
show_details=show_details,
)
def get_serializer_context(self):
"""
The serialization context differs between listing and detail views.
"""
return {
"request": self.request,
"view": self,
"router": self.request.wagtailapi_router,
}
def get_renderer_context(self):
context = super().get_renderer_context()
context["indent"] = 4
return context
@classmethod
def get_urlpatterns(cls):
"""
This returns a list of URL patterns for the endpoint
"""
return [
path("", cls.as_view({"get": "listing_view"}), name="listing"),
path("<int:pk>/", cls.as_view({"get": "detail_view"}), name="detail"),
path("find/", cls.as_view({"get": "find_view"}), name="find"),
]
@classmethod
def get_model_listing_urlpath(cls, model, namespace=""):
if namespace:
url_name = namespace + ":listing"
else:
url_name = "listing"
return reverse(url_name)
@classmethod
def get_object_detail_urlpath(cls, model, pk, namespace=""):
if namespace:
url_name = namespace + ":detail"
else:
url_name = "detail"
return reverse(url_name, args=(pk,))
class PagesAPIViewSet(BaseAPIViewSet):
base_serializer_class = PageSerializer
filter_backends = [
FieldsFilter,
ChildOfFilter,
AncestorOfFilter,
DescendantOfFilter,
OrderingFilter,
TranslationOfFilter,
LocaleFilter,
SearchFilter, # needs to be last, as SearchResults querysets cannot be filtered further
]
known_query_parameters = BaseAPIViewSet.known_query_parameters.union(
[
"type",
"child_of",
"ancestor_of",
"descendant_of",
"translation_of",
"locale",
"site",
]
)
body_fields = BaseAPIViewSet.body_fields + [
"title",
]
meta_fields = BaseAPIViewSet.meta_fields + [
"html_url",
"slug",
"show_in_menus",
"seo_title",
"search_description",
"first_published_at",
"alias_of",
"parent",
"locale",
]
listing_default_fields = BaseAPIViewSet.listing_default_fields + [
"title",
"html_url",
"slug",
"first_published_at",
]
nested_default_fields = BaseAPIViewSet.nested_default_fields + [
"title",
]
detail_only_fields = ["parent"]
name = "pages"
model = Page
@classmethod
def get_detail_default_fields(cls, model):
detail_default_fields = super().get_detail_default_fields(model)
# When i18n is disabled, remove "locale" from default fields
if not getattr(settings, "WAGTAIL_I18N_ENABLED", False):
detail_default_fields.remove("locale")
return detail_default_fields
@classmethod
def get_listing_default_fields(cls, model):
listing_default_fields = super().get_listing_default_fields(model)
# When i18n is enabled, add "locale" to default fields
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
listing_default_fields.append("locale")
return listing_default_fields
def get_root_page(self):
"""
Returns the page that is used when the `&child_of=root` filter is used.
"""
return Site.find_for_request(self.request).root_page
def get_base_queryset(self):
"""
Returns a queryset containing all pages that can be seen by this user.
This is used as the base for get_queryset and is also used to find the
parent pages when using the child_of and descendant_of filters as well.
"""
request = self.request
# Get all live pages
queryset = Page.objects.all().live()
# Exclude pages that the user doesn't have access to
restricted_pages = [
restriction.page
for restriction in PageViewRestriction.objects.all().select_related("page")
if not restriction.accept_request(self.request)
]
# Exclude the restricted pages and their descendants from the queryset
for restricted_page in restricted_pages:
queryset = queryset.not_descendant_of(restricted_page, inclusive=True)
# Check if we have a specific site to look for
if "site" in request.GET:
# Optionally allow querying by port
if ":" in request.GET["site"]:
(hostname, port) = request.GET["site"].split(":", 1)
query = {
"hostname": hostname,
"port": port,
}
else:
query = {
"hostname": request.GET["site"],
}
try:
site = Site.objects.get(**query)
except Site.MultipleObjectsReturned:
raise BadRequestError(
"Your query returned multiple sites. Try adding a port number to your site filter."
)
else:
# Otherwise, find the site from the request
site = Site.find_for_request(self.request)
if site:
base_queryset = queryset
queryset = base_queryset.descendant_of(site.root_page, inclusive=True)
# If internationalisation is enabled, include pages from other language trees
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
for translation in site.root_page.get_translations():
queryset |= base_queryset.descendant_of(translation, inclusive=True)
else:
# No sites configured
queryset = queryset.none()
return queryset
def get_queryset(self):
request = self.request
# Allow pages to be filtered to a specific type
try:
models_type = request.GET.get("type", None)
models = models_type and page_models_from_string(models_type) or []
except (LookupError, ValueError):
raise BadRequestError("type doesn't exist")
if not models:
if self.model == Page:
return self.get_base_queryset()
else:
return self.model.objects.filter(
pk__in=self.get_base_queryset().values_list("pk", flat=True)
)
elif len(models) == 1:
# If a single page type has been specified, swap out the Page-based queryset for one based on
# the specific page model so that we can filter on any custom APIFields defined on that model
return models[0].objects.filter(
pk__in=self.get_base_queryset().values_list("pk", flat=True)
)
else: # len(models) > 1
return self.get_base_queryset().type(*models)
def get_object(self):
base = super().get_object()
return base.specific
def find_object(self, queryset, request):
site = Site.find_for_request(request)
if "html_path" in request.GET and site is not None:
path = request.GET["html_path"]
path_components = [component for component in path.split("/") if component]
try:
page, _, _ = site.root_page.specific.route(request, path_components)
except Http404:
return
if queryset.filter(id=page.id).exists():
return page
return super().find_object(queryset, request)
def get_serializer_context(self):
"""
The serialization context differs between listing and detail views.
"""
context = super().get_serializer_context()
context["base_queryset"] = self.get_base_queryset()
return context