Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
from .decorators import * # NOQA: F403
from .page_tests import * # NOQA: F403
from .wagtail_tests import * # NOQA: F403

View File

@@ -0,0 +1,30 @@
from functools import wraps
class disconnect_signal_receiver:
"""
A decorator that disconnects a signal's receiver during the
execution of a test and reconnects it back at its end.
"""
def __init__(self, signal, receiver):
self.signal = signal
self.receiver = receiver
def __call__(self, func):
@wraps(func)
def wrapper(*args, **kwargs):
exception = None
self.signal.disconnect(self.receiver)
try:
func(*args, **kwargs)
except Exception as e: # noqa: BLE001
exception = e
finally:
self.signal.connect(self.receiver)
if exception:
raise exception
return wrapper

View File

@@ -0,0 +1,202 @@
"""
The ``assertCanCreate`` method requires page data to be passed in
the same format that the page edit form would submit. For complex
page types, it can be difficult to construct this data structure by hand;
the ``wagtail.test.utils.form_data`` module provides a set of helper
functions to assist with this.
"""
import bs4
from django.http import QueryDict
from wagtail.admin.rich_text import get_rich_text_editor_widget
from .wagtail_tests import WagtailTestUtils
def _nested_form_data(data):
if isinstance(data, dict):
items = data.items()
elif isinstance(data, list):
items = enumerate(data)
for key, value in items:
key = str(key)
if isinstance(value, (dict, list)):
for child_keys, child_value in _nested_form_data(value):
yield [key] + child_keys, child_value
else:
yield [key], value
def nested_form_data(data):
"""
Translates a nested dict structure into a flat form data dict
with hyphen-separated keys.
.. code-block:: python
nested_form_data({
'foo': 'bar',
'parent': {
'child': 'field',
},
})
# Returns: {'foo': 'bar', 'parent-child': 'field'}
"""
return {"-".join(key): value for key, value in _nested_form_data(data)}
def streamfield(items):
"""
Takes a list of (block_type, value) tuples and turns it in to
StreamField form data. Use this within a :func:`nested_form_data`
call, with the field name as the key.
.. code-block:: python
nested_form_data({'content': streamfield([
('text', 'Hello, world'),
])})
# Returns:
# {
# 'content-count': '1',
# 'content-0-type': 'text',
# 'content-0-value': 'Hello, world',
# 'content-0-order': '0',
# 'content-0-deleted': '',
# }
"""
def to_block(index, item):
block, value = item
return {"type": block, "value": value, "deleted": "", "order": str(index)}
data_dict = {str(index): to_block(index, item) for index, item in enumerate(items)}
data_dict["count"] = str(len(data_dict))
return data_dict
def inline_formset(items, initial=0, min=0, max=1000):
"""
Takes a list of form data for an InlineFormset and translates
it in to valid POST data. Use this within a :func:`nested_form_data`
call, with the formset relation name as the key.
.. code-block:: python
nested_form_data({'lines': inline_formset([
{'text': 'Hello'},
{'text': 'World'},
])})
# Returns:
# {
# 'lines-TOTAL_FORMS': '2',
# 'lines-INITIAL_FORMS': '0',
# 'lines-MIN_NUM_FORMS': '0',
# 'lines-MAX_NUM_FORMS': '1000',
# 'lines-0-text': 'Hello',
# 'lines-0-ORDER': '0',
# 'lines-0-DELETE': '',
# 'lines-1-text': 'World',
# 'lines-1-ORDER': '1',
# 'lines-1-DELETE': '',
# }
"""
def to_form(index, item):
defaults = {
"ORDER": str(index),
"DELETE": "",
}
defaults.update(item)
return defaults
data_dict = {str(index): to_form(index, item) for index, item in enumerate(items)}
data_dict.update(
{
"TOTAL_FORMS": str(len(data_dict)),
"INITIAL_FORMS": str(initial),
"MIN_NUM_FORMS": str(min),
"MAX_NUM_FORMS": str(max),
}
)
return data_dict
def rich_text(value, editor="default", features=None):
"""
Converts an HTML-like rich text string to the data format required by
the currently active rich text editor.
:param editor: An alternative editor name as defined in ``WAGTAILADMIN_RICH_TEXT_EDITORS``
:param features: A list of features allowed in the rich text content (see :ref:`rich_text_features`)
.. code-block:: python
self.assertCanCreate(root_page, ContentPage, nested_form_data({
'title': 'About us',
'body': rich_text('<p>Lorem ipsum dolor sit amet</p>'),
}))
"""
widget = get_rich_text_editor_widget(editor, features)
return widget.format_value(value)
def _querydict_from_form(form: bs4.Tag, exclude_csrf: bool = True) -> QueryDict:
data = QueryDict(mutable=True)
for input in form.find_all("input"):
name = input.attrs.get("name")
if (
name
and input.attrs.get("type", "") not in ("checkbox", "radio")
and (not exclude_csrf or name != "csrfmiddlewaretoken")
):
data[name] = input.attrs.get("value", "")
for input in form.find_all("input", type="radio", checked=True):
name = input.attrs.get("name")
if name:
data[name] = input.attrs.get("value")
for input in form.find_all("input", type="checkbox", checked=True):
name = input.attrs.get("name")
if name:
data.appendlist(name, input.attrs.get("value", ""))
for textarea in form.find_all("textarea"):
name = textarea.attrs.get("name")
if name:
data[name] = textarea.get_text()
for select in form.find_all("select"):
name = select.attrs.get("name")
if name:
selected_value = False
for option in select.find_all("option", selected=True):
selected_value = True
data.appendlist(name, option.attrs.get("value", option.get_text()))
if not selected_value:
first_option = select.find("option")
if first_option:
data[name] = first_option.attrs.get(
"value", first_option.get_text()
)
return data
def querydict_from_html(
html: str, form_id: str = None, form_index: int = 0, exclude_csrf: bool = True
) -> QueryDict:
soup = WagtailTestUtils.get_soup(html)
if form_id is not None:
form = soup.find("form", attrs={"id": form_id})
if form is None:
raise ValueError(f'No form was found with id "{form_id}".')
return _querydict_from_form(form, exclude_csrf)
else:
index = int(form_index)
for i, form in enumerate(soup.find_all("form", limit=index + 1)):
if i == index:
return _querydict_from_form(form, exclude_csrf)
raise ValueError(f"No form was found with index: {form_index}.")

