Initial commit
This commit is contained in:
1
env/lib/python3.10/site-packages/wagtail/api/__init__.py
vendored
Normal file
1
env/lib/python3.10/site-packages/wagtail/api/__init__.py
vendored
Normal file
@@ -0,0 +1 @@
|
||||
from .conf import APIField # noqa: F401
|
||||
BIN
env/lib/python3.10/site-packages/wagtail/api/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/__pycache__/conf.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/__pycache__/conf.cpython-310.pyc
vendored
Normal file
Binary file not shown.
10
env/lib/python3.10/site-packages/wagtail/api/conf.py
vendored
Normal file
10
env/lib/python3.10/site-packages/wagtail/api/conf.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
class APIField:
|
||||
def __init__(self, name, serializer=None):
|
||||
self.name = name
|
||||
self.serializer = serializer
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.name)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<APIField {self.name}>"
|
||||
0
env/lib/python3.10/site-packages/wagtail/api/v2/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/api/v2/__init__.py
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/apps.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/apps.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/filters.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/filters.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/pagination.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/pagination.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/router.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/router.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/serializers.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/serializers.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/signal_handlers.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/signal_handlers.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/utils.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/utils.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/views.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/__pycache__/views.cpython-310.pyc
vendored
Normal file
Binary file not shown.
22
env/lib/python3.10/site-packages/wagtail/api/v2/apps.py
vendored
Normal file
22
env/lib/python3.10/site-packages/wagtail/api/v2/apps.py
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
from django.apps import AppConfig, apps
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
|
||||
class WagtailAPIV2AppConfig(AppConfig):
|
||||
name = "wagtail.api.v2"
|
||||
label = "wagtailapi_v2"
|
||||
verbose_name = _("Wagtail API v2")
|
||||
|
||||
def ready(self):
|
||||
# Install cache purging signal handlers
|
||||
if getattr(settings, "WAGTAILAPI_USE_FRONTENDCACHE", False):
|
||||
if apps.is_installed("wagtail.contrib.frontend_cache"):
|
||||
from wagtail.api.v2.signal_handlers import register_signal_handlers
|
||||
|
||||
register_signal_handlers()
|
||||
else:
|
||||
raise ImproperlyConfigured(
|
||||
"The setting 'WAGTAILAPI_USE_FRONTENDCACHE' is True but 'wagtail.contrib.frontend_cache' is not in INSTALLED_APPS."
|
||||
)
|
||||
307
env/lib/python3.10/site-packages/wagtail/api/v2/filters.py
vendored
Normal file
307
env/lib/python3.10/site-packages/wagtail/api/v2/filters.py
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
from django.conf import settings
|
||||
from django.db import models
|
||||
from django.shortcuts import get_object_or_404
|
||||
from rest_framework.filters import BaseFilterBackend
|
||||
from taggit.managers import TaggableManager
|
||||
|
||||
from wagtail.models import Locale, Page
|
||||
from wagtail.search.backends import get_search_backend
|
||||
from wagtail.search.backends.base import FilterFieldError, OrderByFieldError
|
||||
|
||||
from .utils import BadRequestError, parse_boolean
|
||||
|
||||
|
||||
class FieldsFilter(BaseFilterBackend):
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
"""
|
||||
This performs field level filtering on the result set
|
||||
Eg: ?title=James Joyce
|
||||
"""
|
||||
fields = set(view.get_available_fields(queryset.model, db_fields_only=True))
|
||||
|
||||
# Locale is a database field, but we provide a separate filter for it
|
||||
if "locale" in fields:
|
||||
fields.remove("locale")
|
||||
|
||||
for field_name, value in request.GET.items():
|
||||
if field_name in fields:
|
||||
try:
|
||||
field = queryset.model._meta.get_field(field_name)
|
||||
except LookupError:
|
||||
field = None
|
||||
|
||||
# Convert value into python
|
||||
try:
|
||||
if isinstance(
|
||||
field, (models.BooleanField, models.NullBooleanField)
|
||||
):
|
||||
value = parse_boolean(value)
|
||||
elif isinstance(field, (models.IntegerField, models.AutoField)):
|
||||
value = int(value)
|
||||
elif isinstance(field, models.ForeignKey):
|
||||
value = field.target_field.get_prep_value(value)
|
||||
except ValueError as e:
|
||||
raise BadRequestError(
|
||||
"field filter error. '%s' is not a valid value for %s (%s)"
|
||||
% (value, field_name, str(e))
|
||||
)
|
||||
|
||||
if "\x00" in str(value):
|
||||
raise BadRequestError(
|
||||
"field filter error. null characters are not allowed for %s"
|
||||
% field_name
|
||||
)
|
||||
|
||||
if isinstance(field, TaggableManager):
|
||||
for tag in value.split(","):
|
||||
queryset = queryset.filter(**{field_name + "__name": tag})
|
||||
|
||||
# Stick a message on the queryset to indicate that tag filtering has been performed
|
||||
# This will let the do_search method know that it must raise an error as searching
|
||||
# and tag filtering at the same time is not supported
|
||||
queryset._filtered_by_tag = True
|
||||
else:
|
||||
queryset = queryset.filter(**{field_name: value})
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class OrderingFilter(BaseFilterBackend):
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
"""
|
||||
This applies ordering to the result set with support for multiple fields.
|
||||
Eg: ?order=title or ?order=title,created_at
|
||||
|
||||
It also supports reverse ordering
|
||||
Eg: ?order=-title
|
||||
|
||||
And random ordering
|
||||
Eg: ?order=random
|
||||
"""
|
||||
if "order" in request.GET:
|
||||
order_by_list = request.GET["order"].split(",")
|
||||
|
||||
# Random ordering
|
||||
if "random" in order_by_list:
|
||||
if len(order_by_list) > 1:
|
||||
raise BadRequestError(
|
||||
"random ordering cannot be combined with other fields"
|
||||
)
|
||||
# Prevent ordering by random with offset
|
||||
if "offset" in request.GET:
|
||||
raise BadRequestError(
|
||||
"random ordering with offset is not supported"
|
||||
)
|
||||
|
||||
return queryset.order_by("?")
|
||||
|
||||
order_by_fields = []
|
||||
for order_by in order_by_list:
|
||||
# Check if reverse ordering is set
|
||||
if order_by.startswith("-"):
|
||||
reverse_order = True
|
||||
order_by = order_by[1:]
|
||||
else:
|
||||
reverse_order = False
|
||||
|
||||
# Add ordering
|
||||
if order_by in view.get_available_fields(queryset.model):
|
||||
order_by_fields.append(order_by)
|
||||
else:
|
||||
# Unknown field
|
||||
raise BadRequestError(
|
||||
"cannot order by '%s' (unknown field)" % order_by
|
||||
)
|
||||
|
||||
# Apply ordering to the queryset
|
||||
queryset = queryset.order_by(*order_by_fields)
|
||||
|
||||
# Reverse order if needed
|
||||
if reverse_order:
|
||||
queryset = queryset.reverse()
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class SearchFilter(BaseFilterBackend):
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
"""
|
||||
This performs a full-text search on the result set
|
||||
Eg: ?search=James Joyce
|
||||
"""
|
||||
search_enabled = getattr(settings, "WAGTAILAPI_SEARCH_ENABLED", True)
|
||||
|
||||
if "search" in request.GET:
|
||||
if not search_enabled:
|
||||
raise BadRequestError("search is disabled")
|
||||
|
||||
# Searching and filtering by tag at the same time is not supported
|
||||
if getattr(queryset, "_filtered_by_tag", False):
|
||||
raise BadRequestError(
|
||||
"filtering by tag with a search query is not supported"
|
||||
)
|
||||
|
||||
search_query = request.GET["search"]
|
||||
search_operator = request.GET.get("search_operator", None)
|
||||
order_by_relevance = "order" not in request.GET
|
||||
|
||||
sb = get_search_backend()
|
||||
try:
|
||||
queryset = sb.search(
|
||||
search_query,
|
||||
queryset,
|
||||
operator=search_operator,
|
||||
order_by_relevance=order_by_relevance,
|
||||
)
|
||||
except FilterFieldError as e:
|
||||
raise BadRequestError(
|
||||
"cannot filter by '{}' while searching (field is not indexed)".format(
|
||||
e.field_name
|
||||
)
|
||||
)
|
||||
except OrderByFieldError as e:
|
||||
raise BadRequestError(
|
||||
"cannot order by '{}' while searching (field is not indexed)".format(
|
||||
e.field_name
|
||||
)
|
||||
)
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class ChildOfFilter(BaseFilterBackend):
|
||||
"""
|
||||
Implements the ?child_of filter used to filter the results to only contain
|
||||
pages that are direct children of the specified page.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if "child_of" in request.GET:
|
||||
try:
|
||||
parent_page_id = int(request.GET["child_of"])
|
||||
if parent_page_id < 0:
|
||||
raise ValueError()
|
||||
|
||||
parent_page = view.get_base_queryset().get(id=parent_page_id)
|
||||
except ValueError:
|
||||
if request.GET["child_of"] == "root":
|
||||
parent_page = view.get_root_page()
|
||||
else:
|
||||
raise BadRequestError("child_of must be a positive integer")
|
||||
except Page.DoesNotExist:
|
||||
raise BadRequestError("parent page doesn't exist")
|
||||
|
||||
queryset = queryset.child_of(parent_page)
|
||||
|
||||
# Save the parent page on the queryset. This is required for the page
|
||||
# explorer, which needs to pass the parent page into
|
||||
# `construct_explorer_page_queryset` hook functions
|
||||
queryset._filtered_by_child_of = parent_page
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class AncestorOfFilter(BaseFilterBackend):
|
||||
"""
|
||||
Implements the ?ancestor filter which limits the set of pages to a
|
||||
particular branch of the page tree.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if "ancestor_of" in request.GET:
|
||||
try:
|
||||
descendant_page_id = int(request.GET["ancestor_of"])
|
||||
if descendant_page_id < 0:
|
||||
raise ValueError()
|
||||
|
||||
descendant_page = view.get_base_queryset().get(id=descendant_page_id)
|
||||
except ValueError:
|
||||
raise BadRequestError("ancestor_of must be a positive integer")
|
||||
except Page.DoesNotExist:
|
||||
raise BadRequestError("descendant page doesn't exist")
|
||||
|
||||
queryset = queryset.ancestor_of(descendant_page)
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class DescendantOfFilter(BaseFilterBackend):
|
||||
"""
|
||||
Implements the ?descendant_of filter which limits the set of pages to a
|
||||
particular branch of the page tree.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if "descendant_of" in request.GET:
|
||||
if hasattr(queryset, "_filtered_by_child_of"):
|
||||
raise BadRequestError(
|
||||
"filtering by descendant_of with child_of is not supported"
|
||||
)
|
||||
try:
|
||||
parent_page_id = int(request.GET["descendant_of"])
|
||||
if parent_page_id < 0:
|
||||
raise ValueError()
|
||||
|
||||
parent_page = view.get_base_queryset().get(id=parent_page_id)
|
||||
except ValueError:
|
||||
if request.GET["descendant_of"] == "root":
|
||||
parent_page = view.get_root_page()
|
||||
else:
|
||||
raise BadRequestError("descendant_of must be a positive integer")
|
||||
except Page.DoesNotExist:
|
||||
raise BadRequestError("ancestor page doesn't exist")
|
||||
|
||||
queryset = queryset.descendant_of(parent_page)
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class TranslationOfFilter(BaseFilterBackend):
|
||||
"""
|
||||
Implements the ?translation_of filter which limits the set of pages to translations
|
||||
of a page.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if "translation_of" in request.GET:
|
||||
try:
|
||||
page_id = int(request.GET["translation_of"])
|
||||
if page_id < 0:
|
||||
raise ValueError()
|
||||
|
||||
page = view.get_base_queryset().get(id=page_id)
|
||||
except ValueError:
|
||||
if request.GET["translation_of"] == "root":
|
||||
page = view.get_root_page()
|
||||
else:
|
||||
raise BadRequestError("translation_of must be a positive integer")
|
||||
except Page.DoesNotExist:
|
||||
raise BadRequestError("translation_of page doesn't exist")
|
||||
|
||||
_filtered_by_child_of = getattr(queryset, "_filtered_by_child_of", None)
|
||||
|
||||
queryset = queryset.translation_of(page)
|
||||
|
||||
if _filtered_by_child_of:
|
||||
queryset._filtered_by_child_of = _filtered_by_child_of
|
||||
|
||||
return queryset
|
||||
|
||||
|
||||
class LocaleFilter(BaseFilterBackend):
|
||||
"""
|
||||
Implements the ?locale filter which limits the set of pages to a
|
||||
particular locale.
|
||||
"""
|
||||
|
||||
def filter_queryset(self, request, queryset, view):
|
||||
if "locale" in request.GET:
|
||||
_filtered_by_child_of = getattr(queryset, "_filtered_by_child_of", None)
|
||||
|
||||
locale = get_object_or_404(Locale, language_code=request.GET["locale"])
|
||||
queryset = queryset.filter(locale=locale)
|
||||
|
||||
if _filtered_by_child_of:
|
||||
queryset._filtered_by_child_of = _filtered_by_child_of
|
||||
|
||||
return queryset
|
||||
53
env/lib/python3.10/site-packages/wagtail/api/v2/pagination.py
vendored
Normal file
53
env/lib/python3.10/site-packages/wagtail/api/v2/pagination.py
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.conf import settings
|
||||
from rest_framework.pagination import BasePagination
|
||||
from rest_framework.response import Response
|
||||
|
||||
from .utils import BadRequestError
|
||||
|
||||
|
||||
class WagtailPagination(BasePagination):
|
||||
def paginate_queryset(self, queryset, request, view=None):
|
||||
limit_max = getattr(settings, "WAGTAILAPI_LIMIT_MAX", 20)
|
||||
|
||||
try:
|
||||
offset = int(request.GET.get("offset", 0))
|
||||
if offset < 0:
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
raise BadRequestError("offset must be a positive integer")
|
||||
|
||||
try:
|
||||
limit_default = 20 if not limit_max else min(20, limit_max)
|
||||
limit = int(request.GET.get("limit", limit_default))
|
||||
if limit < 0:
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
raise BadRequestError("limit must be a positive integer")
|
||||
|
||||
if limit_max and limit > limit_max:
|
||||
raise BadRequestError("limit cannot be higher than %d" % limit_max)
|
||||
|
||||
start = offset
|
||||
stop = offset + limit
|
||||
|
||||
self.view = view
|
||||
self.total_count = queryset.count()
|
||||
return queryset[start:stop]
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
data = OrderedDict(
|
||||
[
|
||||
(
|
||||
"meta",
|
||||
OrderedDict(
|
||||
[
|
||||
("total_count", self.total_count),
|
||||
]
|
||||
),
|
||||
),
|
||||
("items", data),
|
||||
]
|
||||
)
|
||||
return Response(data)
|
||||
95
env/lib/python3.10/site-packages/wagtail/api/v2/router.py
vendored
Normal file
95
env/lib/python3.10/site-packages/wagtail/api/v2/router.py
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
import functools
|
||||
|
||||
from django.urls import include, re_path
|
||||
|
||||
from wagtail.utils.urlpatterns import decorate_urlpatterns
|
||||
|
||||
|
||||
class WagtailAPIRouter:
|
||||
"""
|
||||
A class that provides routing and cross-linking for a collection
|
||||
of API endpoints
|
||||
"""
|
||||
|
||||
def __init__(self, url_namespace):
|
||||
self.url_namespace = url_namespace
|
||||
self._endpoints = {}
|
||||
|
||||
def register_endpoint(self, name, class_):
|
||||
self._endpoints[name] = class_
|
||||
|
||||
def get_model_endpoint(self, model):
|
||||
"""
|
||||
Finds the endpoint in the API that represents a model
|
||||
|
||||
Returns a (name, endpoint_class) tuple. Or None if an
|
||||
endpoint is not found.
|
||||
"""
|
||||
for name, class_ in self._endpoints.items():
|
||||
if issubclass(model, class_.model):
|
||||
return name, class_
|
||||
|
||||
def get_model_listing_urlpath(self, model):
|
||||
"""
|
||||
Returns a URL path (excluding scheme and hostname) to the listing
|
||||
page of a model
|
||||
|
||||
Returns None if the model is not represented by any endpoints.
|
||||
"""
|
||||
endpoint = self.get_model_endpoint(model)
|
||||
|
||||
if endpoint:
|
||||
endpoint_name, endpoint_class = endpoint[0], endpoint[1]
|
||||
url_namespace = self.url_namespace + ":" + endpoint_name
|
||||
return endpoint_class.get_model_listing_urlpath(
|
||||
model, namespace=url_namespace
|
||||
)
|
||||
|
||||
def get_object_detail_urlpath(self, model, pk):
|
||||
"""
|
||||
Returns a URL path (excluding scheme and hostname) to the detail
|
||||
page of an object.
|
||||
|
||||
Returns None if the object is not represented by any endpoints.
|
||||
"""
|
||||
endpoint = self.get_model_endpoint(model)
|
||||
|
||||
if endpoint:
|
||||
endpoint_name, endpoint_class = endpoint[0], endpoint[1]
|
||||
url_namespace = self.url_namespace + ":" + endpoint_name
|
||||
return endpoint_class.get_object_detail_urlpath(
|
||||
model, pk, namespace=url_namespace
|
||||
)
|
||||
|
||||
def wrap_view(self, func):
|
||||
@functools.wraps(func)
|
||||
def wrapped(request, *args, **kwargs):
|
||||
request.wagtailapi_router = self
|
||||
return func(request, *args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
def get_urlpatterns(self):
|
||||
urlpatterns = []
|
||||
|
||||
for name, class_ in self._endpoints.items():
|
||||
pattern = re_path(
|
||||
rf"^{name}/",
|
||||
include((class_.get_urlpatterns(), name), namespace=name),
|
||||
)
|
||||
urlpatterns.append(pattern)
|
||||
|
||||
decorate_urlpatterns(urlpatterns, self.wrap_view)
|
||||
|
||||
return urlpatterns
|
||||
|
||||
@property
|
||||
def urls(self):
|
||||
"""
|
||||
A shortcut to allow quick registration of the API in a URLconf.
|
||||
|
||||
Use with Django's include() function:
|
||||
|
||||
path('api/', include(myapi.urls)),
|
||||
"""
|
||||
return self.get_urlpatterns(), self.url_namespace, self.url_namespace
|
||||
419
env/lib/python3.10/site-packages/wagtail/api/v2/serializers.py
vendored
Normal file
419
env/lib/python3.10/site-packages/wagtail/api/v2/serializers.py
vendored
Normal file
@@ -0,0 +1,419 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.urls.exceptions import NoReverseMatch
|
||||
from modelcluster.models import get_all_child_relations
|
||||
from rest_framework import relations, serializers
|
||||
from rest_framework.fields import Field, SkipField
|
||||
from taggit.managers import _TaggableManager
|
||||
|
||||
from wagtail import fields as wagtailcore_fields
|
||||
|
||||
from .utils import get_object_detail_url
|
||||
|
||||
|
||||
class TypeField(Field):
|
||||
"""
|
||||
Serializes the "type" field of each object.
|
||||
|
||||
Example:
|
||||
"type": "wagtailimages.Image"
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return instance
|
||||
|
||||
def to_representation(self, obj):
|
||||
name = type(obj)._meta.app_label + "." + type(obj).__name__
|
||||
self.context["view"].seen_types[name] = type(obj)
|
||||
return name
|
||||
|
||||
|
||||
class DetailUrlField(Field):
|
||||
"""
|
||||
Serializes the "detail_url" field of each object.
|
||||
|
||||
Example:
|
||||
"detail_url": "http://api.example.com/v1/images/1/"
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
url = get_object_detail_url(
|
||||
self.context["router"], self.context["request"], type(instance), instance.pk
|
||||
)
|
||||
|
||||
if url:
|
||||
return url
|
||||
else:
|
||||
# Hide the detail_url field if the object doesn't have an endpoint
|
||||
raise SkipField
|
||||
|
||||
def to_representation(self, url):
|
||||
return url
|
||||
|
||||
|
||||
class PageHtmlUrlField(Field):
|
||||
"""
|
||||
Serializes the "html_url" field for pages.
|
||||
|
||||
Example:
|
||||
"html_url": "http://www.example.com/blog/blog-post/"
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return instance
|
||||
|
||||
def to_representation(self, page):
|
||||
try:
|
||||
return page.full_url
|
||||
except NoReverseMatch:
|
||||
return None
|
||||
|
||||
|
||||
class PageTypeField(Field):
|
||||
"""
|
||||
Serializes the "type" field for pages.
|
||||
|
||||
This takes into account the fact that we sometimes may not have the "specific"
|
||||
page object by calling "page.specific_class" instead of looking at the object's
|
||||
type.
|
||||
|
||||
Example:
|
||||
"type": "blog.BlogPage"
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return instance
|
||||
|
||||
def to_representation(self, page):
|
||||
if page.specific_class is None:
|
||||
return None
|
||||
name = page.specific_class._meta.app_label + "." + page.specific_class.__name__
|
||||
self.context["view"].seen_types[name] = page.specific_class
|
||||
return name
|
||||
|
||||
|
||||
class PageLocaleField(Field):
|
||||
"""
|
||||
Serializes the "locale" field for pages.
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return instance
|
||||
|
||||
def to_representation(self, page):
|
||||
return page.locale.language_code
|
||||
|
||||
|
||||
class RelatedField(relations.RelatedField):
|
||||
"""
|
||||
Serializes related objects (eg, foreign keys).
|
||||
|
||||
Example:
|
||||
|
||||
"feed_image": {
|
||||
"id": 1,
|
||||
"meta": {
|
||||
"type": "wagtailimages.Image",
|
||||
"detail_url": "http://api.example.com/v1/images/1/"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.serializer_class = kwargs.pop("serializer_class")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def to_representation(self, value):
|
||||
serializer = self.serializer_class(context=self.context)
|
||||
return serializer.to_representation(value)
|
||||
|
||||
|
||||
class PageParentField(relations.RelatedField):
|
||||
"""
|
||||
Serializes the "parent" field on Page objects.
|
||||
|
||||
Pages don't have a "parent" field so some extra logic is needed to find the
|
||||
parent page. That logic is implemented in this class.
|
||||
|
||||
The representation is the same as the RelatedField class.
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
parent = instance.get_parent()
|
||||
|
||||
if self.context["base_queryset"].filter(id=parent.id).exists():
|
||||
return parent
|
||||
|
||||
def to_representation(self, value):
|
||||
serializer_class = get_serializer_class(
|
||||
value.__class__,
|
||||
["id", "type", "detail_url", "html_url", "title"],
|
||||
meta_fields=["type", "detail_url", "html_url"],
|
||||
base=PageSerializer,
|
||||
)
|
||||
serializer = serializer_class(context=self.context)
|
||||
return serializer.to_representation(value)
|
||||
|
||||
|
||||
class PageAliasOfField(relations.RelatedField):
|
||||
"""
|
||||
Serializes the "alias_of" field on Page objects.
|
||||
"""
|
||||
|
||||
def get_attribute(self, instance):
|
||||
return instance.alias_of
|
||||
|
||||
def to_representation(self, value):
|
||||
serializer_class = get_serializer_class(
|
||||
value.__class__,
|
||||
["id", "type", "detail_url", "html_url", "title"],
|
||||
meta_fields=["type", "detail_url", "html_url"],
|
||||
base=PageSerializer,
|
||||
)
|
||||
serializer = serializer_class(context=self.context)
|
||||
return serializer.to_representation(value)
|
||||
|
||||
|
||||
class ChildRelationField(Field):
|
||||
"""
|
||||
Serializes child relations.
|
||||
|
||||
Child relations are any model that is related to a Page using a ParentalKey.
|
||||
They are used for repeated fields on a page such as carousel items or related
|
||||
links.
|
||||
|
||||
Child objects are part of the pages content so we nest them. The relation is
|
||||
represented as a list of objects.
|
||||
|
||||
Example:
|
||||
|
||||
"carousel_items": [
|
||||
{
|
||||
"id": 1,
|
||||
"meta": {
|
||||
"type": "demo.MyCarouselItem"
|
||||
},
|
||||
"title": "First carousel item",
|
||||
"image": {
|
||||
"id": 1,
|
||||
"meta": {
|
||||
"type": "wagtailimages.Image",
|
||||
"detail_url": "http://api.example.com/v1/images/1/"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"meta": {
|
||||
"type": "demo.MyCarouselItem"
|
||||
},
|
||||
"title": "Second carousel item (no image)",
|
||||
"image": null
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.serializer_class = kwargs.pop("serializer_class")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def to_representation(self, value):
|
||||
serializer = self.serializer_class(context=self.context)
|
||||
|
||||
return [
|
||||
serializer.to_representation(child_object) for child_object in value.all()
|
||||
]
|
||||
|
||||
|
||||
class StreamField(Field):
|
||||
"""
|
||||
Serializes StreamField values.
|
||||
|
||||
Stream fields are stored in JSON format in the database. We reuse that in
|
||||
the API.
|
||||
|
||||
Example:
|
||||
|
||||
"body": [
|
||||
{
|
||||
"type": "heading",
|
||||
"value": {
|
||||
"text": "Hello world!",
|
||||
"size": "h1"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "paragraph",
|
||||
"value": "Some content"
|
||||
}
|
||||
{
|
||||
"type": "image",
|
||||
"value": 1
|
||||
}
|
||||
]
|
||||
|
||||
Where "heading" is a struct block containing "text" and "size" fields, and
|
||||
"paragraph" is a simple text block.
|
||||
|
||||
Note that foreign keys are represented slightly differently in stream fields
|
||||
to other parts of the API. In stream fields, a foreign key is represented
|
||||
by an integer (the ID of the related object) but elsewhere in the API,
|
||||
foreign objects are nested objects with id and meta as attributes.
|
||||
"""
|
||||
|
||||
def to_representation(self, value):
|
||||
return value.stream_block.get_api_representation(value, self.context)
|
||||
|
||||
|
||||
class TagsField(Field):
|
||||
"""
|
||||
Serializes django-taggit TaggableManager fields.
|
||||
|
||||
These fields are a common way to link tags to objects in Wagtail. The API
|
||||
serializes these as a list of strings taken from the name attribute of each
|
||||
tag.
|
||||
|
||||
Example:
|
||||
|
||||
"tags": ["bird", "wagtail"]
|
||||
"""
|
||||
|
||||
def to_representation(self, value):
|
||||
return list(value.all().order_by("name").values_list("name", flat=True))
|
||||
|
||||
|
||||
class BaseSerializer(serializers.ModelSerializer):
|
||||
# Add StreamField to serializer_field_mapping
|
||||
serializer_field_mapping = (
|
||||
serializers.ModelSerializer.serializer_field_mapping.copy()
|
||||
)
|
||||
serializer_field_mapping.update(
|
||||
{
|
||||
wagtailcore_fields.StreamField: StreamField,
|
||||
}
|
||||
)
|
||||
serializer_related_field = RelatedField
|
||||
|
||||
# Meta fields
|
||||
type = TypeField(read_only=True)
|
||||
detail_url = DetailUrlField(read_only=True)
|
||||
|
||||
def to_representation(self, instance):
|
||||
data = OrderedDict()
|
||||
fields = [field for field in self.fields.values() if not field.write_only]
|
||||
|
||||
# Split meta fields from core fields
|
||||
meta_fields = [
|
||||
field for field in fields if field.field_name in self.meta_fields
|
||||
]
|
||||
fields = [field for field in fields if field.field_name not in self.meta_fields]
|
||||
|
||||
# Make sure id is always first. This will be filled in later
|
||||
if "id" in [field.field_name for field in fields]:
|
||||
data["id"] = None
|
||||
|
||||
# Serialise meta fields
|
||||
meta = OrderedDict()
|
||||
for field in meta_fields:
|
||||
try:
|
||||
attribute = field.get_attribute(instance)
|
||||
except SkipField:
|
||||
continue
|
||||
|
||||
if attribute is None:
|
||||
# We skip `to_representation` for `None` values so that
|
||||
# fields do not have to explicitly deal with that case.
|
||||
meta[field.field_name] = None
|
||||
else:
|
||||
meta[field.field_name] = field.to_representation(attribute)
|
||||
|
||||
if meta:
|
||||
data["meta"] = meta
|
||||
|
||||
# Serialise core fields
|
||||
for field in fields:
|
||||
try:
|
||||
if field.field_name == "admin_display_title":
|
||||
instance = instance.specific_deferred
|
||||
|
||||
attribute = field.get_attribute(instance)
|
||||
except SkipField:
|
||||
continue
|
||||
|
||||
if attribute is None:
|
||||
# We skip `to_representation` for `None` values so that
|
||||
# fields do not have to explicitly deal with that case.
|
||||
data[field.field_name] = None
|
||||
else:
|
||||
data[field.field_name] = field.to_representation(attribute)
|
||||
|
||||
return data
|
||||
|
||||
def build_property_field(self, field_name, model_class):
|
||||
# TaggableManager is not a Django field so it gets treated as a property
|
||||
field = getattr(model_class, field_name)
|
||||
if isinstance(field, _TaggableManager):
|
||||
return TagsField, {}
|
||||
|
||||
return super().build_property_field(field_name, model_class)
|
||||
|
||||
def build_relational_field(self, field_name, relation_info):
|
||||
field_class, field_kwargs = super().build_relational_field(
|
||||
field_name, relation_info
|
||||
)
|
||||
field_kwargs["serializer_class"] = self.child_serializer_classes[field_name]
|
||||
return field_class, field_kwargs
|
||||
|
||||
|
||||
class PageSerializer(BaseSerializer):
|
||||
type = PageTypeField(read_only=True)
|
||||
locale = PageLocaleField(read_only=True)
|
||||
html_url = PageHtmlUrlField(read_only=True)
|
||||
parent = PageParentField(read_only=True)
|
||||
alias_of = PageAliasOfField(read_only=True)
|
||||
|
||||
def build_relational_field(self, field_name, relation_info):
|
||||
# Find all relation fields that point to child class and make them use
|
||||
# the ChildRelationField class.
|
||||
if relation_info.to_many:
|
||||
model = getattr(self.Meta, "model")
|
||||
child_relations = {
|
||||
child_relation.field.remote_field.related_name: child_relation.related_model
|
||||
for child_relation in get_all_child_relations(model)
|
||||
}
|
||||
|
||||
if (
|
||||
field_name in child_relations
|
||||
and field_name in self.child_serializer_classes
|
||||
):
|
||||
return ChildRelationField, {
|
||||
"serializer_class": self.child_serializer_classes[field_name]
|
||||
}
|
||||
|
||||
return super().build_relational_field(field_name, relation_info)
|
||||
|
||||
|
||||
def get_serializer_class(
|
||||
model,
|
||||
field_names,
|
||||
meta_fields,
|
||||
field_serializer_overrides=None,
|
||||
child_serializer_classes=None,
|
||||
base=BaseSerializer,
|
||||
):
|
||||
model_ = model
|
||||
|
||||
class Meta:
|
||||
model = model_
|
||||
fields = list(field_names)
|
||||
|
||||
attrs = {
|
||||
"Meta": Meta,
|
||||
"meta_fields": list(meta_fields),
|
||||
"child_serializer_classes": child_serializer_classes or {},
|
||||
}
|
||||
|
||||
if field_serializer_overrides:
|
||||
attrs.update(field_serializer_overrides)
|
||||
|
||||
return type(str(model_.__name__ + "Serializer"), (base,), attrs)
|
||||
61
env/lib/python3.10/site-packages/wagtail/api/v2/signal_handlers.py
vendored
Normal file
61
env/lib/python3.10/site-packages/wagtail/api/v2/signal_handlers.py
vendored
Normal file
@@ -0,0 +1,61 @@
|
||||
from django.db.models.signals import post_delete, post_save
|
||||
from django.urls import reverse
|
||||
|
||||
from wagtail.contrib.frontend_cache.utils import purge_url_from_cache
|
||||
from wagtail.documents import get_document_model
|
||||
from wagtail.images import get_image_model
|
||||
from wagtail.models import get_page_models
|
||||
from wagtail.signals import page_published, page_unpublished
|
||||
|
||||
from .utils import get_base_url
|
||||
|
||||
|
||||
def purge_page_from_cache(instance, **kwargs):
|
||||
base_url = get_base_url()
|
||||
purge_url_from_cache(
|
||||
base_url + reverse("wagtailapi_v2:pages:detail", args=(instance.id,))
|
||||
)
|
||||
|
||||
|
||||
def purge_image_from_cache(instance, **kwargs):
|
||||
if not kwargs.get("created", False):
|
||||
base_url = get_base_url()
|
||||
purge_url_from_cache(
|
||||
base_url + reverse("wagtailapi_v2:images:detail", args=(instance.id,))
|
||||
)
|
||||
|
||||
|
||||
def purge_document_from_cache(instance, **kwargs):
|
||||
if not kwargs.get("created", False):
|
||||
base_url = get_base_url()
|
||||
purge_url_from_cache(
|
||||
base_url + reverse("wagtailapi_v2:documents:detail", args=(instance.id,))
|
||||
)
|
||||
|
||||
|
||||
def register_signal_handlers():
|
||||
Image = get_image_model()
|
||||
Document = get_document_model()
|
||||
|
||||
for model in get_page_models():
|
||||
page_published.connect(purge_page_from_cache, sender=model)
|
||||
page_unpublished.connect(purge_page_from_cache, sender=model)
|
||||
|
||||
post_save.connect(purge_image_from_cache, sender=Image)
|
||||
post_delete.connect(purge_image_from_cache, sender=Image)
|
||||
post_save.connect(purge_document_from_cache, sender=Document)
|
||||
post_delete.connect(purge_document_from_cache, sender=Document)
|
||||
|
||||
|
||||
def unregister_signal_handlers():
|
||||
Image = get_image_model()
|
||||
Document = get_document_model()
|
||||
|
||||
for model in get_page_models():
|
||||
page_published.disconnect(purge_page_from_cache, sender=model)
|
||||
page_unpublished.disconnect(purge_page_from_cache, sender=model)
|
||||
|
||||
post_save.disconnect(purge_image_from_cache, sender=Image)
|
||||
post_delete.disconnect(purge_image_from_cache, sender=Image)
|
||||
post_save.disconnect(purge_document_from_cache, sender=Document)
|
||||
post_delete.disconnect(purge_document_from_cache, sender=Document)
|
||||
0
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__init__.py
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_documents.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_documents.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_images.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_images.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_pages.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/test_pages.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/tests.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/api/v2/tests/__pycache__/tests.cpython-310.pyc
vendored
Normal file
Binary file not shown.
615
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_documents.py
vendored
Normal file
615
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_documents.py
vendored
Normal file
@@ -0,0 +1,615 @@
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.urls import reverse
|
||||
|
||||
from wagtail.api.v2 import signal_handlers
|
||||
from wagtail.documents import get_document_model
|
||||
|
||||
|
||||
class TestDocumentListing(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:documents:listing"), params)
|
||||
|
||||
def get_document_id_list(self, content):
|
||||
return [document["id"] for document in content["items"]]
|
||||
|
||||
# BASIC TESTS
|
||||
|
||||
def test_basic(self):
|
||||
response = self.get_response()
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# Check that the meta section is there
|
||||
self.assertIn("meta", content)
|
||||
self.assertIsInstance(content["meta"], dict)
|
||||
|
||||
# Check that the total count is there and correct
|
||||
self.assertIn("total_count", content["meta"])
|
||||
self.assertIsInstance(content["meta"]["total_count"], int)
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_document_model().objects.count()
|
||||
)
|
||||
|
||||
# Check that the items section is there
|
||||
self.assertIn("items", content)
|
||||
self.assertIsInstance(content["items"], list)
|
||||
|
||||
# Check that each document has a meta section with type and detail_url attributes
|
||||
for document in content["items"]:
|
||||
self.assertIn("meta", document)
|
||||
self.assertIsInstance(document["meta"], dict)
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()),
|
||||
{"type", "detail_url", "download_url", "tags"},
|
||||
)
|
||||
|
||||
# Type should always be wagtaildocs.Document
|
||||
self.assertEqual(document["meta"]["type"], "wagtaildocs.Document")
|
||||
|
||||
# Check detail_url
|
||||
self.assertEqual(
|
||||
document["meta"]["detail_url"],
|
||||
"http://localhost/api/main/documents/%d/" % document["id"],
|
||||
)
|
||||
|
||||
# Check download_url
|
||||
self.assertTrue(
|
||||
document["meta"]["download_url"].startswith(
|
||||
"http://localhost/documents/%d/" % document["id"]
|
||||
)
|
||||
)
|
||||
|
||||
# FIELDS
|
||||
|
||||
def test_fields_default(self):
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()),
|
||||
{"type", "detail_url", "download_url", "tags"},
|
||||
)
|
||||
|
||||
def test_fields(self):
|
||||
response = self.get_response(fields="title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()),
|
||||
{"type", "detail_url", "download_url", "tags"},
|
||||
)
|
||||
|
||||
def test_remove_fields(self):
|
||||
response = self.get_response(fields="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta"})
|
||||
|
||||
def test_remove_meta_fields(self):
|
||||
response = self.get_response(fields="-download_url")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()), {"type", "detail_url", "tags"}
|
||||
)
|
||||
|
||||
def test_remove_all_meta_fields(self):
|
||||
response = self.get_response(fields="-type,-detail_url,-tags,-download_url")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "title"})
|
||||
|
||||
def test_remove_id_field(self):
|
||||
response = self.get_response(fields="-id")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"meta", "title"})
|
||||
|
||||
def test_all_fields(self):
|
||||
response = self.get_response(fields="*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
|
||||
def test_all_fields_then_remove_something(self):
|
||||
response = self.get_response(fields="*,-title,-download_url")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertEqual(set(document.keys()), {"id", "meta"})
|
||||
self.assertEqual(
|
||||
set(document["meta"].keys()), {"type", "detail_url", "tags"}
|
||||
)
|
||||
|
||||
def test_fields_tags(self):
|
||||
response = self.get_response(fields="tags")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for document in content["items"]:
|
||||
self.assertIsInstance(document["meta"]["tags"], list)
|
||||
|
||||
def test_star_in_wrong_position_gives_error(self):
|
||||
response = self.get_response(fields="title,*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "fields error: '*' must be in the first position"}
|
||||
)
|
||||
|
||||
def test_fields_which_are_not_in_api_fields_gives_error(self):
|
||||
response = self.get_response(fields="uploaded_by_user")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: uploaded_by_user"})
|
||||
|
||||
def test_fields_unknown_field_gives_error(self):
|
||||
response = self.get_response(fields="123,title,abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_fields_remove_unknown_field_gives_error(self):
|
||||
response = self.get_response(fields="-123,-title,-abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
# FILTERING
|
||||
|
||||
def test_filtering_exact_filter(self):
|
||||
response = self.get_response(title="James Joyce")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list, [2])
|
||||
|
||||
def test_filtering_on_id(self):
|
||||
response = self.get_response(id=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list, [10])
|
||||
|
||||
def test_filtering_tags(self):
|
||||
get_document_model().objects.get(id=3).tags.add("test")
|
||||
|
||||
response = self.get_response(tags="test")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list, [3])
|
||||
|
||||
def test_filtering_unknown_field_gives_error(self):
|
||||
response = self.get_response(not_a_field="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content,
|
||||
{
|
||||
"message": "query parameter is not an operation or a recognised field: not_a_field"
|
||||
},
|
||||
)
|
||||
|
||||
# ORDERING
|
||||
|
||||
def test_ordering_by_title(self):
|
||||
response = self.get_response(order="title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list, [3, 12, 10, 2, 7, 8, 5, 4, 1, 11, 9, 6])
|
||||
|
||||
def test_ordering_by_title_backwards(self):
|
||||
response = self.get_response(order="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list, [6, 9, 11, 1, 4, 5, 8, 7, 2, 10, 12, 3])
|
||||
|
||||
def test_ordering_by_random(self):
|
||||
response_1 = self.get_response(order="random")
|
||||
content_1 = json.loads(response_1.content.decode("UTF-8"))
|
||||
document_id_list_1 = self.get_document_id_list(content_1)
|
||||
|
||||
response_2 = self.get_response(order="random")
|
||||
content_2 = json.loads(response_2.content.decode("UTF-8"))
|
||||
document_id_list_2 = self.get_document_id_list(content_2)
|
||||
|
||||
self.assertNotEqual(document_id_list_1, document_id_list_2)
|
||||
|
||||
def test_ordering_by_random_backwards_gives_error(self):
|
||||
response = self.get_response(order="-random")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "cannot order by 'random' (unknown field)"}
|
||||
)
|
||||
|
||||
def test_ordering_by_random_with_offset_gives_error(self):
|
||||
response = self.get_response(order="random", offset=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "random ordering with offset is not supported"}
|
||||
)
|
||||
|
||||
def test_ordering_by_unknown_field_gives_error(self):
|
||||
response = self.get_response(order="not_a_field")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "cannot order by 'not_a_field' (unknown field)"}
|
||||
)
|
||||
|
||||
# LIMIT
|
||||
|
||||
def test_limit_only_two_items_returned(self):
|
||||
response = self.get_response(limit=2)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(len(content["items"]), 2)
|
||||
|
||||
def test_limit_total_count(self):
|
||||
response = self.get_response(limit=2)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# The total count must not be affected by "limit"
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_document_model().objects.count()
|
||||
)
|
||||
|
||||
def test_limit_not_integer_gives_error(self):
|
||||
response = self.get_response(limit="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit must be a positive integer"})
|
||||
|
||||
def test_limit_too_high_gives_error(self):
|
||||
response = self.get_response(limit=1000)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit cannot be higher than 20"})
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=None)
|
||||
def test_limit_max_none_gives_no_errors(self):
|
||||
response = self.get_response(limit=1000000)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(len(content["items"]), get_document_model().objects.count())
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=10)
|
||||
def test_limit_maximum_can_be_changed(self):
|
||||
response = self.get_response(limit=20)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit cannot be higher than 10"})
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=2)
|
||||
def test_limit_default_changes_with_max(self):
|
||||
# The default limit is 20. If WAGTAILAPI_LIMIT_MAX is less than that,
|
||||
# the default should change accordingly.
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(len(content["items"]), 2)
|
||||
|
||||
# OFFSET
|
||||
|
||||
def test_offset_5_usually_appears_5th_in_list(self):
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list.index(5), 4)
|
||||
|
||||
def test_offset_5_moves_after_offset(self):
|
||||
response = self.get_response(offset=4)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
self.assertEqual(document_id_list.index(5), 0)
|
||||
|
||||
def test_offset_total_count(self):
|
||||
response = self.get_response(offset=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# The total count must not be affected by "offset"
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_document_model().objects.count()
|
||||
)
|
||||
|
||||
def test_offset_not_integer_gives_error(self):
|
||||
response = self.get_response(offset="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "offset must be a positive integer"})
|
||||
|
||||
|
||||
class TestDocumentListingSearch(TransactionTestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:documents:listing"), params)
|
||||
|
||||
def get_document_id_list(self, content):
|
||||
return [document["id"] for document in content["items"]]
|
||||
|
||||
def test_search_for_james_joyce(self):
|
||||
response = self.get_response(search="james")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
|
||||
self.assertEqual(set(document_id_list), {2})
|
||||
|
||||
def test_search_with_order(self):
|
||||
response = self.get_response(search="james", order="title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
document_id_list = self.get_document_id_list(content)
|
||||
|
||||
self.assertEqual(document_id_list, [2])
|
||||
|
||||
@override_settings(WAGTAILAPI_SEARCH_ENABLED=False)
|
||||
def test_search_when_disabled_gives_error(self):
|
||||
response = self.get_response(search="james")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "search is disabled"})
|
||||
|
||||
def test_search_when_filtering_by_tag_gives_error(self):
|
||||
response = self.get_response(search="james", tags="wagtail")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content,
|
||||
{"message": "filtering by tag with a search query is not supported"},
|
||||
)
|
||||
|
||||
|
||||
class TestDocumentDetail(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, image_id, **params):
|
||||
return self.client.get(
|
||||
reverse("wagtailapi_v2:documents:detail", args=(image_id,)), params
|
||||
)
|
||||
|
||||
def test_basic(self):
|
||||
response = self.get_response(1)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# Check the id field
|
||||
self.assertIn("id", content)
|
||||
self.assertEqual(content["id"], 1)
|
||||
|
||||
# Check that the meta section is there
|
||||
self.assertIn("meta", content)
|
||||
self.assertIsInstance(content["meta"], dict)
|
||||
|
||||
# Check the meta type
|
||||
self.assertIn("type", content["meta"])
|
||||
self.assertEqual(content["meta"]["type"], "wagtaildocs.Document")
|
||||
|
||||
# Check the meta detail_url
|
||||
self.assertIn("detail_url", content["meta"])
|
||||
self.assertEqual(
|
||||
content["meta"]["detail_url"], "http://localhost/api/main/documents/1/"
|
||||
)
|
||||
|
||||
# Check the meta download_url
|
||||
self.assertIn("download_url", content["meta"])
|
||||
self.assertEqual(
|
||||
content["meta"]["download_url"],
|
||||
"http://localhost/documents/1/wagtail_by_markyharky.jpg",
|
||||
)
|
||||
|
||||
# Check the title field
|
||||
self.assertIn("title", content)
|
||||
self.assertEqual(content["title"], "Wagtail by mark Harkin")
|
||||
|
||||
# Check the tags field
|
||||
self.assertIn("tags", content["meta"])
|
||||
self.assertEqual(content["meta"]["tags"], [])
|
||||
|
||||
def test_tags(self):
|
||||
get_document_model().objects.get(id=1).tags.add("hello")
|
||||
get_document_model().objects.get(id=1).tags.add("world")
|
||||
|
||||
response = self.get_response(1)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("tags", content["meta"])
|
||||
self.assertEqual(content["meta"]["tags"], ["hello", "world"])
|
||||
|
||||
@override_settings(WAGTAILAPI_BASE_URL="http://api.example.com/")
|
||||
def test_download_url_with_custom_base_url(self):
|
||||
response = self.get_response(1)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("download_url", content["meta"])
|
||||
self.assertEqual(
|
||||
content["meta"]["download_url"],
|
||||
"http://api.example.com/documents/1/wagtail_by_markyharky.jpg",
|
||||
)
|
||||
|
||||
# FIELDS
|
||||
|
||||
def test_remove_fields(self):
|
||||
response = self.get_response(2, fields="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("id", set(content.keys()))
|
||||
self.assertNotIn("title", set(content.keys()))
|
||||
|
||||
def test_remove_meta_fields(self):
|
||||
response = self.get_response(2, fields="-download_url")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("detail_url", set(content["meta"].keys()))
|
||||
self.assertNotIn("download_url", set(content["meta"].keys()))
|
||||
|
||||
def test_remove_id_field(self):
|
||||
response = self.get_response(2, fields="-id")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("title", set(content.keys()))
|
||||
self.assertNotIn("id", set(content.keys()))
|
||||
|
||||
def test_remove_all_fields(self):
|
||||
response = self.get_response(2, fields="_,id,type")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(set(content.keys()), {"id", "meta"})
|
||||
self.assertEqual(set(content["meta"].keys()), {"type"})
|
||||
|
||||
def test_star_in_wrong_position_gives_error(self):
|
||||
response = self.get_response(2, fields="title,*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "fields error: '*' must be in the first position"}
|
||||
)
|
||||
|
||||
def test_fields_which_are_not_in_api_fields_gives_error(self):
|
||||
response = self.get_response(2, fields="path")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: path"})
|
||||
|
||||
def test_fields_unknown_field_gives_error(self):
|
||||
response = self.get_response(2, fields="123,title,abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_fields_remove_unknown_field_gives_error(self):
|
||||
response = self.get_response(2, fields="-123,-title,-abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_nested_fields_on_non_relational_field_gives_error(self):
|
||||
response = self.get_response(2, fields="title(foo,bar)")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "'title' does not support nested fields"})
|
||||
|
||||
|
||||
class TestDocumentFind(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:documents:find"), params)
|
||||
|
||||
def test_without_parameters(self):
|
||||
response = self.get_response()
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(content, {"message": "not found"})
|
||||
|
||||
def test_find_by_id(self):
|
||||
response = self.get_response(id=5)
|
||||
|
||||
self.assertRedirects(
|
||||
response,
|
||||
"http://localhost" + reverse("wagtailapi_v2:documents:detail", args=[5]),
|
||||
fetch_redirect_response=False,
|
||||
)
|
||||
|
||||
def test_find_by_id_nonexistent(self):
|
||||
response = self.get_response(id=1234)
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(content, {"message": "not found"})
|
||||
|
||||
|
||||
@override_settings(
|
||||
WAGTAILFRONTENDCACHE={
|
||||
"varnish": {
|
||||
"BACKEND": "wagtail.contrib.frontend_cache.backends.HTTPBackend",
|
||||
"LOCATION": "http://localhost:8000",
|
||||
},
|
||||
},
|
||||
WAGTAILAPI_BASE_URL="http://api.example.com",
|
||||
)
|
||||
@mock.patch("wagtail.contrib.frontend_cache.backends.http.HTTPBackend.purge")
|
||||
class TestDocumentCacheInvalidation(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
signal_handlers.register_signal_handlers()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
signal_handlers.unregister_signal_handlers()
|
||||
|
||||
def test_resave_document_purges(self, purge):
|
||||
get_document_model().objects.get(id=5).save()
|
||||
|
||||
purge.assert_any_call("http://api.example.com/api/main/documents/5/")
|
||||
|
||||
def test_delete_document_purges(self, purge):
|
||||
get_document_model().objects.get(id=5).delete()
|
||||
|
||||
purge.assert_any_call("http://api.example.com/api/main/documents/5/")
|
||||
607
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_images.py
vendored
Normal file
607
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_images.py
vendored
Normal file
@@ -0,0 +1,607 @@
|
||||
import json
|
||||
from unittest import mock
|
||||
|
||||
from django.test import TestCase, TransactionTestCase
|
||||
from django.test.utils import override_settings
|
||||
from django.urls import reverse
|
||||
|
||||
from wagtail.api.v2 import signal_handlers
|
||||
from wagtail.images import get_image_model
|
||||
|
||||
|
||||
class TestImageListing(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:images:listing"), params)
|
||||
|
||||
def get_image_id_list(self, content):
|
||||
return [image["id"] for image in content["items"]]
|
||||
|
||||
# BASIC TESTS
|
||||
|
||||
def test_basic(self):
|
||||
response = self.get_response()
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# Check that the meta section is there
|
||||
self.assertIn("meta", content)
|
||||
self.assertIsInstance(content["meta"], dict)
|
||||
|
||||
# Check that the total count is there and correct
|
||||
self.assertIn("total_count", content["meta"])
|
||||
self.assertIsInstance(content["meta"]["total_count"], int)
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_image_model().objects.count()
|
||||
)
|
||||
|
||||
# Check that the items section is there
|
||||
self.assertIn("items", content)
|
||||
self.assertIsInstance(content["items"], list)
|
||||
|
||||
# Check that each image has a meta section with type and detail_url attributes
|
||||
for image in content["items"]:
|
||||
self.assertIn("meta", image)
|
||||
self.assertIsInstance(image["meta"], dict)
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
|
||||
# Type should always be wagtailimages.Image
|
||||
self.assertEqual(image["meta"]["type"], "wagtailimages.Image")
|
||||
|
||||
# Check detail url
|
||||
self.assertEqual(
|
||||
image["meta"]["detail_url"],
|
||||
"http://localhost/api/main/images/%d/" % image["id"],
|
||||
)
|
||||
|
||||
# FIELDS
|
||||
|
||||
def test_fields_default(self):
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
|
||||
def test_fields(self):
|
||||
response = self.get_response(fields="width,height")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(
|
||||
set(image.keys()), {"id", "meta", "title", "width", "height"}
|
||||
)
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
|
||||
def test_remove_fields(self):
|
||||
response = self.get_response(fields="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "meta"})
|
||||
|
||||
def test_remove_meta_fields(self):
|
||||
response = self.get_response(fields="-tags")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()), {"type", "detail_url", "download_url"}
|
||||
)
|
||||
|
||||
def test_remove_all_meta_fields(self):
|
||||
response = self.get_response(fields="-type,-detail_url,-tags")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "title", "meta"})
|
||||
|
||||
def test_remove_id_field(self):
|
||||
response = self.get_response(fields="-id")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"meta", "title"})
|
||||
|
||||
def test_all_fields(self):
|
||||
response = self.get_response(fields="*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(
|
||||
set(image.keys()), {"id", "meta", "title", "width", "height"}
|
||||
)
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
|
||||
def test_all_fields_then_remove_something(self):
|
||||
response = self.get_response(fields="*,-title,-tags")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "meta", "width", "height"})
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()), {"type", "detail_url", "download_url"}
|
||||
)
|
||||
|
||||
def test_fields_tags(self):
|
||||
response = self.get_response(fields="tags")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
for image in content["items"]:
|
||||
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(set(image.keys()), {"id", "meta", "title"})
|
||||
self.assertEqual(
|
||||
set(image["meta"].keys()),
|
||||
{"type", "detail_url", "tags", "download_url"},
|
||||
)
|
||||
self.assertIsInstance(image["meta"]["tags"], list)
|
||||
|
||||
def test_star_in_wrong_position_gives_error(self):
|
||||
response = self.get_response(fields="title,*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "fields error: '*' must be in the first position"}
|
||||
)
|
||||
|
||||
def test_fields_which_are_not_in_api_fields_gives_error(self):
|
||||
response = self.get_response(fields="uploaded_by_user")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: uploaded_by_user"})
|
||||
|
||||
def test_fields_unknown_field_gives_error(self):
|
||||
response = self.get_response(fields="123,title,abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_fields_remove_unknown_field_gives_error(self):
|
||||
response = self.get_response(fields="-123,-title,-abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
# FILTERING
|
||||
|
||||
def test_filtering_exact_filter(self):
|
||||
response = self.get_response(title="James Joyce")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list, [5])
|
||||
|
||||
def test_filtering_on_id(self):
|
||||
response = self.get_response(id=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list, [10])
|
||||
|
||||
def test_filtering_tags(self):
|
||||
get_image_model().objects.get(id=6).tags.add("test")
|
||||
|
||||
response = self.get_response(tags="test")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list, [6])
|
||||
|
||||
def test_filtering_unknown_field_gives_error(self):
|
||||
response = self.get_response(not_a_field="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content,
|
||||
{
|
||||
"message": "query parameter is not an operation or a recognised field: not_a_field"
|
||||
},
|
||||
)
|
||||
|
||||
# ORDERING
|
||||
|
||||
def test_ordering_by_title(self):
|
||||
response = self.get_response(order="title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list, [6, 15, 13, 5, 10, 11, 8, 7, 4, 14, 12, 9])
|
||||
|
||||
def test_ordering_by_title_backwards(self):
|
||||
response = self.get_response(order="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list, [9, 12, 14, 4, 7, 8, 11, 10, 5, 13, 15, 6])
|
||||
|
||||
def test_ordering_by_random(self):
|
||||
response_1 = self.get_response(order="random")
|
||||
content_1 = json.loads(response_1.content.decode("UTF-8"))
|
||||
image_id_list_1 = self.get_image_id_list(content_1)
|
||||
|
||||
response_2 = self.get_response(order="random")
|
||||
content_2 = json.loads(response_2.content.decode("UTF-8"))
|
||||
image_id_list_2 = self.get_image_id_list(content_2)
|
||||
|
||||
self.assertNotEqual(image_id_list_1, image_id_list_2)
|
||||
|
||||
def test_ordering_by_random_backwards_gives_error(self):
|
||||
response = self.get_response(order="-random")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "cannot order by 'random' (unknown field)"}
|
||||
)
|
||||
|
||||
def test_ordering_by_random_with_offset_gives_error(self):
|
||||
response = self.get_response(order="random", offset=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "random ordering with offset is not supported"}
|
||||
)
|
||||
|
||||
def test_ordering_by_unknown_field_gives_error(self):
|
||||
response = self.get_response(order="not_a_field")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "cannot order by 'not_a_field' (unknown field)"}
|
||||
)
|
||||
|
||||
# LIMIT
|
||||
|
||||
def test_limit_only_two_items_returned(self):
|
||||
response = self.get_response(limit=2)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(len(content["items"]), 2)
|
||||
|
||||
def test_limit_total_count(self):
|
||||
response = self.get_response(limit=2)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# The total count must not be affected by "limit"
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_image_model().objects.count()
|
||||
)
|
||||
|
||||
def test_limit_not_integer_gives_error(self):
|
||||
response = self.get_response(limit="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit must be a positive integer"})
|
||||
|
||||
def test_limit_too_high_gives_error(self):
|
||||
response = self.get_response(limit=1000)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit cannot be higher than 20"})
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=None)
|
||||
def test_limit_max_none_gives_no_errors(self):
|
||||
response = self.get_response(limit=1000000)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(len(content["items"]), get_image_model().objects.count())
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=10)
|
||||
def test_limit_maximum_can_be_changed(self):
|
||||
response = self.get_response(limit=20)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "limit cannot be higher than 10"})
|
||||
|
||||
@override_settings(WAGTAILAPI_LIMIT_MAX=2)
|
||||
def test_limit_default_changes_with_max(self):
|
||||
# The default limit is 20. If WAGTAILAPI_LIMIT_MAX is less than that,
|
||||
# the default should change accordingly.
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(len(content["items"]), 2)
|
||||
|
||||
# OFFSET
|
||||
|
||||
def test_offset_10_usually_appears_7th_in_list(self):
|
||||
response = self.get_response()
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list.index(10), 6)
|
||||
|
||||
def test_offset_10_moves_after_offset(self):
|
||||
response = self.get_response(offset=4)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
self.assertEqual(image_id_list.index(10), 2)
|
||||
|
||||
def test_offset_total_count(self):
|
||||
response = self.get_response(offset=10)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# The total count must not be affected by "offset"
|
||||
self.assertEqual(
|
||||
content["meta"]["total_count"], get_image_model().objects.count()
|
||||
)
|
||||
|
||||
def test_offset_not_integer_gives_error(self):
|
||||
response = self.get_response(offset="abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "offset must be a positive integer"})
|
||||
|
||||
|
||||
class TestImageListingSearch(TransactionTestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:images:listing"), params)
|
||||
|
||||
def get_image_id_list(self, content):
|
||||
return [image["id"] for image in content["items"]]
|
||||
|
||||
def test_search_for_james_joyce(self):
|
||||
response = self.get_response(search="james")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
|
||||
self.assertEqual(set(image_id_list), {5})
|
||||
|
||||
def test_search_with_order(self):
|
||||
response = self.get_response(search="james", order="title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
image_id_list = self.get_image_id_list(content)
|
||||
|
||||
self.assertEqual(image_id_list, [5])
|
||||
|
||||
@override_settings(WAGTAILAPI_SEARCH_ENABLED=False)
|
||||
def test_search_when_disabled_gives_error(self):
|
||||
response = self.get_response(search="james")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "search is disabled"})
|
||||
|
||||
def test_search_when_filtering_by_tag_gives_error(self):
|
||||
response = self.get_response(search="james", tags="wagtail")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content,
|
||||
{"message": "filtering by tag with a search query is not supported"},
|
||||
)
|
||||
|
||||
|
||||
class TestImageDetail(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, image_id, **params):
|
||||
return self.client.get(
|
||||
reverse("wagtailapi_v2:images:detail", args=(image_id,)), params
|
||||
)
|
||||
|
||||
def test_basic(self):
|
||||
response = self.get_response(5)
|
||||
|
||||
self.assertEqual(response.status_code, 200)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
# Check the id field
|
||||
self.assertIn("id", content)
|
||||
self.assertEqual(content["id"], 5)
|
||||
|
||||
# Check that the meta section is there
|
||||
self.assertIn("meta", content)
|
||||
self.assertIsInstance(content["meta"], dict)
|
||||
|
||||
# Check the meta type
|
||||
self.assertIn("type", content["meta"])
|
||||
self.assertEqual(content["meta"]["type"], "wagtailimages.Image")
|
||||
|
||||
# Check the meta detail_url
|
||||
self.assertIn("detail_url", content["meta"])
|
||||
self.assertEqual(
|
||||
content["meta"]["detail_url"], "http://localhost/api/main/images/5/"
|
||||
)
|
||||
|
||||
# Check the title field
|
||||
self.assertIn("title", content)
|
||||
self.assertEqual(content["title"], "James Joyce")
|
||||
|
||||
# Check the width and height fields
|
||||
self.assertIn("width", content)
|
||||
self.assertIn("height", content)
|
||||
self.assertEqual(content["width"], 500)
|
||||
self.assertEqual(content["height"], 392)
|
||||
|
||||
# Check the tags field
|
||||
self.assertIn("tags", content["meta"])
|
||||
self.assertEqual(content["meta"]["tags"], [])
|
||||
|
||||
def test_tags(self):
|
||||
image = get_image_model().objects.get(id=5)
|
||||
image.tags.add("hello")
|
||||
image.tags.add("world")
|
||||
|
||||
response = self.get_response(5)
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("tags", content["meta"])
|
||||
self.assertEqual(content["meta"]["tags"], ["hello", "world"])
|
||||
|
||||
# FIELDS
|
||||
|
||||
def test_remove_fields(self):
|
||||
response = self.get_response(5, fields="-title")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("id", set(content.keys()))
|
||||
self.assertNotIn("title", set(content.keys()))
|
||||
|
||||
def test_remove_meta_fields(self):
|
||||
response = self.get_response(5, fields="-type")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("detail_url", set(content["meta"].keys()))
|
||||
self.assertNotIn("type", set(content["meta"].keys()))
|
||||
|
||||
def test_remove_id_field(self):
|
||||
response = self.get_response(5, fields="-id")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertIn("title", set(content.keys()))
|
||||
self.assertNotIn("id", set(content.keys()))
|
||||
|
||||
def test_remove_all_fields(self):
|
||||
response = self.get_response(5, fields="_,id,type")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(set(content.keys()), {"id", "meta"})
|
||||
self.assertEqual(set(content["meta"].keys()), {"type"})
|
||||
|
||||
def test_star_in_wrong_position_gives_error(self):
|
||||
response = self.get_response(5, fields="title,*")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(
|
||||
content, {"message": "fields error: '*' must be in the first position"}
|
||||
)
|
||||
|
||||
def test_fields_which_are_not_in_api_fields_gives_error(self):
|
||||
response = self.get_response(5, fields="path")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: path"})
|
||||
|
||||
def test_fields_unknown_field_gives_error(self):
|
||||
response = self.get_response(5, fields="123,title,abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_fields_remove_unknown_field_gives_error(self):
|
||||
response = self.get_response(5, fields="-123,-title,-abc")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "unknown fields: 123, abc"})
|
||||
|
||||
def test_nested_fields_on_non_relational_field_gives_error(self):
|
||||
response = self.get_response(5, fields="title(foo,bar)")
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertEqual(content, {"message": "'title' does not support nested fields"})
|
||||
|
||||
|
||||
class TestImageFind(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
def get_response(self, **params):
|
||||
return self.client.get(reverse("wagtailapi_v2:images:find"), params)
|
||||
|
||||
def test_without_parameters(self):
|
||||
response = self.get_response()
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(content, {"message": "not found"})
|
||||
|
||||
def test_find_by_id(self):
|
||||
response = self.get_response(id=5)
|
||||
|
||||
self.assertRedirects(
|
||||
response,
|
||||
"http://localhost" + reverse("wagtailapi_v2:images:detail", args=[5]),
|
||||
fetch_redirect_response=False,
|
||||
)
|
||||
|
||||
def test_find_by_id_nonexistent(self):
|
||||
response = self.get_response(id=1234)
|
||||
|
||||
self.assertEqual(response.status_code, 404)
|
||||
self.assertEqual(response["Content-type"], "application/json")
|
||||
|
||||
# Will crash if the JSON is invalid
|
||||
content = json.loads(response.content.decode("UTF-8"))
|
||||
|
||||
self.assertEqual(content, {"message": "not found"})
|
||||
|
||||
|
||||
@override_settings(
|
||||
WAGTAILFRONTENDCACHE={
|
||||
"varnish": {
|
||||
"BACKEND": "wagtail.contrib.frontend_cache.backends.HTTPBackend",
|
||||
"LOCATION": "http://localhost:8000",
|
||||
},
|
||||
},
|
||||
WAGTAILAPI_BASE_URL="http://api.example.com",
|
||||
)
|
||||
@mock.patch("wagtail.contrib.frontend_cache.backends.http.HTTPBackend.purge")
|
||||
class TestImageCacheInvalidation(TestCase):
|
||||
fixtures = ["demosite.json"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
signal_handlers.register_signal_handlers()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
signal_handlers.unregister_signal_handlers()
|
||||
|
||||
def test_resave_image_purges(self, purge):
|
||||
get_image_model().objects.get(id=5).save()
|
||||
|
||||
purge.assert_any_call("http://api.example.com/api/main/images/5/")
|
||||
|
||||
def test_delete_image_purges(self, purge):
|
||||
get_image_model().objects.get(id=5).delete()
|
||||
|
||||
purge.assert_any_call("http://api.example.com/api/main/images/5/")
|
||||
1903
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_pages.py
vendored
Normal file
1903
env/lib/python3.10/site-packages/wagtail/api/v2/tests/test_pages.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
401
env/lib/python3.10/site-packages/wagtail/api/v2/tests/tests.py
vendored
Normal file
401
env/lib/python3.10/site-packages/wagtail/api/v2/tests/tests.py
vendored
Normal file
@@ -0,0 +1,401 @@
|
||||
from django.test import RequestFactory, TestCase, override_settings
|
||||
from django.utils.encoding import force_bytes
|
||||
|
||||
from wagtail.models import Site
|
||||
|
||||
from ..utils import (
|
||||
FieldsParameterParseError,
|
||||
get_base_url,
|
||||
parse_boolean,
|
||||
parse_fields_parameter,
|
||||
)
|
||||
|
||||
|
||||
class DynamicBaseUrl:
|
||||
def __str__(self):
|
||||
return "https://www.example.com"
|
||||
|
||||
def __bytes__(self):
|
||||
return force_bytes(self.__str__())
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
return self.__bytes__().decode(*args, **kwargs)
|
||||
|
||||
|
||||
class TestGetBaseUrl(TestCase):
|
||||
def setUp(self):
|
||||
Site.objects.all().delete()
|
||||
|
||||
def prepare_site(self):
|
||||
return Site.objects.get_or_create(
|
||||
hostname="other.example.com",
|
||||
port=8080,
|
||||
root_page_id=1,
|
||||
is_default_site=True,
|
||||
)[0]
|
||||
|
||||
def clear_cached_site(self, request):
|
||||
del request._wagtail_site
|
||||
|
||||
def test_get_base_url_unset(self):
|
||||
self.assertIsNone(get_base_url())
|
||||
|
||||
def test_get_base_url_from_request(self):
|
||||
# base url for siteless request should be None
|
||||
request = RequestFactory().get("/")
|
||||
self.assertIsNone(Site.find_for_request(request))
|
||||
self.assertIsNone(get_base_url(request))
|
||||
|
||||
# base url for request with a site should be based on the site's details
|
||||
site = self.prepare_site()
|
||||
self.clear_cached_site(request)
|
||||
self.assertEqual(site, Site.find_for_request(request))
|
||||
self.assertEqual(get_base_url(request), "http://other.example.com:8080")
|
||||
|
||||
# port 443 should indicate https without a port
|
||||
site.port = 443
|
||||
site.save()
|
||||
self.clear_cached_site(request)
|
||||
self.assertEqual(get_base_url(request), "https://other.example.com")
|
||||
|
||||
# port 80 should indicate http without a port
|
||||
site.port = 80
|
||||
site.save()
|
||||
self.clear_cached_site(request)
|
||||
self.assertEqual(get_base_url(request), "http://other.example.com")
|
||||
|
||||
@override_settings(WAGTAILAPI_BASE_URL="https://bar.example.com")
|
||||
def test_get_base_url_prefers_setting(self):
|
||||
request = RequestFactory().get("/")
|
||||
site = self.prepare_site()
|
||||
self.assertEqual(site, Site.find_for_request(request))
|
||||
self.assertEqual(get_base_url(request), "https://bar.example.com")
|
||||
with override_settings(WAGTAILAPI_BASE_URL=None):
|
||||
self.assertEqual(get_base_url(request), "http://other.example.com:8080")
|
||||
|
||||
@override_settings(WAGTAILAPI_BASE_URL="https://bar.example.com")
|
||||
def test_get_base_url_from_setting_string(self):
|
||||
self.assertEqual(get_base_url(), "https://bar.example.com")
|
||||
|
||||
@override_settings(WAGTAILAPI_BASE_URL=b"https://baz.example.com")
|
||||
def test_get_base_url_from_setting_bytes(self):
|
||||
self.assertEqual(get_base_url(), "https://baz.example.com")
|
||||
|
||||
@override_settings(WAGTAILAPI_BASE_URL=DynamicBaseUrl())
|
||||
def test_get_base_url_from_setting_object(self):
|
||||
self.assertEqual(get_base_url(), "https://www.example.com")
|
||||
|
||||
|
||||
class TestParseFieldsParameter(TestCase):
|
||||
# GOOD STUFF
|
||||
|
||||
def test_valid_single_field(self):
|
||||
parsed = parse_fields_parameter("test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("test", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_multiple_fields(self):
|
||||
parsed = parse_fields_parameter("test,another_test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("test", False, None),
|
||||
("another_test", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_negated_field(self):
|
||||
parsed = parse_fields_parameter("-test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("test", True, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_nested_fields(self):
|
||||
parsed = parse_fields_parameter("test(foo,bar)")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
(
|
||||
"test",
|
||||
False,
|
||||
[
|
||||
("foo", False, None),
|
||||
("bar", False, None),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_star_field(self):
|
||||
parsed = parse_fields_parameter("*,-test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("*", False, None),
|
||||
("test", True, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_star_with_additional_field(self):
|
||||
# Note: '*,test' is not allowed but '*,test(foo)' is
|
||||
parsed = parse_fields_parameter("*,test(foo)")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("*", False, None),
|
||||
(
|
||||
"test",
|
||||
False,
|
||||
[
|
||||
("foo", False, None),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_underscore_field(self):
|
||||
parsed = parse_fields_parameter("_,test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("_", False, None),
|
||||
("test", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_field_with_underscore_in_middle(self):
|
||||
parsed = parse_fields_parameter("a_test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("a_test", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_negated_field_with_underscore_in_middle(self):
|
||||
parsed = parse_fields_parameter("-a_test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("a_test", True, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_field_with_underscore_at_beginning(self):
|
||||
parsed = parse_fields_parameter("_test")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("_test", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
def test_valid_field_with_underscore_at_end(self):
|
||||
parsed = parse_fields_parameter("test_")
|
||||
|
||||
self.assertEqual(
|
||||
parsed,
|
||||
[
|
||||
("test_", False, None),
|
||||
],
|
||||
)
|
||||
|
||||
# BAD STUFF
|
||||
|
||||
def test_invalid_char(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test#")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '#' at position 4")
|
||||
|
||||
def test_invalid_whitespace_before_identifier(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter(" test")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected whitespace at position 0")
|
||||
|
||||
def test_invalid_whitespace_after_identifier(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test ")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected whitespace at position 4")
|
||||
|
||||
def test_invalid_whitespace_after_comma(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test, test")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected whitespace at position 5")
|
||||
|
||||
def test_invalid_whitespace_before_comma(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test ,test")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected whitespace at position 4")
|
||||
|
||||
def test_invalid_unexpected_negation_operator(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test-")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '-' at position 4")
|
||||
|
||||
def test_invalid_unexpected_open_bracket(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test,(foo)")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '(' at position 5")
|
||||
|
||||
def test_invalid_unexpected_close_bracket(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test)")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char ')' at position 4")
|
||||
|
||||
def test_invalid_unexpected_comma_in_middle(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test,,foo")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char ',' at position 5")
|
||||
|
||||
def test_invalid_unexpected_comma_at_end(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test,foo,")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char ',' at position 9")
|
||||
|
||||
def test_invalid_unclosed_bracket(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test(foo")
|
||||
|
||||
self.assertEqual(
|
||||
str(e.exception),
|
||||
"unexpected end of input (did you miss out a close bracket?)",
|
||||
)
|
||||
|
||||
def test_invalid_subfields_on_negated_field(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("-test(foo)")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '(' at position 5")
|
||||
|
||||
def test_invalid_star_field_in_wrong_position(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test,*")
|
||||
|
||||
self.assertEqual(str(e.exception), "'*' must be in the first position")
|
||||
|
||||
def test_invalid_negated_star(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("-*")
|
||||
|
||||
self.assertEqual(str(e.exception), "'*' cannot be negated")
|
||||
|
||||
def test_invalid_star_with_nesting(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("*(foo,bar)")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '(' at position 1")
|
||||
|
||||
def test_invalid_star_with_chars_after(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("*foo")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char 'f' at position 1")
|
||||
|
||||
def test_invalid_star_with_chars_before(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("foo*")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '*' at position 3")
|
||||
|
||||
def test_invalid_star_with_additional_field(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("*,foo")
|
||||
|
||||
self.assertEqual(
|
||||
str(e.exception), "additional fields with '*' doesn't make sense"
|
||||
)
|
||||
|
||||
def test_invalid_underscore_in_wrong_position(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("test,_")
|
||||
|
||||
self.assertEqual(str(e.exception), "'_' must be in the first position")
|
||||
|
||||
def test_invalid_negated_underscore(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("-_")
|
||||
|
||||
self.assertEqual(str(e.exception), "'_' cannot be negated")
|
||||
|
||||
def test_invalid_underscore_with_nesting(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("_(foo,bar)")
|
||||
|
||||
self.assertEqual(str(e.exception), "unexpected char '(' at position 1")
|
||||
|
||||
def test_invalid_underscore_with_negated_field(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("_,-foo")
|
||||
|
||||
self.assertEqual(str(e.exception), "negated fields with '_' doesn't make sense")
|
||||
|
||||
def test_invalid_star_and_underscore(self):
|
||||
with self.assertRaises(FieldsParameterParseError) as e:
|
||||
parse_fields_parameter("*,_")
|
||||
|
||||
self.assertEqual(str(e.exception), "'_' must be in the first position")
|
||||
|
||||
|
||||
class TestParseBoolean(TestCase):
|
||||
# GOOD STUFF
|
||||
|
||||
def test_valid_true(self):
|
||||
parsed = parse_boolean("true")
|
||||
|
||||
self.assertIs(parsed, True)
|
||||
|
||||
def test_valid_false(self):
|
||||
parsed = parse_boolean("false")
|
||||
|
||||
self.assertIs(parsed, False)
|
||||
|
||||
def test_valid_1(self):
|
||||
parsed = parse_boolean("1")
|
||||
|
||||
self.assertIs(parsed, True)
|
||||
|
||||
def test_valid_0(self):
|
||||
parsed = parse_boolean("0")
|
||||
|
||||
self.assertIs(parsed, False)
|
||||
|
||||
# BAD STUFF
|
||||
|
||||
def test_invalid(self):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
parse_boolean("foo")
|
||||
|
||||
self.assertEqual(str(e.exception), "expected 'true' or 'false', got 'foo'")
|
||||
|
||||
def test_invalid_integer(self):
|
||||
with self.assertRaises(ValueError) as e:
|
||||
parse_boolean("2")
|
||||
|
||||
self.assertEqual(str(e.exception), "expected 'true' or 'false', got '2'")
|
||||
270
env/lib/python3.10/site-packages/wagtail/api/v2/utils.py
vendored
Normal file
270
env/lib/python3.10/site-packages/wagtail/api/v2/utils.py
vendored
Normal file
@@ -0,0 +1,270 @@
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from django.conf import settings
|
||||
from django.utils.encoding import force_str
|
||||
|
||||
from wagtail.coreutils import resolve_model_string
|
||||
from wagtail.models import Page, Site
|
||||
|
||||
|
||||
class BadRequestError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def get_base_url(request=None):
|
||||
base_url = getattr(settings, "WAGTAILAPI_BASE_URL", None)
|
||||
|
||||
if base_url is None and request:
|
||||
site = Site.find_for_request(request)
|
||||
if site:
|
||||
base_url = site.root_url
|
||||
|
||||
if base_url:
|
||||
# We only want the scheme and netloc
|
||||
base_url_parsed = urlsplit(force_str(base_url))
|
||||
|
||||
return base_url_parsed.scheme + "://" + base_url_parsed.netloc
|
||||
|
||||
|
||||
def get_full_url(request, path):
|
||||
if path.startswith(("http://", "https://")):
|
||||
return path
|
||||
base_url = get_base_url(request) or ""
|
||||
return base_url + path
|
||||
|
||||
|
||||
def get_object_detail_url(router, request, model, pk):
|
||||
url_path = router.get_object_detail_urlpath(model, pk)
|
||||
|
||||
if url_path:
|
||||
return get_full_url(request, url_path)
|
||||
|
||||
|
||||
def page_models_from_string(string):
|
||||
page_models = []
|
||||
|
||||
for sub_string in string.split(","):
|
||||
page_model = resolve_model_string(sub_string)
|
||||
|
||||
if not issubclass(page_model, Page):
|
||||
raise ValueError("Model is not a page")
|
||||
|
||||
page_models.append(page_model)
|
||||
|
||||
return tuple(page_models)
|
||||
|
||||
|
||||
class FieldsParameterParseError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def parse_fields_parameter(fields_str):
|
||||
"""
|
||||
Parses the ?fields= GET parameter. As this parameter is supposed to be used
|
||||
by developers, the syntax is quite tight (eg, not allowing any whitespace).
|
||||
Having a strict syntax allows us to extend the it at a later date with less
|
||||
chance of breaking anyone's code.
|
||||
|
||||
This function takes a string and returns a list of tuples representing each
|
||||
top-level field. Each tuple contains three items:
|
||||
- The name of the field (string)
|
||||
- Whether the field has been negated (boolean)
|
||||
- A list of nested fields if there are any, None otherwise
|
||||
|
||||
Some examples of how this function works:
|
||||
|
||||
>>> parse_fields_parameter("foo")
|
||||
[
|
||||
('foo', False, None),
|
||||
]
|
||||
|
||||
>>> parse_fields_parameter("foo,bar")
|
||||
[
|
||||
('foo', False, None),
|
||||
('bar', False, None),
|
||||
]
|
||||
|
||||
>>> parse_fields_parameter("-foo")
|
||||
[
|
||||
('foo', True, None),
|
||||
]
|
||||
|
||||
>>> parse_fields_parameter("foo(bar,baz)")
|
||||
[
|
||||
('foo', False, [
|
||||
('bar', False, None),
|
||||
('baz', False, None),
|
||||
]),
|
||||
]
|
||||
|
||||
It raises a FieldsParameterParseError (subclass of ValueError) if it
|
||||
encounters a syntax error
|
||||
"""
|
||||
|
||||
def get_position(current_str):
|
||||
return len(fields_str) - len(current_str)
|
||||
|
||||
def parse_field_identifier(fields_str):
|
||||
first_char = True
|
||||
negated = False
|
||||
ident = ""
|
||||
|
||||
while fields_str:
|
||||
char = fields_str[0]
|
||||
|
||||
if char in ["(", ")", ","]:
|
||||
if not ident:
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
if ident in ["*", "_"] and char == "(":
|
||||
# * and _ cannot have nested fields
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
return ident, negated, fields_str
|
||||
|
||||
elif char == "-":
|
||||
if not first_char:
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
negated = True
|
||||
|
||||
elif char in ["*", "_"]:
|
||||
if ident and char == "*":
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
ident += char
|
||||
|
||||
elif char.isalnum() or char == "_":
|
||||
if ident == "*":
|
||||
# * can only be on its own
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
ident += char
|
||||
|
||||
elif char.isspace():
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected whitespace at position %d" % get_position(fields_str)
|
||||
)
|
||||
else:
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '%s' at position %d"
|
||||
% (char, get_position(fields_str))
|
||||
)
|
||||
|
||||
first_char = False
|
||||
fields_str = fields_str[1:]
|
||||
|
||||
return ident, negated, fields_str
|
||||
|
||||
def parse_fields(fields_str, expect_close_bracket=False):
|
||||
first_ident = None
|
||||
is_first = True
|
||||
fields = []
|
||||
|
||||
while fields_str:
|
||||
sub_fields = None
|
||||
ident, negated, fields_str = parse_field_identifier(fields_str)
|
||||
|
||||
# Some checks specific to '*' and '_'
|
||||
if ident in ["*", "_"]:
|
||||
if not is_first:
|
||||
raise FieldsParameterParseError(
|
||||
"'%s' must be in the first position" % ident
|
||||
)
|
||||
|
||||
if negated:
|
||||
raise FieldsParameterParseError("'%s' cannot be negated" % ident)
|
||||
|
||||
if fields_str and fields_str[0] == "(":
|
||||
if negated:
|
||||
# Negated fields cannot contain subfields
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char '(' at position %d" % get_position(fields_str)
|
||||
)
|
||||
|
||||
sub_fields, fields_str = parse_fields(
|
||||
fields_str[1:], expect_close_bracket=True
|
||||
)
|
||||
|
||||
if is_first:
|
||||
first_ident = ident
|
||||
else:
|
||||
# Negated fields can't be used with '_'
|
||||
if first_ident == "_" and negated:
|
||||
# _,foo is allowed but _,-foo is not
|
||||
raise FieldsParameterParseError(
|
||||
"negated fields with '_' doesn't make sense"
|
||||
)
|
||||
|
||||
# Additional fields without sub fields can't be used with '*'
|
||||
if first_ident == "*" and not negated and not sub_fields:
|
||||
# *,foo(bar) and *,-foo are allowed but *,foo is not
|
||||
raise FieldsParameterParseError(
|
||||
"additional fields with '*' doesn't make sense"
|
||||
)
|
||||
|
||||
fields.append((ident, negated, sub_fields))
|
||||
|
||||
if fields_str and fields_str[0] == ")":
|
||||
if not expect_close_bracket:
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char ')' at position %d" % get_position(fields_str)
|
||||
)
|
||||
|
||||
return fields, fields_str[1:]
|
||||
|
||||
if fields_str and fields_str[0] == ",":
|
||||
fields_str = fields_str[1:]
|
||||
|
||||
# A comma can not exist immediately before another comma or the end of the string
|
||||
if not fields_str or fields_str[0] == ",":
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected char ',' at position %d" % get_position(fields_str)
|
||||
)
|
||||
|
||||
is_first = False
|
||||
|
||||
if expect_close_bracket:
|
||||
# This parser should've exited with a close bracket but instead we
|
||||
# hit the end of the input. Raise an error
|
||||
raise FieldsParameterParseError(
|
||||
"unexpected end of input (did you miss out a close bracket?)"
|
||||
)
|
||||
|
||||
return fields, fields_str
|
||||
|
||||
fields, _ = parse_fields(fields_str)
|
||||
|
||||
return fields
|
||||
|
||||
|
||||
def parse_boolean(value):
|
||||
"""
|
||||
Parses strings into booleans using the following mapping (case-sensitive):
|
||||
|
||||
'true' => True
|
||||
'false' => False
|
||||
'1' => True
|
||||
'0' => False
|
||||
"""
|
||||
if value in ["true", "1"]:
|
||||
return True
|
||||
elif value in ["false", "0"]:
|
||||
return False
|
||||
else:
|
||||
raise ValueError("expected 'true' or 'false', got '%s'" % value)
|
||||
607
env/lib/python3.10/site-packages/wagtail/api/v2/views.py
vendored
Normal file
607
env/lib/python3.10/site-packages/wagtail/api/v2/views.py
vendored
Normal file
@@ -0,0 +1,607 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.http import Http404
|
||||
from django.shortcuts import redirect
|
||||
from django.urls import path, reverse
|
||||
from modelcluster.fields import ParentalKey
|
||||
from rest_framework import status
|
||||
from rest_framework.renderers import BrowsableAPIRenderer, JSONRenderer
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import GenericViewSet
|
||||
|
||||
from wagtail.api import APIField
|
||||
from wagtail.models import Page, PageViewRestriction, Site
|
||||
|
||||
from .filters import (
|
||||
AncestorOfFilter,
|
||||
ChildOfFilter,
|
||||
DescendantOfFilter,
|
||||
FieldsFilter,
|
||||
LocaleFilter,
|
||||
OrderingFilter,
|
||||
SearchFilter,
|
||||
TranslationOfFilter,
|
||||
)
|
||||
from .pagination import WagtailPagination
|
||||
from .serializers import BaseSerializer, PageSerializer, get_serializer_class
|
||||
from .utils import (
|
||||
BadRequestError,
|
||||
get_object_detail_url,
|
||||
page_models_from_string,
|
||||
parse_fields_parameter,
|
||||
)
|
||||
|
||||
|
||||
class BaseAPIViewSet(GenericViewSet):
|
||||
renderer_classes = [JSONRenderer, BrowsableAPIRenderer]
|
||||
|
||||
pagination_class = WagtailPagination
|
||||
base_serializer_class = BaseSerializer
|
||||
filter_backends = []
|
||||
model = None # Set on subclass
|
||||
|
||||
known_query_parameters = frozenset(
|
||||
[
|
||||
"limit",
|
||||
"offset",
|
||||
"fields",
|
||||
"order",
|
||||
"search",
|
||||
"search_operator",
|
||||
# Used by jQuery for cache-busting. See #1671
|
||||
"_",
|
||||
# Required by BrowsableAPIRenderer
|
||||
"format",
|
||||
]
|
||||
)
|
||||
body_fields = ["id"]
|
||||
meta_fields = ["type", "detail_url"]
|
||||
listing_default_fields = ["id", "type", "detail_url"]
|
||||
nested_default_fields = ["id", "type", "detail_url"]
|
||||
detail_only_fields = []
|
||||
name = None # Set on subclass.
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# seen_types is a mapping of type name strings (format: "app_label.ModelName")
|
||||
# to model classes. When an object is serialised in the API, its model
|
||||
# is added to this mapping. This is used by the Admin API which appends a
|
||||
# summary of the used types to the response.
|
||||
self.seen_types = OrderedDict()
|
||||
|
||||
def get_queryset(self):
|
||||
return self.model.objects.all().order_by("id")
|
||||
|
||||
def listing_view(self, request):
|
||||
queryset = self.get_queryset()
|
||||
self.check_query_parameters(queryset)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
queryset = self.paginate_queryset(queryset)
|
||||
serializer = self.get_serializer(queryset, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
def detail_view(self, request, pk):
|
||||
instance = self.get_object()
|
||||
serializer = self.get_serializer(instance)
|
||||
return Response(serializer.data)
|
||||
|
||||
def find_view(self, request):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
try:
|
||||
obj = self.find_object(queryset, request)
|
||||
|
||||
if obj is None:
|
||||
raise self.model.DoesNotExist
|
||||
|
||||
except self.model.DoesNotExist:
|
||||
raise Http404("not found")
|
||||
|
||||
# Generate redirect
|
||||
url = get_object_detail_url(
|
||||
self.request.wagtailapi_router, request, self.model, obj.pk
|
||||
)
|
||||
|
||||
if url is None:
|
||||
# Shouldn't happen unless this endpoint isn't actually installed in the router
|
||||
raise Exception(
|
||||
"Cannot generate URL to detail view. Is '{}' installed in the API router?".format(
|
||||
self.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
return redirect(url)
|
||||
|
||||
def find_object(self, queryset, request):
|
||||
"""
|
||||
Override this to implement more find methods.
|
||||
"""
|
||||
if "id" in request.GET:
|
||||
return queryset.get(id=request.GET["id"])
|
||||
|
||||
def handle_exception(self, exc):
|
||||
if isinstance(exc, Http404):
|
||||
data = {"message": str(exc)}
|
||||
return Response(data, status=status.HTTP_404_NOT_FOUND)
|
||||
elif isinstance(exc, BadRequestError):
|
||||
data = {"message": str(exc)}
|
||||
return Response(data, status=status.HTTP_400_BAD_REQUEST)
|
||||
return super().handle_exception(exc)
|
||||
|
||||
@classmethod
|
||||
def _convert_api_fields(cls, fields):
|
||||
return [
|
||||
field if isinstance(field, APIField) else APIField(field)
|
||||
for field in fields
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_body_fields(cls, model):
|
||||
return cls._convert_api_fields(
|
||||
cls.body_fields + list(getattr(model, "api_fields", ()))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_body_fields_names(cls, model):
|
||||
return [field.name for field in cls.get_body_fields(model)]
|
||||
|
||||
@classmethod
|
||||
def get_meta_fields(cls, model):
|
||||
return cls._convert_api_fields(
|
||||
cls.meta_fields + list(getattr(model, "api_meta_fields", ()))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_meta_fields_names(cls, model):
|
||||
return [field.name for field in cls.get_meta_fields(model)]
|
||||
|
||||
@classmethod
|
||||
def get_field_serializer_overrides(cls, model):
|
||||
return {
|
||||
field.name: field.serializer
|
||||
for field in cls.get_body_fields(model) + cls.get_meta_fields(model)
|
||||
if field.serializer is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_available_fields(cls, model, db_fields_only=False):
|
||||
"""
|
||||
Returns a list of all the fields that can be used in the API for the
|
||||
specified model class.
|
||||
|
||||
Setting db_fields_only to True will remove all fields that do not have
|
||||
an underlying column in the database (eg, type/detail_url and any custom
|
||||
fields that are callables)
|
||||
"""
|
||||
fields = cls.get_body_fields_names(model) + cls.get_meta_fields_names(model)
|
||||
|
||||
if db_fields_only:
|
||||
# Get list of available database fields then remove any fields in our
|
||||
# list that isn't a database field
|
||||
database_fields = set()
|
||||
for field in model._meta.get_fields():
|
||||
database_fields.add(field.name)
|
||||
|
||||
if hasattr(field, "attname"):
|
||||
database_fields.add(field.attname)
|
||||
|
||||
fields = [field for field in fields if field in database_fields]
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def get_detail_default_fields(cls, model):
|
||||
return cls.get_available_fields(model)
|
||||
|
||||
@classmethod
|
||||
def get_listing_default_fields(cls, model):
|
||||
return cls.listing_default_fields[:]
|
||||
|
||||
@classmethod
|
||||
def get_nested_default_fields(cls, model):
|
||||
return cls.nested_default_fields[:]
|
||||
|
||||
def check_query_parameters(self, queryset):
|
||||
"""
|
||||
Ensure that only valid query parameters are included in the URL.
|
||||
"""
|
||||
query_parameters = set(self.request.GET.keys())
|
||||
|
||||
# All query parameters must be either a database field or an operation
|
||||
allowed_query_parameters = set(
|
||||
self.get_available_fields(queryset.model, db_fields_only=True)
|
||||
).union(self.known_query_parameters)
|
||||
unknown_parameters = query_parameters - allowed_query_parameters
|
||||
if unknown_parameters:
|
||||
raise BadRequestError(
|
||||
"query parameter is not an operation or a recognised field: %s"
|
||||
% ", ".join(sorted(unknown_parameters))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_serializer_class(
|
||||
cls, router, model, fields_config, show_details=False, nested=False
|
||||
):
|
||||
# Get all available fields
|
||||
body_fields = cls.get_body_fields_names(model)
|
||||
meta_fields = cls.get_meta_fields_names(model)
|
||||
all_fields = body_fields + meta_fields
|
||||
|
||||
# Remove any duplicates
|
||||
all_fields = list(OrderedDict.fromkeys(all_fields))
|
||||
|
||||
if not show_details:
|
||||
# Remove detail only fields
|
||||
for field in cls.detail_only_fields:
|
||||
try:
|
||||
all_fields.remove(field)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Get list of configured fields
|
||||
if show_details:
|
||||
fields = set(cls.get_detail_default_fields(model))
|
||||
elif nested:
|
||||
fields = set(cls.get_nested_default_fields(model))
|
||||
else:
|
||||
fields = set(cls.get_listing_default_fields(model))
|
||||
|
||||
# If first field is '*' start with all fields
|
||||
# If first field is '_' start with no fields
|
||||
if fields_config and fields_config[0][0] == "*":
|
||||
fields = set(all_fields)
|
||||
fields_config = fields_config[1:]
|
||||
elif fields_config and fields_config[0][0] == "_":
|
||||
fields = set()
|
||||
fields_config = fields_config[1:]
|
||||
|
||||
mentioned_fields = set()
|
||||
sub_fields = {}
|
||||
|
||||
for field_name, negated, field_sub_fields in fields_config:
|
||||
if negated:
|
||||
try:
|
||||
fields.remove(field_name)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
fields.add(field_name)
|
||||
if field_sub_fields:
|
||||
sub_fields[field_name] = field_sub_fields
|
||||
|
||||
mentioned_fields.add(field_name)
|
||||
|
||||
unknown_fields = mentioned_fields - set(all_fields)
|
||||
|
||||
if unknown_fields:
|
||||
raise BadRequestError(
|
||||
"unknown fields: %s" % ", ".join(sorted(unknown_fields))
|
||||
)
|
||||
|
||||
# Build nested serialisers
|
||||
child_serializer_classes = {}
|
||||
|
||||
for field_name in fields:
|
||||
try:
|
||||
django_field = model._meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
django_field = None
|
||||
|
||||
if django_field and django_field.is_relation:
|
||||
child_sub_fields = sub_fields.get(field_name, [])
|
||||
|
||||
# Inline (aka "child") models should display all fields by default
|
||||
if isinstance(getattr(django_field, "field", None), ParentalKey):
|
||||
if not child_sub_fields or child_sub_fields[0][0] not in ["*", "_"]:
|
||||
child_sub_fields = list(child_sub_fields)
|
||||
child_sub_fields.insert(0, ("*", False, None))
|
||||
|
||||
# Get a serializer class for the related object
|
||||
child_model = django_field.related_model
|
||||
child_endpoint_class = router.get_model_endpoint(child_model)
|
||||
child_endpoint_class = (
|
||||
child_endpoint_class[1] if child_endpoint_class else BaseAPIViewSet
|
||||
)
|
||||
child_serializer_classes[
|
||||
field_name
|
||||
] = child_endpoint_class._get_serializer_class(
|
||||
router, child_model, child_sub_fields, nested=True
|
||||
)
|
||||
|
||||
else:
|
||||
if field_name in sub_fields:
|
||||
# Sub fields were given for a non-related field
|
||||
raise BadRequestError(
|
||||
"'%s' does not support nested fields" % field_name
|
||||
)
|
||||
|
||||
# Reorder fields so it matches the order of all_fields
|
||||
fields = [field for field in all_fields if field in fields]
|
||||
|
||||
field_serializer_overrides = {
|
||||
field[0]: field[1]
|
||||
for field in cls.get_field_serializer_overrides(model).items()
|
||||
if field[0] in fields
|
||||
}
|
||||
return get_serializer_class(
|
||||
model,
|
||||
fields,
|
||||
meta_fields=meta_fields,
|
||||
field_serializer_overrides=field_serializer_overrides,
|
||||
child_serializer_classes=child_serializer_classes,
|
||||
base=cls.base_serializer_class,
|
||||
)
|
||||
|
||||
def get_serializer_class(self):
|
||||
request = self.request
|
||||
|
||||
# Get model
|
||||
if self.action == "listing_view":
|
||||
model = self.get_queryset().model
|
||||
else:
|
||||
model = type(self.get_object())
|
||||
|
||||
# Fields
|
||||
if "fields" in request.GET:
|
||||
try:
|
||||
fields_config = parse_fields_parameter(request.GET["fields"])
|
||||
except ValueError as e:
|
||||
raise BadRequestError("fields error: %s" % str(e))
|
||||
else:
|
||||
# Use default fields
|
||||
fields_config = []
|
||||
|
||||
# Allow "detail_only" (eg parent) fields on detail view
|
||||
if self.action == "listing_view":
|
||||
show_details = False
|
||||
else:
|
||||
show_details = True
|
||||
|
||||
return self._get_serializer_class(
|
||||
self.request.wagtailapi_router,
|
||||
model,
|
||||
fields_config,
|
||||
show_details=show_details,
|
||||
)
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
The serialization context differs between listing and detail views.
|
||||
"""
|
||||
return {
|
||||
"request": self.request,
|
||||
"view": self,
|
||||
"router": self.request.wagtailapi_router,
|
||||
}
|
||||
|
||||
def get_renderer_context(self):
|
||||
context = super().get_renderer_context()
|
||||
context["indent"] = 4
|
||||
return context
|
||||
|
||||
@classmethod
|
||||
def get_urlpatterns(cls):
|
||||
"""
|
||||
This returns a list of URL patterns for the endpoint
|
||||
"""
|
||||
return [
|
||||
path("", cls.as_view({"get": "listing_view"}), name="listing"),
|
||||
path("<int:pk>/", cls.as_view({"get": "detail_view"}), name="detail"),
|
||||
path("find/", cls.as_view({"get": "find_view"}), name="find"),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_model_listing_urlpath(cls, model, namespace=""):
|
||||
if namespace:
|
||||
url_name = namespace + ":listing"
|
||||
else:
|
||||
url_name = "listing"
|
||||
|
||||
return reverse(url_name)
|
||||
|
||||
@classmethod
|
||||
def get_object_detail_urlpath(cls, model, pk, namespace=""):
|
||||
if namespace:
|
||||
url_name = namespace + ":detail"
|
||||
else:
|
||||
url_name = "detail"
|
||||
|
||||
return reverse(url_name, args=(pk,))
|
||||
|
||||
|
||||
class PagesAPIViewSet(BaseAPIViewSet):
|
||||
base_serializer_class = PageSerializer
|
||||
filter_backends = [
|
||||
FieldsFilter,
|
||||
ChildOfFilter,
|
||||
AncestorOfFilter,
|
||||
DescendantOfFilter,
|
||||
OrderingFilter,
|
||||
TranslationOfFilter,
|
||||
LocaleFilter,
|
||||
SearchFilter, # needs to be last, as SearchResults querysets cannot be filtered further
|
||||
]
|
||||
known_query_parameters = BaseAPIViewSet.known_query_parameters.union(
|
||||
[
|
||||
"type",
|
||||
"child_of",
|
||||
"ancestor_of",
|
||||
"descendant_of",
|
||||
"translation_of",
|
||||
"locale",
|
||||
"site",
|
||||
]
|
||||
)
|
||||
body_fields = BaseAPIViewSet.body_fields + [
|
||||
"title",
|
||||
]
|
||||
meta_fields = BaseAPIViewSet.meta_fields + [
|
||||
"html_url",
|
||||
"slug",
|
||||
"show_in_menus",
|
||||
"seo_title",
|
||||
"search_description",
|
||||
"first_published_at",
|
||||
"alias_of",
|
||||
"parent",
|
||||
"locale",
|
||||
]
|
||||
listing_default_fields = BaseAPIViewSet.listing_default_fields + [
|
||||
"title",
|
||||
"html_url",
|
||||
"slug",
|
||||
"first_published_at",
|
||||
]
|
||||
nested_default_fields = BaseAPIViewSet.nested_default_fields + [
|
||||
"title",
|
||||
]
|
||||
detail_only_fields = ["parent"]
|
||||
name = "pages"
|
||||
model = Page
|
||||
|
||||
@classmethod
|
||||
def get_detail_default_fields(cls, model):
|
||||
detail_default_fields = super().get_detail_default_fields(model)
|
||||
|
||||
# When i18n is disabled, remove "locale" from default fields
|
||||
if not getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
detail_default_fields.remove("locale")
|
||||
|
||||
return detail_default_fields
|
||||
|
||||
@classmethod
|
||||
def get_listing_default_fields(cls, model):
|
||||
listing_default_fields = super().get_listing_default_fields(model)
|
||||
|
||||
# When i18n is enabled, add "locale" to default fields
|
||||
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
listing_default_fields.append("locale")
|
||||
|
||||
return listing_default_fields
|
||||
|
||||
def get_root_page(self):
|
||||
"""
|
||||
Returns the page that is used when the `&child_of=root` filter is used.
|
||||
"""
|
||||
return Site.find_for_request(self.request).root_page
|
||||
|
||||
def get_base_queryset(self):
|
||||
"""
|
||||
Returns a queryset containing all pages that can be seen by this user.
|
||||
|
||||
This is used as the base for get_queryset and is also used to find the
|
||||
parent pages when using the child_of and descendant_of filters as well.
|
||||
"""
|
||||
|
||||
request = self.request
|
||||
|
||||
# Get all live pages
|
||||
queryset = Page.objects.all().live()
|
||||
|
||||
# Exclude pages that the user doesn't have access to
|
||||
restricted_pages = [
|
||||
restriction.page
|
||||
for restriction in PageViewRestriction.objects.all().select_related("page")
|
||||
if not restriction.accept_request(self.request)
|
||||
]
|
||||
|
||||
# Exclude the restricted pages and their descendants from the queryset
|
||||
for restricted_page in restricted_pages:
|
||||
queryset = queryset.not_descendant_of(restricted_page, inclusive=True)
|
||||
|
||||
# Check if we have a specific site to look for
|
||||
if "site" in request.GET:
|
||||
# Optionally allow querying by port
|
||||
if ":" in request.GET["site"]:
|
||||
(hostname, port) = request.GET["site"].split(":", 1)
|
||||
query = {
|
||||
"hostname": hostname,
|
||||
"port": port,
|
||||
}
|
||||
else:
|
||||
query = {
|
||||
"hostname": request.GET["site"],
|
||||
}
|
||||
try:
|
||||
site = Site.objects.get(**query)
|
||||
except Site.MultipleObjectsReturned:
|
||||
raise BadRequestError(
|
||||
"Your query returned multiple sites. Try adding a port number to your site filter."
|
||||
)
|
||||
else:
|
||||
# Otherwise, find the site from the request
|
||||
site = Site.find_for_request(self.request)
|
||||
|
||||
if site:
|
||||
base_queryset = queryset
|
||||
queryset = base_queryset.descendant_of(site.root_page, inclusive=True)
|
||||
|
||||
# If internationalisation is enabled, include pages from other language trees
|
||||
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
|
||||
for translation in site.root_page.get_translations():
|
||||
queryset |= base_queryset.descendant_of(translation, inclusive=True)
|
||||
|
||||
else:
|
||||
# No sites configured
|
||||
queryset = queryset.none()
|
||||
|
||||
return queryset
|
||||
|
||||
def get_queryset(self):
|
||||
request = self.request
|
||||
|
||||
# Allow pages to be filtered to a specific type
|
||||
try:
|
||||
models_type = request.GET.get("type", None)
|
||||
models = models_type and page_models_from_string(models_type) or []
|
||||
except (LookupError, ValueError):
|
||||
raise BadRequestError("type doesn't exist")
|
||||
|
||||
if not models:
|
||||
if self.model == Page:
|
||||
return self.get_base_queryset()
|
||||
else:
|
||||
return self.model.objects.filter(
|
||||
pk__in=self.get_base_queryset().values_list("pk", flat=True)
|
||||
)
|
||||
|
||||
elif len(models) == 1:
|
||||
# If a single page type has been specified, swap out the Page-based queryset for one based on
|
||||
# the specific page model so that we can filter on any custom APIFields defined on that model
|
||||
return models[0].objects.filter(
|
||||
pk__in=self.get_base_queryset().values_list("pk", flat=True)
|
||||
)
|
||||
|
||||
else: # len(models) > 1
|
||||
return self.get_base_queryset().type(*models)
|
||||
|
||||
def get_object(self):
|
||||
base = super().get_object()
|
||||
return base.specific
|
||||
|
||||
def find_object(self, queryset, request):
|
||||
site = Site.find_for_request(request)
|
||||
if "html_path" in request.GET and site is not None:
|
||||
path = request.GET["html_path"]
|
||||
path_components = [component for component in path.split("/") if component]
|
||||
|
||||
try:
|
||||
page, _, _ = site.root_page.specific.route(request, path_components)
|
||||
except Http404:
|
||||
return
|
||||
|
||||
if queryset.filter(id=page.id).exists():
|
||||
return page
|
||||
|
||||
return super().find_object(queryset, request)
|
||||
|
||||
def get_serializer_context(self):
|
||||
"""
|
||||
The serialization context differs between listing and detail views.
|
||||
"""
|
||||
context = super().get_serializer_context()
|
||||
context["base_queryset"] = self.get_base_queryset()
|
||||
return context
|
||||
Reference in New Issue
Block a user