View File

@@ -0,0 +1,457 @@
from typing import Any, Dict, Optional
from unittest import mock
from django.conf import settings
from django.contrib.auth.base_user import AbstractBaseUser
from django.http import Http404
from django.test import TestCase
from django.urls import reverse
from django.utils.http import urlencode
from django.utils.text import slugify
from wagtail.coreutils import get_dummy_request
from wagtail.models import Page
from .form_data import querydict_from_html
from .wagtail_tests import WagtailTestUtils
AUTH_BACKEND = settings.AUTHENTICATION_BACKENDS[0]
class WagtailPageTestCase(WagtailTestUtils, TestCase):
"""
A set of assertions to help write tests for custom Wagtail page types
"""
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.dummy_request = get_dummy_request()
def _testCanCreateAt(self, parent_model, child_model):
return child_model in parent_model.allowed_subpage_models()
def assertCanCreateAt(self, parent_model, child_model, msg=None):
"""
Assert a particular child Page type can be created under a parent
Page type. ``parent_model`` and ``child_model`` should be the Page
classes being tested.
"""
if not self._testCanCreateAt(parent_model, child_model):
msg = self._formatMessage(
msg,
"Can not create a %s.%s under a %s.%s"
% (
child_model._meta.app_label,
child_model._meta.model_name,
parent_model._meta.app_label,
parent_model._meta.model_name,
),
)
raise self.failureException(msg)
def assertCanNotCreateAt(self, parent_model, child_model, msg=None):
"""
Assert a particular child Page type can not be created under a parent
Page type. ``parent_model`` and ``child_model`` should be the Page
classes being tested.
"""
if self._testCanCreateAt(parent_model, child_model):
msg = self._formatMessage(
msg,
"Can create a %s.%s under a %s.%s"
% (
child_model._meta.app_label,
child_model._meta.model_name,
parent_model._meta.app_label,
parent_model._meta.model_name,
),
)
raise self.failureException(msg)
def assertCanCreate(self, parent, child_model, data, msg=None, publish=True):
"""
Assert that a child of the given Page type can be created under the
parent, using the supplied POST data.
``parent`` should be a Page instance, and ``child_model`` should be a
Page subclass. ``data`` should be a dict that will be POSTed at the
Wagtail admin Page creation method.
"""
self.assertCanCreateAt(parent.specific_class, child_model)
if "slug" not in data and "title" in data:
data["slug"] = slugify(data["title"])
if publish:
data["action-publish"] = "action-publish"
add_url = reverse(
"wagtailadmin_pages:add",
args=[child_model._meta.app_label, child_model._meta.model_name, parent.pk],
)
response = self.client.post(add_url, data, follow=True)
if response.status_code != 200:
msg = self._formatMessage(
msg,
"Creating a %s.%s returned a %d"
% (
child_model._meta.app_label,
child_model._meta.model_name,
response.status_code,
),
)
raise self.failureException(msg)
if response.redirect_chain == []:
if "form" not in response.context:
msg = self._formatMessage(msg, "Creating a page failed unusually")
raise self.failureException(msg)
form = response.context["form"]
if not form.errors:
msg = self._formatMessage(
msg, "Creating a page failed for an unknown reason"
)
raise self.failureException(msg)
errors = "\n".join(
" {}:\n {}".format(field, "\n ".join(errors))
for field, errors in sorted(form.errors.items())
)
msg = self._formatMessage(
msg,
"Validation errors found when creating a %s.%s:\n%s"
% (child_model._meta.app_label, child_model._meta.model_name, errors),
)
raise self.failureException(msg)
if publish:
expected_url = reverse("wagtailadmin_explore", args=[parent.pk])
else:
expected_url = reverse(
"wagtailadmin_pages:edit", args=[Page.objects.order_by("pk").last().pk]
)
if response.redirect_chain != [(expected_url, 302)]:
msg = self._formatMessage(
msg,
"Creating a page %s.%s didn't redirect the user to the expected page %s, but to %s"
% (
child_model._meta.app_label,
child_model._meta.model_name,
expected_url,
response.redirect_chain,
),
)
raise self.failureException(msg)
def assertAllowedSubpageTypes(self, parent_model, child_models, msg=None):
"""
Test that the only page types that can be created under
``parent_model`` are ``child_models``.
The list of allowed child models may differ from those set in
``Page.subpage_types``, if the child models have set
``Page.parent_page_types``.
"""
self.assertEqual(
set(parent_model.allowed_subpage_models()), set(child_models), msg=msg
)
def assertAllowedParentPageTypes(self, child_model, parent_models, msg=None):
"""
Test that the only page types that ``child_model`` can be created under
are ``parent_models``.
The list of allowed parent models may differ from those set in
``Page.parent_page_types``, if the parent models have set
``Page.subpage_types``.
"""
self.assertEqual(
set(child_model.allowed_parent_page_models()), set(parent_models), msg=msg
)
def assertPageIsRoutable(
self,
page: Page,
route_path: Optional[str] = "/",
msg: Optional[str] = None,
):
"""
Asserts that ``page`` can be routed to without raising a ``Http404`` error.
For page types with multiple routes, you can use ``route_path`` to specify an alternate route to test.
"""
path = page.get_url(self.dummy_request)
if route_path != "/":
path = path.rstrip("/") + "/" + route_path.lstrip("/")
site = page.get_site()
if site is None:
msg = self._formatMessage(
msg,
'Failed to route to "%s" for %s "%s". The page does not belong to any sites.'
% (type(page).__name__, route_path, page),
)
raise self.failureException(msg)
path_components = [component for component in path.split("/") if component]
try:
page, args, kwargs = site.root_page.localized.specific.route(
self.dummy_request, path_components
)
except Http404:
msg = self._formatMessage(
msg,
'Failed to route to "%(route_path)s" for %(page_type)s "%(page)s". A Http404 was raised for path: "%(full_path)s".'
% {
"route_path": route_path,
"page_type": type(page).__name__,
"page": page,
"full_path": path,
},
)
raise self.failureException(msg)
def assertPageIsRenderable(
self,
page: Page,
route_path: Optional[str] = "/",
query_data: Optional[Dict[str, Any]] = None,
post_data: Optional[Dict[str, Any]] = None,
user: Optional[AbstractBaseUser] = None,
accept_404: Optional[bool] = False,
accept_redirect: Optional[bool] = False,
msg: Optional[str] = None,
):
"""
Asserts that ``page`` can be rendered without raising a fatal error.
For page types with multiple routes, you can use ``route_path`` to specify an alternate route to test.
When ``post_data`` is provided, the test makes a ``POST`` request with ``post_data`` in the request body. Otherwise, a ``GET`` request is made.
When supplied, ``query_data`` is converted to a querystring and added to the request URL (regardless of whether ``post_data`` is provided).
When ``user`` is provided, the test is conducted with them as the active user.
By default, the assertion will fail if the request to the page URL results in a 301, 302 or 404 HTTP response. If you are testing a page/route
where a 404 response is expected, you can use ``accept_404=True`` to indicate this, and the assertion will pass when encountering a 404. Likewise,
if you are testing a page/route where a redirect response is expected, you can use `accept_redirect=True` to indicate this, and the assertion will
pass when encountering 301 or 302.
"""
if user:
self.client.force_login(user, AUTH_BACKEND)
path = page.get_url(self.dummy_request)
if route_path != "/":
path = path.rstrip("/") + "/" + route_path.lstrip("/")
post_kwargs = {}
if post_data is not None:
post_kwargs = {"data": post_data}
if query_data:
post_kwargs["QUERYSTRING"] = urlencode(query_data, doseq=True)
try:
if post_data is None:
resp = self.client.get(path, data=query_data)
else:
resp = self.client.post(path, **post_kwargs)
except Exception as e: # noqa: BLE001
msg = self._formatMessage(
msg,
'Failed to render route "%(route_path)s" for %(page_type)s "%(page)s":\n%(exc)s'
% {
"route_path": route_path,
"page_type": type(page).__name__,
"page": page,
"exc": e,
},
)
raise self.failureException(msg)
finally:
if user:
self.client.logout()
if (
resp.status_code == 200
or (accept_404 and resp.status_code == 404)
or (accept_redirect and resp.status_code in (301, 302))
or isinstance(resp, mock.MagicMock)
):
return
msg = self._formatMessage(
msg,
'Failed to render route "%(route_path)s" for %(page_type)s "%(page)s":\nA HTTP %(code)s response was received for path: "%(full_path)s".'
% {
"route_path": route_path,
"page_type": type(page).__name__,
"page": page,
"code": resp.status_code,
"full_path": path,
},
)
raise self.failureException(msg)
def assertPageIsEditable(
self,
page: Page,
post_data: Optional[Dict[str, Any]] = None,
user: Optional[AbstractBaseUser] = None,
msg: Optional[str] = None,
):
"""
Asserts that the page edit view works for ``page`` without raising a fatal error.
When ``user`` is provided, the test is conducted with them as the active user. Otherwise, a superuser is created and used for the test.
After a successful ``GET`` request, a ``POST`` request is made with field data in the request body. If ``post_data`` is provided, that will be used for this purpose. If not, this data will be extracted from the ``GET`` response HTML.
"""
if user:
# rule out permission issues early on
if not page.permissions_for_user(user).can_edit():
self._formatMessage(
msg,
'Failed to load edit view for %(page_type)s "%(page)s":\nUser "%(user)s" have insufficient permissions.'
% {
"page_type": type(page).__name__,
"page": page,
"user": user,
},
)
raise self.failureException(msg)
else:
if not hasattr(self, "_pageiseditable_superuser"):
self._pageiseditable_superuser = self.create_superuser(
"assertpageiseditable"
)
user = self._pageiseditable_superuser
self.client.force_login(user, AUTH_BACKEND)
path = reverse("wagtailadmin_pages:edit", kwargs={"page_id": page.id})
try:
response = self.client.get(path)
except Exception as e: # noqa: BLE001
self.client.logout()
msg = self._formatMessage(
msg,
'Failed to load edit view via GET for %(page_type)s "%(page)s":\n%(exc)s'
% {"page_type": type(page).__name__, "page": page, "exc": e},
)
raise self.failureException(msg)
if response.status_code != 200:
self.client.logout()
msg = self._formatMessage(
msg,
'Failed to load edit view via GET for %(page_type)s "%(page)s":\nReceived response with HTTP status code: %(code)s.'
% {
"page_type": type(page).__name__,
"page": page,
"code": response.status_code,
},
)
raise self.failureException(msg)
if post_data is not None:
data_to_post = post_data
else:
data_to_post = querydict_from_html(
response.content.decode(), form_id="page-edit-form"
)
data_to_post["action-publish"] = ""
try:
self.client.post(path, data_to_post)
except Exception as e: # noqa: BLE001
msg = self._formatMessage(
msg,
'Failed to load edit view via POST for %(page_type)s "%(page)s":\n%(exc)s'
% {"page_type": type(page).__name__, "page": page, "exc": e},
)
raise self.failureException(msg)
finally:
page.save() # undo any changes to page
self.client.logout()
def assertPageIsPreviewable(
self,
page: Page,
mode: Optional[str] = "",
post_data: Optional[Dict[str, Any]] = None,
user: Optional[AbstractBaseUser] = None,
msg: Optional[str] = None,
):
"""
Asserts that the page preview view can be loaded for ``page`` without raising a fatal error.
For page types that support multiple preview modes, ``mode`` can be used to specify the preview mode to be tested.
When ``user`` is provided, the test is conducted with them as the active user. Otherwise, a superuser is created and used for the test.
To load the preview, the test client needs to make a ``POST`` request including all required field data in the request body.
If ``post_data`` is provided, that will be used for this purpose. If not, the method will attempt to extract this data from the page edit view.
"""
if not user:
if not hasattr(self, "_pageispreviewable_superuser"):
self._pageispreviewable_superuser = self.create_superuser(
"assertpageispreviewable"
)
user = self._pageispreviewable_superuser
self.client.force_login(user, AUTH_BACKEND)
if post_data is None:
edit_path = reverse("wagtailadmin_pages:edit", kwargs={"page_id": page.id})
html = self.client.get(edit_path).content.decode()
post_data = querydict_from_html(html, form_id="page-edit-form")
preview_path = reverse(
"wagtailadmin_pages:preview_on_edit", kwargs={"page_id": page.id}
)
try:
response = self.client.post(
preview_path, data=post_data, QUERYSTRING=f"mode={mode}"
)
self.assertEqual(response.status_code, 200)
self.assertJSONEqual(
response.content.decode(),
{"is_valid": True, "is_available": True},
)
except Exception as e: # noqa: BLE001
self.client.logout()
msg = self._formatMessage(
msg,
'Failed to load preview for %(page_type)s "%(page)s" with mode="%(mode)s":\n%(exc)s'
% {
"page_type": type(page).__name__,
"page": page,
"mode": mode,
"exc": e,
},
)
raise self.failureException(msg)
try:
self.client.get(preview_path, data={"mode": mode})
except Exception as e: # noqa: BLE001
msg = self._formatMessage(
msg,
'Failed to load preview for %(page_type)s "%(page)s" with mode="%(mode)s":\n%(exc)s'
% {
"page_type": type(page).__name__,
"page": page,
"mode": mode,
"exc": e,
},
)
raise self.failureException(msg)
finally:
self.client.logout()
class WagtailPageTests(WagtailPageTestCase):
def setUp(self):
super().setUp()
self.login()

View File

@@ -0,0 +1,90 @@
from typing import Dict, List, Union
from django.test import SimpleTestCase
from .wagtail_tests import WagtailTestUtils
class AdminTemplateTestUtils:
base_breadcrumb_items = [{"label": "Home", "url": "/admin/"}]
def assertBreadcrumbsItemsRendered(
self: Union[WagtailTestUtils, SimpleTestCase],
items: List[Dict[str, str]],
html: Union[str, bytes],
):
soup = self.get_soup(html)
# Select with a class instead of a data-controller attribute because
# the controller is only applied if the breadcrumbs are collapsible
breadcrumbs = soup.select(".w-breadcrumbs")
num_breadcrumbs = len(breadcrumbs)
self.assertEqual(
num_breadcrumbs,
1,
f"Expected one breadcrumbs component to be rendered, found {num_breadcrumbs}",
)
items = self.base_breadcrumb_items + items
rendered_items = breadcrumbs[0].select("ol > li")
num_rendered_items = len(rendered_items)
num_items = len(items)
arrows = soup.select("ol > li > svg")
num_arrows = len(arrows)
self.assertEqual(
num_rendered_items,
num_items,
f"Expected {num_items} breadcrumbs items to be rendered, found {num_rendered_items}",
)
self.assertEqual(
num_arrows,
num_items - 1,
f"Expected {num_items - 1} arrows to be rendered, found {num_arrows}",
)
for item, rendered_item in zip(items, rendered_items):
if item.get("url") is not None:
element = rendered_item.select_one("a")
self.assertIsNotNone(
element,
f"Expected '{item['label']}' breadcrumbs item to be a link",
)
self.assertEqual(
element["href"],
item["url"],
f"Expected '{item['label']}' breadcrumbs item to link to '{item['url']}'",
)
else:
element = rendered_item.select_one("div")
self.assertIsNotNone(
element,
f"Expected '{item['label']}' breadcrumbs item to be a div",
)
# Sublabel is optional and the : separator is invisible
label = element.get_text(strip=True)
sublabel = None
if item.get("sublabel"):
label, sublabel = label.split(":", maxsplit=1)
self.assertEqual(
label,
item["label"],
f"Expected '{item['label']}' breadcrumbs item label, found '{label}'",
)
if sublabel:
self.assertEqual(
sublabel,
item["sublabel"],
f"Expected '{item['sublabel']}' breadcrumbs item sublabel, found '{sublabel}'",
)
def assertBreadcrumbsNotRendered(
self: Union[WagtailTestUtils, SimpleTestCase],
html: Union[str, bytes],
):
soup = self.get_soup(html)
# Select with a class instead of a data-controller attribute because
# the controller is only applied if the breadcrumbs are collapsible
breadcrumbs = soup.select_one(".w-breadcrumbs")
# Confirmation views (e.g. delete view) shouldn't render breadcrumbs
self.assertIsNone(breadcrumbs)

View File

@@ -0,0 +1,22 @@
import datetime
from django.utils import timezone
def submittable_timestamp(timestamp):
"""
Helper function to translate a possibly-timezone-aware datetime into the format used in the
go_live_at / expire_at form fields - "YYYY-MM-DD hh:mm", with no timezone indicator.
This will be interpreted as being in the server's timezone (settings.TIME_ZONE), so we
need to pass it through timezone.localtime to ensure that the client and server are in
agreement about what the timestamp means.
"""
if timezone.is_aware(timestamp):
return timezone.localtime(timestamp).strftime("%Y-%m-%d %H:%M")
else:
return timestamp.strftime("%Y-%m-%d %H:%M")
def local_datetime(*args):
dt = datetime.datetime(*args)
return timezone.make_aware(dt)

View File

@@ -0,0 +1,2 @@
from .blocks import * # noqa: F403
from .factories import * # noqa: F403

View File

@@ -0,0 +1,242 @@
from collections import defaultdict
import factory
from factory.declarations import ParameteredAttribute
from wagtail import blocks
from wagtail.documents.blocks import DocumentChooserBlock
from wagtail.images.blocks import ImageChooserBlock
from .builder import (
ListBlockStepBuilder,
StreamBlockStepBuilder,
StructBlockStepBuilder,
)
from .factories import DocumentFactory, ImageFactory, PageFactory
from .options import BlockFactoryOptions, StreamBlockFactoryOptions
__all__ = [
"CharBlockFactory",
"IntegerBlockFactory",
"StreamBlockFactory",
"StreamFieldFactory",
"ListBlockFactory",
"StructBlockFactory",
"PageChooserBlockFactory",
"ImageChooserBlockFactory",
"DocumentChooserBlockFactory",
]
class StreamBlockFactory(factory.Factory):
_options_class = StreamBlockFactoryOptions
_builder_class = StreamBlockStepBuilder
@classmethod
def _generate(cls, strategy, params):
if cls._meta.abstract and not hasattr(cls, "__generate_abstract__"):
raise factory.errors.FactoryError(
"Cannot generate instances of abstract factory %(f)s; "
"Ensure %(f)s.Meta.model is set and %(f)s.Meta.abstract "
"is either not set or False." % {"f": cls.__name__}
)
step = cls._builder_class(cls._meta, params, strategy)
return step.build()
@classmethod
def _construct_stream(cls, block_class, *args, **kwargs):
def get_index(key):
return int(key.split(".")[0])
stream_length = max(map(get_index, kwargs.keys())) + 1 if kwargs else 0
stream_data = [None] * stream_length
for indexed_block_name, value in kwargs.items():
i, name = indexed_block_name.split(".")
stream_data[int(i)] = (name, value)
block_def = cls._meta.get_block_definition()
if block_def is None:
# We got an old style definition, so aren't aware of a StreamBlock class for the
# StreamField's child blocks. As nesting of StreamBlocks isn't supported for this
# kind of declaration, returning the stream data without up-casting it to a
# StreamValue is OK here. StreamField handles conversion to a StreamValue, but not
# recursively.
return stream_data
return blocks.StreamValue(block_def, stream_data)
@classmethod
def _build(cls, block_class, *args, **kwargs):
return cls._construct_stream(block_class, *args, **kwargs)
@classmethod
def _create(cls, block_class, *args, **kwargs):
return cls._construct_stream(block_class, *args, **kwargs)
class Meta:
abstract = True
class StreamFieldFactory(ParameteredAttribute):
"""
Syntax:
<streamfield>__<index>__<block_name>__<key>='foo',
Syntax to generate blocks with default factory values:
<streamfield>__<index>=<block_name>
"""
def __init__(self, block_types, **kwargs):
super().__init__(**kwargs)
if isinstance(block_types, dict):
# Old style definition, dict mapping block name -> block factory
self.stream_block_factory = type(
"_GeneratedStreamBlockFactory",
(StreamBlockFactory,),
{**block_types, "__generate_abstract__": True},
)
elif isinstance(block_types, type) and issubclass(
block_types, StreamBlockFactory
):
block_types._meta.block_def = block_types._meta.model()
self.stream_block_factory = block_types
else:
raise TypeError(
"StreamFieldFactory argument must be a StreamBlockFactory subclass or dict "
"mapping block names to factories"
)
def evaluate(self, instance, step, extra):
return self.stream_block_factory(**extra)
class ListBlockFactory(factory.SubFactory):
_builder_class = ListBlockStepBuilder
def __call__(self, **kwargs):
return self.evaluate(None, None, kwargs)
def evaluate(self, instance, step, extra):
result = defaultdict(dict)
for key, value in extra.items():
if key.isdigit():
result[int(key)]["value"] = value
else:
prefix, label = key.split("__", maxsplit=1)
if prefix and prefix.isdigit():
result[int(prefix)][label] = value
subfactory = self.get_factory()
force_sequence = step.sequence if self.FORCE_SEQUENCE else None
values = [
step.recurse(subfactory, params, force_sequence=force_sequence)
for _, params in sorted(result.items())
]
list_block_def = blocks.list_block.ListBlock(subfactory._meta.model())
return blocks.list_block.ListValue(list_block_def, values)
class StructBlockFactory(factory.Factory):
_options_class = BlockFactoryOptions
_builder_class = StructBlockStepBuilder
class Meta:
abstract = True
model = blocks.StructBlock
@classmethod
def _construct_struct_value(cls, block_class, params):
return blocks.StructValue(
block_class(),
list(params.items()),
)
@classmethod
def _build(cls, block_class, *args, **kwargs):
return cls._construct_struct_value(block_class, kwargs)
@classmethod
def _create(cls, block_class, *args, **kwargs):
return cls._construct_struct_value(block_class, kwargs)
class BlockFactory(factory.Factory):
_options_class = BlockFactoryOptions
_builder_class = factory.builder.StepBuilder
class Meta:
abstract = True
@classmethod
def _construct_block(cls, block_class, *args, **kwargs):
if kwargs.get("value"):
return block_class().clean(kwargs["value"])
return block_class().get_default()
@classmethod
def _build(cls, block_class, *args, **kwargs):
return cls._construct_block(block_class, *args, **kwargs)
@classmethod
def _create(cls, block_class, *args, **kwargs):
return cls._construct_block(block_class, *args, **kwargs)
class CharBlockFactory(BlockFactory):
class Meta:
model = blocks.CharBlock
class IntegerBlockFactory(BlockFactory):
class Meta:
model = blocks.IntegerBlock
class ChooserBlockFactory(BlockFactory):
pass
class PageChooserBlockFactory(ChooserBlockFactory):
page = factory.SubFactory(PageFactory)
class Meta:
model = blocks.PageChooserBlock
@classmethod
def _build(cls, model_class, page):
return page
@classmethod
def _create(cls, model_class, page):
return page
class ImageChooserBlockFactory(ChooserBlockFactory):
image = factory.SubFactory(ImageFactory)
class Meta:
model = ImageChooserBlock
@classmethod
def _build(cls, model_class, image):
return image
@classmethod
def _create(cls, model_class, image):
return image
class DocumentChooserBlockFactory(ChooserBlockFactory):
document = factory.SubFactory(DocumentFactory)
class Meta:
model = DocumentChooserBlock
@classmethod
def _build(cls, model_class, document):
return document
@classmethod
def _create(cls, model_class, document):
return document

View File

@@ -0,0 +1,141 @@
from itertools import zip_longest
from factory import SubFactory
from factory.builder import StepBuilder
from wagtail import blocks
class StreamFieldFactoryException(Exception):
pass
class InvalidDeclaration(StreamFieldFactoryException):
pass
class DuplicateDeclaration(StreamFieldFactoryException):
pass
class UnknownChildBlockFactory(StreamFieldFactoryException):
pass
class BaseBlockStepBuilder(StepBuilder):
def recurse(self, factory_meta, extras):
"""Recurse into a sub-factory call."""
builder_class = factory_meta.factory._builder_class
return builder_class(factory_meta, extras, strategy=self.strategy)
class StructBlockStepBuilder(BaseBlockStepBuilder):
pass
class ListBlockStepBuilder(BaseBlockStepBuilder):
pass
class StreamBlockStepBuilder(BaseBlockStepBuilder):
def __init__(self, factory_meta, extras, strategy):
indexed_block_names, extra_declarations = self.get_block_declarations(
factory_meta, extras
)
new_factory_class = self.create_factory_class(factory_meta, indexed_block_names)
super().__init__(new_factory_class._meta, extra_declarations, strategy)
def get_block_declarations(self, factory_meta, extras):
# Mapping of StreamValue index -> block name. We will use this to create a
# StreamBlockFactory subclass with one declaration for each pair, named
# <index>.<block_name>
indexed_block_names = {}
# Extra declarations passed at instantiation, renamed from <index>__<name>__... to
# <index>.<block_name>__..., to match the declarations on the StreamBlockFactory subclass
# we will generate. As DeclarationSet splits parameter names on "__" the
# <index>.<block_name> keys won't cause errors for unknown declarations (0__foo_block
# implies a declaration "0" with context "foo_block"). They will also have the important
# property of being uniquely hashable
extra_declarations = {}
for k, v in extras.items():
if k.isdigit():
# We got a declaration like `<index>="foo_block"' - <index> should get the
# default value for foo_block, so don't store this item in extra_declarations
if v not in factory_meta.base_declarations:
raise UnknownChildBlockFactory(
f"No factory defined for block '{v}'"
)
key = int(k)
if key in indexed_block_names and indexed_block_names[key] != v:
raise DuplicateDeclaration(
f"Multiple declarations for index {key} at this level of nesting "
f"(got {v}, already have {indexed_block_names[key]})"
)
indexed_block_names[key] = v
else:
try:
i, name, *params = k.split("__", maxsplit=2)
key = int(i)
except (ValueError, TypeError):
raise InvalidDeclaration(
"StreamFieldFactory declarations must be of the form "
"<index>=<block_name>, <index>__<block_name>=value or "
f"<index>__<block_name>__<param>=value, got: {k}"
)
if key in indexed_block_names and indexed_block_names[key] != name:
raise DuplicateDeclaration(
f"Multiple declarations for index {key} at this level of nesting "
f"(got {name}, already have {indexed_block_names[key]})"
)
indexed_block_names[key] = name
transformed_key = self.reconstruct_key(i, name, params)
extra_declarations[transformed_key] = v
self.validate_block_indexes_sequential(indexed_block_names, factory_meta)
return indexed_block_names, extra_declarations
def reconstruct_key(self, index, name, params):
return f"{index}.{'__'.join((name, *params))}"
def validate_block_indexes_sequential(self, indexed_block_names, factory_meta):
if not indexed_block_names:
# There were no declarations for this block, we will ultimately return an empty
# StreamValue
return
indexes = sorted(indexed_block_names.keys())
for declared, expected in zip_longest(indexes, range(max(indexes) + 1)):
if declared != expected:
raise InvalidDeclaration(
f"Parameters for {factory_meta.factory} missing required index {expected}"
)
def create_factory_class(self, old_factory_meta, indexed_block_names):
# Create a new StreamBlockFactory subclass, with a declaration for each block the user
# requested at instantiation. This way we can rely on the factory_boy internals for
# object generation
new_class_dict = {"Meta": old_factory_meta.to_meta_class()}
block_def = old_factory_meta.get_block_definition()
for i, name in indexed_block_names.items():
declared_value = old_factory_meta.base_declarations[name]
if block_def is not None and isinstance(declared_value, SubFactory):
# Annotate the subfactory's factory with the correct block definition for that
# branch of the tree, so we can construct a StreamValue if there's no explicit
# block class defined (e.g. if a nested StreamBlock was declared inline like
# `inner_stream = StreamBlock(...))'
child_def = block_def.child_blocks[name]
if isinstance(child_def, blocks.ListBlock):
# ListBlock is a special case as it is a concrete node in the stream block
# tree, but ListBlockFactory is a SubFactory subclass, making it "abstract"
# in the factory tree
child_def = child_def.child_block
declared_value.get_factory()._meta.block_def = child_def
new_class_dict[f"{i}.{name}"] = declared_value
from .blocks import StreamBlockFactory
return type(
"_GeneratedStreamBlockFactory", (StreamBlockFactory,), new_class_dict
)

View File

@@ -0,0 +1,147 @@
import logging
import factory
from django.utils.text import slugify
from factory import errors, utils
from factory.declarations import ParameteredAttribute
from factory.django import DjangoModelFactory
from wagtail.documents import get_document_model
from wagtail.images import get_image_model
from wagtail.models import Collection, Page, Site
__all__ = [
"CollectionFactory",
"ImageFactory",
"PageFactory",
"SiteFactory",
"DocumentFactory",
]
logger = logging.getLogger(__file__)
class ParentNodeFactory(ParameteredAttribute):
EXTEND_CONTAINERS = True
FORCE_SEQUENCE = False
UNROLL_CONTEXT_BEFORE_EVALUATION = False
def generate(self, step, params):
if not params:
return None
subfactory = step.builder.factory_meta.factory
logger.debug(
"ParentNodeFactory: Instantiating %s.%s(%s), create=%r",
subfactory.__module__,
subfactory.__name__,
utils.log_pprint(kwargs=params),
step,
)
force_sequence = step.sequence if self.FORCE_SEQUENCE else None
return step.recurse(subfactory, params, force_sequence=force_sequence)
class MP_NodeFactory(DjangoModelFactory):
parent = ParentNodeFactory()
@classmethod
def _build(cls, model_class, *args, **kwargs):
kwargs.pop("parent")
return model_class(**kwargs)
@classmethod
def _create(cls, model_class, *args, **kwargs):
parent = kwargs.pop("parent")
if cls._meta.django_get_or_create:
instance = cls._get_or_create(model_class, *args, parent=parent, **kwargs)
else:
instance = cls._create_instance(model_class, parent, kwargs)
assert instance.pk
return instance
@classmethod
def _create_instance(cls, model_class, parent, kwargs):
instance = model_class(**kwargs)
if parent:
parent.add_child(instance=instance)
else:
model_class.add_root(instance=instance)
return instance
@classmethod
def _get_or_create(cls, model_class, *args, **kwargs):
"""Create an instance of the model through objects.get_or_create."""
manager = cls._get_manager(model_class)
assert "defaults" not in cls._meta.django_get_or_create, (
"'defaults' is a reserved keyword for get_or_create "
"(in %s._meta.django_get_or_create=%r)"
% (cls, cls._meta.django_get_or_create)
)
lookup_fields = {}
for field in cls._meta.django_get_or_create:
if field not in kwargs:
raise errors.FactoryError(
"django_get_or_create - "
"Unable to find initialization value for '%s' in factory %s"
% (field, cls.__name__)
)
lookup_fields[field] = kwargs[field]
parent = lookup_fields.pop("parent", None)
kwargs.pop("parent", None)
if parent:
try:
return manager.child_of(parent).get(**lookup_fields)
except model_class.DoesNotExist:
return cls._create_instance(model_class, parent, kwargs)
else:
return super()._get_or_create(model_class, *args, **kwargs)
class CollectionFactory(MP_NodeFactory):
name = "Test collection"
class Meta:
model = Collection
class PageFactory(MP_NodeFactory):
title = "Test page"
slug = factory.LazyAttribute(lambda obj: slugify(obj.title))
class Meta:
model = Page
class CollectionMemberFactory(DjangoModelFactory):
collection = factory.SubFactory(CollectionFactory, parent=None)
class ImageFactory(CollectionMemberFactory):
class Meta:
model = get_image_model()
title = "An image"
file = factory.django.ImageField()
class SiteFactory(DjangoModelFactory):
hostname = "localhost"
port = factory.Sequence(lambda n: 81 + n)
site_name = "Test site"
root_page = factory.SubFactory(PageFactory, parent=None)
is_default_site = False
class Meta:
model = Site
class DocumentFactory(CollectionMemberFactory):
class Meta:
model = get_document_model()
title = "A document"
file = factory.django.FileField()

View File

@@ -0,0 +1,60 @@
from factory import declarations
from factory.base import FactoryOptions, OptionDefault
class BlockFactoryOptions(FactoryOptions):
def _build_default_options(self):
options = super()._build_default_options()
options.append(OptionDefault("block_def", None))
return options
def get_meta_dict(self):
return {
"model": self.model,
"block_def": self.block_def,
"abstract": self.abstract,
"strategy": self.strategy,
"inline_args": self.inline_args,
"exclude": self.exclude,
"rename": self.rename,
}
def to_meta_class(self):
"""
Create a new Meta class from this instance's options, suitable for
inclusion on a factory subclass
"""
return type("Meta", (), self.get_meta_dict())
class StreamBlockFactoryOptions(BlockFactoryOptions):
def prepare_arguments(self, attributes):
# Like the base implementation, but ignore args as they are not relevant
# for instantiating StreamValues.
def get_block_name(key):
# Keys at this point will be like <index>.<block_name>
return key.split(".")[1]
kwargs = dict(attributes)
# 1. Extension points
kwargs = self.factory._adjust_kwargs(**kwargs)
# 2. Remove hidden objects
filtered_kwargs = {}
for k, v in kwargs.items():
block_name = get_block_name(k)
if (
block_name not in self.exclude
and block_name not in self.parameters
and v is not declarations.SKIP
):
filtered_kwargs[k] = v
return (), filtered_kwargs
def get_block_definition(self):
if self.block_def is not None:
return self.block_def
elif self.model is not None:
return self.model()

View File

@@ -0,0 +1,256 @@
import warnings
from contextlib import contextmanager
from typing import Union
from bs4 import BeautifulSoup
from django.contrib.auth import get_user_model
from django.test.testcases import assert_and_parse_html
class WagtailTestUtils:
@staticmethod
def get_soup(markup: Union[str, bytes]) -> BeautifulSoup:
return BeautifulSoup(markup, "html.parser")
@staticmethod
def create_test_user():
"""
Override this method to return an instance of your custom user model
"""
user_model = get_user_model()
# Create a user
user_data = {
user_model.USERNAME_FIELD: "test@email.com",
"email": "test@email.com",
"password": "password",
}
for field in user_model.REQUIRED_FIELDS:
if field not in user_data:
user_data[field] = field
return user_model.objects.create_superuser(**user_data)
def login(self, user=None, username=None, password="password"):
# wrapper for self.client.login that works interchangeably for user models
# with email as the username field; in this case it will use the passed username
# plus '@example.com'
user_model = get_user_model()
if username is None:
if user is None:
user = self.create_test_user()
username = getattr(user, user_model.USERNAME_FIELD)
if user_model.USERNAME_FIELD == "email" and "@" not in username:
username = "%s@example.com" % username
# Login
self.assertTrue(
self.client.login(
password=password, **{user_model.USERNAME_FIELD: username}
)
)
return user
@staticmethod
def create_user(username, email=None, password=None, **kwargs):
# wrapper for get_user_model().objects.create_user that works interchangeably for user models
# with and without a username field
User = get_user_model()
kwargs["email"] = email or "%s@example.com" % username
kwargs["password"] = password
if User.USERNAME_FIELD != "email":
kwargs[User.USERNAME_FIELD] = username
return User.objects.create_user(**kwargs)
@staticmethod
def create_superuser(username, email=None, password=None, **kwargs):
# wrapper for get_user_model().objects.create_user that works interchangeably for user models
# with and without a username field
User = get_user_model()
kwargs["email"] = email or "%s@example.com" % username
kwargs["password"] = password
if User.USERNAME_FIELD != "email":
kwargs[User.USERNAME_FIELD] = username
return User.objects.create_superuser(**kwargs)
@staticmethod
@contextmanager
def ignore_deprecation_warnings():
with warnings.catch_warnings(record=True) as warning_list: # catch all warnings
yield
# rethrow all warnings that were not DeprecationWarnings or PendingDeprecationWarnings
for w in warning_list:
if not issubclass(
w.category, (DeprecationWarning, PendingDeprecationWarning)
):
warnings.showwarning(
message=w.message,
category=w.category,
filename=w.filename,
lineno=w.lineno,
file=w.file,
line=w.line,
)
@contextmanager
def register_hook(self, hook_name, fn, order=0):
from wagtail import hooks
hooks.register(hook_name, fn, order)
try:
yield
finally:
hooks._hooks[hook_name].remove((fn, order))
def _tag_is_equal(self, tag1, tag2):
if not hasattr(tag1, "name") or not hasattr(tag2, "name"):
return False
if tag1.name != tag2.name:
return False
if len(tag1.attributes) != len(tag2.attributes):
return False
if tag1.attributes != tag2.attributes:
# attributes without a value is same as attribute with value that
# equals the attributes name:
# <input checked> == <input checked="checked">
for i in range(len(tag1.attributes)):
attr, value = tag1.attributes[i]
other_attr, other_value = tag2.attributes[i]
if value is None:
value = attr
if other_value is None:
other_value = other_attr
if attr != other_attr or value != other_value:
return False
return True
def _tag_matches_with_extra_attrs(self, thin_tag, fat_tag):
# return true if thin_tag and fat_tag have the same name,
# and all attributes on thin_tag exist on fat_tag
if not hasattr(thin_tag, "name") or not hasattr(fat_tag, "name"):
return False
if thin_tag.name != fat_tag.name:
return False
for attr, value in thin_tag.attributes:
if value is None:
# attributes without a value is same as attribute with value that
# equals the attributes name:
# <input checked> == <input checked="checked">
if (attr, None) not in fat_tag.attributes and (
attr,
attr,
) not in fat_tag.attributes:
return False
else:
if (attr, value) not in fat_tag.attributes:
return False
return True
def _count_tag_occurrences(self, needle, haystack, allow_extra_attrs=False):
count = 0
if allow_extra_attrs:
if self._tag_matches_with_extra_attrs(needle, haystack):
count += 1
else:
if self._tag_is_equal(needle, haystack):
count += 1
if hasattr(haystack, "children"):
count += sum(
self._count_tag_occurrences(
needle, child, allow_extra_attrs=allow_extra_attrs
)
for child in haystack.children
)
return count
def _tag_is_template_script(self, tag):
if tag.name != "script":
return False
return any(attr == ("type", "text/template") for attr in tag.attributes)
def _find_template_script_tags(self, haystack):
if not hasattr(haystack, "name"):
return
if self._tag_is_template_script(haystack):
yield haystack
else:
for child in haystack.children:
yield from self._find_template_script_tags(child)
def assertTagInHTML(
self, needle, haystack, count=None, msg_prefix="", allow_extra_attrs=False
):
needle = assert_and_parse_html(
self, needle, None, "First argument is not valid HTML:"
)
haystack = assert_and_parse_html(
self, haystack, None, "Second argument is not valid HTML:"
)
real_count = self._count_tag_occurrences(
needle, haystack, allow_extra_attrs=allow_extra_attrs
)
if count is not None:
self.assertEqual(
real_count,
count,
msg_prefix
+ "Found %d instances of '%s' in response (expected %d)"
% (real_count, needle, count),
)
else:
self.assertNotEqual(
real_count, 0, msg_prefix + "Couldn't find '%s' in response" % needle
)
def assertNotInHTML(self, needle, haystack, msg_prefix=""):
self.assertInHTML(needle, haystack, count=0, msg_prefix=msg_prefix)
def assertTagInTemplateScript(self, needle, haystack, count=None, msg_prefix=""):
needle = assert_and_parse_html(
self, needle, None, "First argument is not valid HTML:"
)
haystack = assert_and_parse_html(
self, haystack, None, "Second argument is not valid HTML:"
)
real_count = 0
for script_tag in self._find_template_script_tags(haystack):
if script_tag.children:
self.assertEqual(len(script_tag.children), 1)
script_html = assert_and_parse_html(
self,
script_tag.children[0],
None,
"Script tag content is not valid HTML:",
)
real_count += self._count_tag_occurrences(needle, script_html)
if count is not None:
self.assertEqual(
real_count,
count,
msg_prefix
+ "Found %d instances of '%s' in template script (expected %d)"
% (real_count, needle, count),
)
else:
self.assertNotEqual(
real_count,
0,
msg_prefix + "Couldn't find '%s' in template script" % needle,
)