Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1,265 @@
import datetime
from django.test import TestCase
from django.utils import timezone
from wagtail.blocks import StreamValue
from wagtail.blocks.migrations import migrate_operation
from wagtail.blocks.migrations.operations import (
RenameStreamChildrenOperation,
RenameStructChildrenOperation,
)
from wagtail.blocks.migrations.utils import (
InvalidBlockDefError,
apply_changes_to_raw_data,
)
from wagtail.signal_handlers import disable_reference_index_auto_update
from wagtail.test.streamfield_migrations import factories, models
from wagtail.test.streamfield_migrations.testutils import MigrationTestMixin
class TestExceptionRaisedInRawData(TestCase):
"""Directly test whether an exception is raised by apply_changes_to_raw_data for invalid defs.
This would happen in a situation where the user gives a block path which contains a block name
which is not present in the block definition in the project state at which the migration is
applied. (There should also be a block in the stream data with the said name for this to happen)
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstruct",
).content.raw_data
raw_data.extend(
[
{
"type": "invalid_name1",
"id": "0001",
"value": {"char1": "foo", "char2": "foo"},
},
{
"type": "invalid_name1",
"id": "0002",
"value": {"char1": "foo", "char2": "foo"},
},
]
)
raw_data[1]["value"]["invalid_name2"] = [
{"type": "char1", "value": "foo", "id": "0003"}
]
self.raw_data = raw_data
def test_rename_invalid_stream_child(self):
"""Test whether Exception is raised in when recursing through stream block data"""
with self.assertRaisesMessage(
InvalidBlockDefError, "No current block def named invalid_name1"
):
apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="invalid_name1",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
def test_rename_invalid_struct_child(self):
"""Test whether Exception is raised in when recursing through struct block data"""
with self.assertRaisesMessage(
InvalidBlockDefError, "No current block def named invalid_name2"
):
apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.invalid_name2",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
class BadDataMigrationTestCase(TestCase, MigrationTestMixin):
model = models.SamplePage
default_operation_and_block_path = [
(
RenameStructChildrenOperation(old_name="char1", new_name="renamed1"),
"invalid_name1",
)
]
app_name = "streamfield_migration_tests"
def create_instance(self):
instance = factories.SamplePageFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstruct",
)
self.instance = instance
def append_invalid_instance_data(self):
raw_data = self.instance.content.raw_data
raw_data.extend(
[
{
"type": "invalid_name1",
"id": "0001",
"value": {"char1": "foo", "char2": "foo"},
},
{
"type": "invalid_name1",
"id": "0002",
"value": {"char1": "foo", "char2": "foo"},
},
]
)
stream_block = self.instance.content.stream_block
self.instance.content = StreamValue(
stream_block=stream_block, stream_data=raw_data, is_lazy=True
)
self.instance.save()
def create_invalid_revision(self, delta):
self.append_invalid_instance_data()
invalid_revision = self.create_revision(delta)
# remove the invalid data from the instance
raw_data = self.instance.content.raw_data
raw_data = raw_data[:2]
stream_block = self.instance.content.stream_block
self.instance.content = StreamValue(
stream_block=stream_block, stream_data=raw_data, is_lazy=True
)
self.instance.save()
return invalid_revision.id, invalid_revision.created_at
def create_revision(self, delta):
revision = self.instance.save_revision()
revision.created_at = timezone.now() - datetime.timedelta(days=(delta))
revision.save()
return revision
class TestExceptionRaisedForInstance(BadDataMigrationTestCase):
"""Exception should always be raised when applying migration if it occurs while migrating the
instance data"""
def setUp(self):
with disable_reference_index_auto_update():
self.create_instance()
self.append_invalid_instance_data()
def test_migrate(self):
with self.assertRaisesMessage(
InvalidBlockDefError,
"Invalid block def in {} object ({})".format(
self.instance.__class__.__name__, self.instance.id
),
):
self.apply_migration(
revisions_from=timezone.now() + datetime.timedelta(days=2),
)
class TestExceptionRaisedForLatestRevision(BadDataMigrationTestCase):
"""Exception should always be raised when applying migration if it occurs while migrating the
latest revision data"""
def setUp(self):
with disable_reference_index_auto_update():
self.create_instance()
for i in range(4):
self.create_revision(5 - i)
(
self.invalid_revision_id,
self.invalid_revision_created_at,
) = self.create_invalid_revision(0)
def test_migrate(self):
with self.assertRaisesMessage(
InvalidBlockDefError,
"Invalid block def in {} object ({}) for revision id ({}) created at {}".format(
self.instance.__class__.__name__,
self.instance.id,
self.invalid_revision_id,
self.invalid_revision_created_at,
),
):
self.apply_migration(revisions_from=None)
class TestExceptionRaisedForLiveRevision(BadDataMigrationTestCase):
"""Exception should always be raised when applying migration if it occurs while migrating the
live revision data"""
def setUp(self):
with disable_reference_index_auto_update():
self.create_instance()
(
self.invalid_revision_id,
self.invalid_revision_created_at,
) = self.create_invalid_revision(5)
self.instance.live_revision_id = self.invalid_revision_id
self.instance.save()
for i in range(1, 5):
self.create_revision(5 - i)
def test_migrate(self):
with self.assertRaisesMessage(
InvalidBlockDefError,
"Invalid block def in {} object ({}) for revision id ({}) created at {}".format(
self.instance.__class__.__name__,
self.instance.id,
self.invalid_revision_id,
self.invalid_revision_created_at,
),
):
self.apply_migration(revisions_from=None)
class TestExceptionIgnoredForOtherRevisions(BadDataMigrationTestCase):
"""Exception should not be be raised when applying migration if it occurs while migrating
revision data which is not of a live or latest revision. Instead an exception should be logged"""
model = models.SamplePage
def setUp(self):
with disable_reference_index_auto_update():
self.create_instance()
(
self.invalid_revision_id,
self.invalid_revision_created_at,
) = self.create_invalid_revision(5)
for i in range(1, 5):
self.create_revision(5 - i)
def test_migrate(self):
with self.assertLogs(level="ERROR") as cm:
self.apply_migration(revisions_from=None)
self.assertEqual(
cm.output[0].splitlines()[0],
"ERROR:{}:Invalid block def in {} object ({}) for revision id ({}) created at {}".format(
migrate_operation.__name__,
self.instance.__class__.__name__,
self.instance.id,
self.invalid_revision_id,
self.invalid_revision_created_at,
),
)
self.assertEqual(
cm.output[0].splitlines()[-1],
"{}: No current block def named invalid_name1".format(
InvalidBlockDefError.__module__
+ "."
+ InvalidBlockDefError.__name__
),
)

View File

@@ -0,0 +1,59 @@
from django.test import TestCase
from wagtail.blocks.migrations.operations import (
RemoveStreamChildrenOperation,
RenameStreamChildrenOperation,
)
from wagtail.test.streamfield_migrations import models
from wagtail.test.streamfield_migrations.testutils import MigrationTestMixin
class MigrationNameTest(TestCase, MigrationTestMixin):
model = models.SamplePage
app_name = "wagtail_streamfield_migration_toolkit_test"
def test_rename(self):
operations_and_block_path = [
(
RenameStreamChildrenOperation(old_name="char1", new_name="renamed1"),
"",
)
]
migration = self.init_migration(
operations_and_block_path=operations_and_block_path
)
suggested_name = migration.suggest_name()
self.assertEqual(suggested_name, "rename_char1_to_renamed1")
def test_remove(self):
operations_and_block_path = [
(
RemoveStreamChildrenOperation(name="char1"),
"",
)
]
migration = self.init_migration(
operations_and_block_path=operations_and_block_path
)
suggested_name = migration.suggest_name()
self.assertEqual(suggested_name, "remove_char1")
def test_multiple(self):
operations_and_block_path = [
(
RenameStreamChildrenOperation(old_name="char1", new_name="renamed1"),
"",
),
(
RemoveStreamChildrenOperation(name="char1"),
"simplestruct",
),
]
migration = self.init_migration(
operations_and_block_path=operations_and_block_path
)
suggested_name = migration.suggest_name()
self.assertEqual(suggested_name, "rename_char1_to_renamed1_remove_char1")

View File

@@ -0,0 +1,252 @@
import datetime
import json
from django.db import connection
from django.db.models import F, JSONField, TextField
from django.db.models.functions import Cast
from django.test import TestCase
from django.utils import timezone
from wagtail.blocks.migrations.operations import RenameStreamChildrenOperation
from wagtail.test.streamfield_migrations import factories, models
from wagtail.test.streamfield_migrations.testutils import MigrationTestMixin
# TODO test multiple operations in one go
class BaseMigrationTest(TestCase, MigrationTestMixin):
factory = None
has_revisions = False
default_operation_and_block_path = [
(
RenameStreamChildrenOperation(old_name="char1", new_name="renamed1"),
"",
)
]
app_name = None
def _get_test_instances(self):
return [
self.factory(
content__0__char1="Test char 1",
content__1__char1="Test char 2",
content__2__char2="Test char 3",
content__3__char2="Test char 4",
),
self.factory(
content__0__char1="Test char 1",
content__1__char1="Test char 2",
content__2__char2="Test char 3",
),
self.factory(
content__0__char2="Test char 1",
content__1__char2="Test char 2",
content__2__char2="Test char 3",
),
]
def setUp(self):
instances = self._get_test_instances()
self.original_raw_data = {}
self.original_revisions = {}
for instance in instances:
self.original_raw_data[instance.id] = instance.content.raw_data
if self.has_revisions:
for i in range(5):
revision = instance.save_revision()
revision.created_at = timezone.now() - datetime.timedelta(
days=(5 - i)
)
revision.save()
if i == 1:
instance.live_revision = revision
instance.save()
self.original_revisions[instance.id] = list(
instance.revisions.all().order_by("id")
)
def assertBlocksRenamed(self, old_content, new_content, is_altered=True):
for old_block, new_block in zip(old_content, new_content):
self.assertEqual(old_block["id"], new_block["id"])
if is_altered and old_block["type"] == "char1":
self.assertEqual(new_block["type"], "renamed1")
else:
self.assertEqual(old_block["type"], new_block["type"])
def _test_migrate_stream_data(self):
"""Test whether the stream data of the model instances have been updated properly
Apply the migration and then query the raw data of the updated instances. Compare with
original raw data and check whether all relevant `char1` blocks have been renamed and
whether ids and other block types are intact.
"""
self.apply_migration()
instances = self.model.objects.all().annotate(
raw_content=Cast(F("content"), JSONField())
)
for instance in instances:
prev_content = self.original_raw_data[instance.id]
self.assertBlocksRenamed(
old_content=prev_content, new_content=instance.raw_content
)
# TODO test multiple operations applied in one migration
def _test_migrate_revisions(self):
"""Test whether all revisions have been updated properly
Applying migration with `revisions_from=None`, so all revisions should be updated.
"""
self.apply_migration()
instances = self.model.objects.all()
for instance in instances:
old_revisions = self.original_revisions[instance.id]
for old_revision, new_revision in zip(
old_revisions, instance.revisions.all().order_by("id")
):
old_content = json.loads(old_revision.content["content"])
new_content = json.loads(new_revision.content["content"])
self.assertBlocksRenamed(
old_content=old_content, new_content=new_content
)
def _test_always_migrate_live_and_latest_revisions(self):
"""Test whether latest and live revisions are always updated
Applying migration with `revisions_from` set to a date in the future, so there should be
no revisions which are made after the date. Only the live and latest revisions should
update in this case.
"""
revisions_from = timezone.now() + datetime.timedelta(days=2)
self.apply_migration(revisions_from=revisions_from)
instances = self.model.objects.all()
for instance in instances:
old_revisions = self.original_revisions[instance.id]
for old_revision, new_revision in zip(
old_revisions, instance.revisions.all().order_by("id")
):
is_latest_or_live = old_revision.id == instance.live_revision_id or (
old_revision.id == instance.latest_revision_id
)
old_content = json.loads(old_revision.content["content"])
new_content = json.loads(new_revision.content["content"])
self.assertBlocksRenamed(
old_content=old_content,
new_content=new_content,
is_altered=is_latest_or_live,
)
def _test_migrate_revisions_from_date(self):
"""Test whether revisions from a given date onwards are updated
Applying migration with `revisions_from` set to a date between the created date of the first
and last revision, so only the revisions after the date and the live and latest revision
should be updated.
"""
revisions_from = timezone.now() - datetime.timedelta(days=2)
self.apply_migration(revisions_from=revisions_from)
instances = self.model.objects.all()
for instance in instances:
old_revisions = self.original_revisions[instance.id]
for old_revision, new_revision in zip(
old_revisions, instance.revisions.all().order_by("id")
):
is_latest_or_live = old_revision.id == instance.live_revision_id or (
old_revision.id == instance.latest_revision_id
)
is_after_revisions_from = old_revision.created_at > revisions_from
is_altered = is_latest_or_live or is_after_revisions_from
old_content = json.loads(old_revision.content["content"])
new_content = json.loads(new_revision.content["content"])
self.assertBlocksRenamed(
old_content=old_content,
new_content=new_content,
is_altered=is_altered,
)
class TestNonPageModelWithoutRevisions(BaseMigrationTest):
model = models.SampleModel
factory = factories.SampleModelFactory
has_revisions = False
app_name = "streamfield_migration_tests"
def test_migrate_stream_data(self):
self._test_migrate_stream_data()
class TestPage(BaseMigrationTest):
model = models.SamplePage
factory = factories.SamplePageFactory
has_revisions = True
app_name = "streamfield_migration_tests"
def test_migrate_stream_data(self):
self._test_migrate_stream_data()
def test_migrate_revisions(self):
self._test_migrate_revisions()
def test_always_migrate_live_and_latest_revisions(self):
self._test_always_migrate_live_and_latest_revisions()
def test_migrate_revisions_from_date(self):
self._test_migrate_revisions_from_date()
class TestNullStreamField(BaseMigrationTest):
"""
Migrations are processed if the underlying JSON is null.
This might occur if we're operating on a StreamField that was added to a model that
had existing records.
"""
model = models.SamplePage
factory = factories.SamplePageFactory
has_revisions = True
app_name = "streamfield_migration_tests"
def _get_test_instances(self):
return self.factory.create_batch(1, content=None)
def setUp(self):
super().setUp()
# Bypass StreamField/StreamBlock processing that cast a None stream field value
# to the empty StreamValue, and set the underlying JSON to null.
with connection.cursor() as cursor:
cursor.execute(f"UPDATE {self.model._meta.db_table} SET content = 'null'")
def assert_null_content(self):
"""
The raw JSON of all instances for this test is null.
"""
instances = self.model.objects.all().annotate(
raw_content=Cast(F("content"), TextField())
)
for instance in instances:
with self.subTest(instance=instance):
self.assertEqual(instance.raw_content, "null")
def test_migrate_stream_data(self):
self.assert_null_content()
self.apply_migration()
self.assert_null_content()

View File

@@ -0,0 +1,783 @@
from django.test import TestCase
from wagtail.blocks.migrations.operations import (
RemoveStreamChildrenOperation,
RemoveStructChildrenOperation,
RenameStreamChildrenOperation,
RenameStructChildrenOperation,
)
from wagtail.blocks.migrations.utils import apply_changes_to_raw_data
from wagtail.test.streamfield_migrations import factories, models
class FieldStructStreamChildBlockTest(TestCase):
"""Tests involving changes to children of a StreamBlock nested inside a StructBlock
We use `nestedstruct.simplestream` blocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstruct",
content__1__nestedstruct__list1__0__value="a",
content__1__nestedstruct__stream1__0__char1__value="Char Block 1",
content__1__nestedstruct__stream1__1__char2__value="Char Block 2",
content__1__nestedstruct__stream1__2__char1__value="Char Block 1",
content__2="nestedstruct",
content__2__nestedstruct__list1__0__value="a",
content__2__nestedstruct__stream1__0__char1__value="Char Block 1",
content__3="simplestream",
content__3__simplestream__0__char1__value="Char Block 1",
content__3__simplestream__1__char2__value="Char Block 2",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.stream1",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
for key in self.raw_data[1]["value"].keys():
self.assertIn(key, altered_raw_data[1]["value"])
for key in self.raw_data[1]["value"].keys():
self.assertIn(key, altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[1]["value"]["char1"], self.raw_data[1]["value"]["char1"]
)
self.assertEqual(
altered_raw_data[2]["value"]["char1"], self.raw_data[2]["value"]["char1"]
)
self.assertEqual(
altered_raw_data[1]["value"]["struct1"],
self.raw_data[1]["value"]["struct1"],
)
self.assertEqual(
altered_raw_data[2]["value"]["struct1"],
self.raw_data[2]["value"]["struct1"],
)
self.assertEqual(
altered_raw_data[1]["value"]["list1"], self.raw_data[1]["value"]["list1"]
)
self.assertEqual(
altered_raw_data[2]["value"]["list1"], self.raw_data[2]["value"]["list1"]
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.stream1",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[1]["value"]["stream1"][0]["type"], "renamed1")
self.assertEqual(altered_raw_data[1]["value"]["stream1"][2]["type"], "renamed1")
self.assertEqual(altered_raw_data[2]["value"]["stream1"][0]["type"], "renamed1")
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][0]["id"],
self.raw_data[1]["value"]["stream1"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][2]["id"],
self.raw_data[1]["value"]["stream1"][2]["id"],
)
self.assertEqual(
altered_raw_data[2]["value"]["stream1"][0]["id"],
self.raw_data[2]["value"]["stream1"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][0]["value"],
self.raw_data[1]["value"]["stream1"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][2]["value"],
self.raw_data[1]["value"]["stream1"][2]["value"],
)
self.assertEqual(
altered_raw_data[2]["value"]["stream1"][0]["value"],
self.raw_data[2]["value"]["stream1"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][1],
self.raw_data[1]["value"]["stream1"][1],
)
def test_remove(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.stream1",
operation=RemoveStreamChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"]["stream1"]), 1)
self.assertEqual(len(altered_raw_data[2]["value"]["stream1"]), 0)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"][0],
self.raw_data[1]["value"]["stream1"][1],
)
class FieldStructStructChildBlockTest(TestCase):
"""Tests involving changes to a children of a StructBlock nested inside a StructBlock
We use `nestedstruct.simplestruct` blocks here
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstruct",
content__1__nestedstruct__list1__0__value="a",
content__2="nestedstruct",
content__2__nestedstruct__list1__0__value="a",
content__3="simplestruct",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.struct1",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
for key in self.raw_data[1]["value"].keys():
self.assertIn(key, altered_raw_data[1]["value"])
for key in self.raw_data[1]["value"].keys():
self.assertIn(key, altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[1]["value"]["char1"], self.raw_data[1]["value"]["char1"]
)
self.assertEqual(
altered_raw_data[2]["value"]["char1"], self.raw_data[2]["value"]["char1"]
)
self.assertEqual(
altered_raw_data[1]["value"]["stream1"],
self.raw_data[1]["value"]["stream1"],
)
self.assertEqual(
altered_raw_data[2]["value"]["stream1"],
self.raw_data[2]["value"]["stream1"],
)
self.assertEqual(
altered_raw_data[1]["value"]["list1"], self.raw_data[1]["value"]["list1"]
)
self.assertEqual(
altered_raw_data[2]["value"]["list1"], self.raw_data[2]["value"]["list1"]
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.struct1",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertNotIn("char1", altered_raw_data[1]["value"]["struct1"])
self.assertNotIn("char1", altered_raw_data[2]["value"]["struct1"])
self.assertIn("renamed1", altered_raw_data[2]["value"]["struct1"])
self.assertIn("renamed1", altered_raw_data[2]["value"]["struct1"])
self.assertIn("char2", altered_raw_data[1]["value"]["struct1"])
self.assertIn("char2", altered_raw_data[2]["value"]["struct1"])
def test_remove(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstruct.struct1",
operation=RemoveStructChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"]["struct1"]), 1)
self.assertEqual(len(altered_raw_data[2]["value"]["struct1"]), 1)
self.assertNotIn("char1", altered_raw_data[1]["value"]["struct1"])
self.assertNotIn("char1", altered_raw_data[2]["value"]["struct1"])
self.assertIn("char2", altered_raw_data[1]["value"]["struct1"])
self.assertIn("char2", altered_raw_data[2]["value"]["struct1"])
class FieldStreamStreamChildBlockTest(TestCase):
"""Tests involving changes to children of a StreamBlock nested inside a StreamBlock.
We use `nestedstream.stream1` blocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstream",
content__1__nestedstream__0__char1__value="Char Block 1",
content__1__nestedstream__1="stream1",
content__1__nestedstream__1__stream1__0__char1__value="Char Block 1",
content__1__nestedstream__1__stream1__1__char2__value="Char Block 2",
content__1__nestedstream__1__stream1__2__char1__value="Char Block 1",
content__1__nestedstream__2="stream1",
content__1__nestedstream__2__stream1__0__char1__value="Char Block 1",
content__2="nestedstream",
content__2__nestedstream__0="stream1",
content__2__nestedstream__0__stream1__0__char1__value="Char Block 1",
content__3="simplestream",
content__3__simplestream__0__char1__value="Char Block 1",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.stream1",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[1]["value"][0], self.raw_data[1]["value"][0])
self.assertEqual(
altered_raw_data[1]["value"][1]["type"],
self.raw_data[1]["value"][1]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["type"],
self.raw_data[1]["value"][2]["type"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["type"],
self.raw_data[2]["value"][0]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["id"],
self.raw_data[1]["value"][1]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["id"],
self.raw_data[1]["value"][2]["id"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["id"],
self.raw_data[2]["value"][0]["id"],
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.stream1",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][2]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0]["id"],
self.raw_data[1]["value"][1]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][2]["id"],
self.raw_data[1]["value"][1]["value"][2]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"][0]["id"],
self.raw_data[1]["value"][2]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"][0]["id"],
self.raw_data[2]["value"][0]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][1],
self.raw_data[1]["value"][1]["value"][1],
)
def test_remove(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.stream1",
operation=RemoveStreamChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"][1]["value"]), 1)
self.assertEqual(len(altered_raw_data[1]["value"][2]["value"]), 0)
self.assertEqual(len(altered_raw_data[2]["value"][0]["value"]), 0)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0],
self.raw_data[1]["value"][1]["value"][1],
)
class FieldStreamStructChildBlockTest(TestCase):
"""Tests involving changes to children of a StructBlock nested inside a StreamBlock.
We use `nestedstream.simplestruct` blocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedstream",
content__1__nestedstream__0__char1="Char Block 1",
content__1__nestedstream__1="struct1",
content__1__nestedstream__2="struct1",
content__2="nestedstream",
content__2__nestedstream__0="struct1",
content__3="simplestream",
content__3__simplestream__0__char1__value="Char Block 1",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.struct1",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[1]["value"][0], self.raw_data[1]["value"][0])
self.assertEqual(
altered_raw_data[1]["value"][1]["type"],
self.raw_data[1]["value"][1]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["type"],
self.raw_data[1]["value"][2]["type"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["type"],
self.raw_data[2]["value"][0]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["id"],
self.raw_data[1]["value"][1]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["id"],
self.raw_data[1]["value"][2]["id"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["id"],
self.raw_data[2]["value"][0]["id"],
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.struct1",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertNotIn("char1", altered_raw_data[1]["value"][1]["value"])
self.assertNotIn("char1", altered_raw_data[1]["value"][2]["value"])
self.assertNotIn("char1", altered_raw_data[2]["value"][0]["value"])
self.assertIn("renamed1", altered_raw_data[1]["value"][1]["value"])
self.assertIn("renamed1", altered_raw_data[1]["value"][2]["value"])
self.assertIn("renamed1", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["renamed1"],
self.raw_data[1]["value"][1]["value"]["char1"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"]["renamed1"],
self.raw_data[1]["value"][2]["value"]["char1"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["renamed1"],
self.raw_data[2]["value"][0]["value"]["char1"],
)
self.assertIn("char2", altered_raw_data[1]["value"][1]["value"])
self.assertIn("char2", altered_raw_data[1]["value"][2]["value"])
self.assertIn("char2", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["char2"],
self.raw_data[1]["value"][1]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"]["char2"],
self.raw_data[1]["value"][2]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["char2"],
self.raw_data[2]["value"][0]["value"]["char2"],
)
def test_remove(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedstream.struct1",
operation=RemoveStructChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"][1]["value"]), 1)
self.assertEqual(len(altered_raw_data[1]["value"][2]["value"]), 1)
self.assertEqual(len(altered_raw_data[2]["value"][0]["value"]), 1)
self.assertIn("char2", altered_raw_data[1]["value"][1]["value"])
self.assertIn("char2", altered_raw_data[1]["value"][2]["value"])
self.assertIn("char2", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["char2"],
self.raw_data[1]["value"][1]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"]["char2"],
self.raw_data[1]["value"][2]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["char2"],
self.raw_data[2]["value"][0]["value"]["char2"],
)
class FieldListStreamChildBlockTest(TestCase):
"""Tests involving changes to children of a StreamBlock nested inside a ListBlock.
We use `nestedlist_stream.item` blocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="nestedlist_stream",
content__1__nestedlist_stream__0__0__char1__value="Char Block 1",
content__1__nestedlist_stream__0__1__char2__value="Char Block 2",
content__1__nestedlist_stream__0__2__char1__value="Char Block 1",
content__1__nestedlist_stream__1__0__char1__value="Char Block 1",
content__2="nestedlist_stream",
content__2__nestedlist_stream__0__0__char1__value="Char Block 1",
content__3="simplestream",
content__3__simplestream__0__char1__value="Char Block 1",
content__3__simplestream__1__char2__value="Char Block 2",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedlist_stream.item",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(
altered_raw_data[1]["value"][0]["type"],
self.raw_data[1]["value"][0]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["type"],
self.raw_data[1]["value"][1]["type"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["type"],
self.raw_data[2]["value"][0]["type"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["id"], self.raw_data[1]["value"][0]["id"]
)
self.assertEqual(
altered_raw_data[1]["value"][1]["id"], self.raw_data[1]["value"][1]["id"]
)
self.assertEqual(
altered_raw_data[2]["value"][0]["id"], self.raw_data[2]["value"][0]["id"]
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedlist_stream.item",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][2]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"][0]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][0]["id"],
self.raw_data[1]["value"][0]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][2]["id"],
self.raw_data[1]["value"][0]["value"][2]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0]["id"],
self.raw_data[1]["value"][1]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"][0]["id"],
self.raw_data[2]["value"][0]["value"][0]["id"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][0]["value"],
self.raw_data[1]["value"][0]["value"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][2]["value"],
self.raw_data[1]["value"][0]["value"][2]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"][0]["value"],
self.raw_data[1]["value"][1]["value"][0]["value"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"][0]["value"],
self.raw_data[2]["value"][0]["value"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][1],
self.raw_data[1]["value"][0]["value"][1],
)
def test_remove(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedlist_stream.item",
operation=RemoveStreamChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"][0]["value"]), 1)
self.assertEqual(len(altered_raw_data[1]["value"][1]["value"]), 0)
self.assertEqual(len(altered_raw_data[2]["value"][0]["value"]), 0)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][0],
self.raw_data[1]["value"][0]["value"][1],
)
class FieldListStructChildBlockTest(TestCase):
"""Tests involving changes to children of a StructBlock nested inside a ListBlock.
We use `nestedlist_struct.item` blocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1__nestedlist_struct__0__char1="Nested List Struct 1",
content__1__nestedlist_struct__1__char1="Nested List Struct 2",
content__2__nestedlist_struct__0__char1="Nested List Struct 3",
content__3="simplestruct",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedlist_struct.item",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(
altered_raw_data[1]["value"][0]["id"], self.raw_data[1]["value"][0]["id"]
)
self.assertEqual(
altered_raw_data[1]["value"][1]["id"], self.raw_data[1]["value"][1]["id"]
)
self.assertEqual(
altered_raw_data[2]["value"][0]["id"], self.raw_data[2]["value"][0]["id"]
)
def test_rename(self):
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="nestedlist_struct.item",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertNotIn("char1", altered_raw_data[1]["value"][0]["value"])
self.assertNotIn("char1", altered_raw_data[1]["value"][1]["value"])
self.assertNotIn("char1", altered_raw_data[2]["value"][0]["value"])
self.assertIn("renamed1", altered_raw_data[1]["value"][0]["value"])
self.assertIn("renamed1", altered_raw_data[1]["value"][1]["value"])
self.assertIn("renamed1", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][0]["value"]["renamed1"],
self.raw_data[1]["value"][0]["value"]["char1"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["renamed1"],
self.raw_data[1]["value"][1]["value"]["char1"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["renamed1"],
self.raw_data[2]["value"][0]["value"]["char1"],
)
self.assertIn("char2", altered_raw_data[1]["value"][0]["value"])
self.assertIn("char2", altered_raw_data[1]["value"][1]["value"])
self.assertIn("char2", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][0]["value"]["char2"],
self.raw_data[1]["value"][0]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["char2"],
self.raw_data[1]["value"][1]["value"]["char2"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["char2"],
self.raw_data[2]["value"][0]["value"]["char2"],
)

View File

@@ -0,0 +1,264 @@
from django.test import TestCase
from wagtail.blocks.migrations.operations import (
ListChildrenToStructBlockOperation,
RenameStreamChildrenOperation,
RenameStructChildrenOperation,
)
from wagtail.blocks.migrations.utils import apply_changes_to_raw_data
from wagtail.test.streamfield_migrations import models
class OldListFormatNestedStreamTestCase(TestCase):
"""Tests involving changes to ListBlocks in the old format with StreamBlock children"""
@classmethod
def setUpTestData(cls):
raw_data = [
{"type": "char1", "id": "0001", "value": "Char Block 1"},
{
"type": "nestedlist_stream",
"id": "0002",
"value": [
[
{"type": "char1", "id": "0003", "value": "Char Block 1"},
{"type": "char2", "id": "0004", "value": "Char Block 2"},
{"type": "char1", "id": "0005", "value": "Char Block 1"},
],
[
{"type": "char1", "id": "0006", "value": "Char Block 1"},
],
],
},
{
"type": "nestedlist_stream",
"id": "0007",
"value": [
[
{"type": "char1", "id": "0008", "value": "Char Block 1"},
]
],
},
]
cls.raw_data = raw_data
def test_list_converted_to_new_format_in_recursion(self):
"""Test whether all ListBlock children have converted formats during the recursion.
This tests the changes done in the recursion process only, so the operation used isn't
important. We will use a rename operation for now.
Check whether each ListBlock child has attributes id, value, type and type is item.
Check whether rename operation was done successfully.
"""
altered_raw_data = apply_changes_to_raw_data(
self.raw_data,
"nestedlist_stream.item",
RenameStreamChildrenOperation(old_name="char1", new_name="renamed1"),
streamfield=models.SampleModel.content,
)
for listitem in altered_raw_data[1]["value"]:
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
for listitem in altered_raw_data[2]["value"]:
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
# the nested blocks which were renamed
altered_block_path_indices = [
(1, 0, 0),
(1, 0, 2),
(1, 1, 0),
(2, 0, 0),
]
for ind0, ind1, ind2 in altered_block_path_indices:
self.assertEqual(
altered_raw_data[ind0]["value"][ind1]["value"][ind2]["type"], "renamed1"
)
self.assertEqual(
altered_raw_data[ind0]["value"][ind1]["value"][ind2]["id"],
self.raw_data[ind0]["value"][ind1][ind2]["id"],
)
self.assertEqual(
altered_raw_data[ind0]["value"][ind1]["value"][ind2]["value"],
self.raw_data[ind0]["value"][ind1][ind2]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"][1],
self.raw_data[1]["value"][0][1],
)
def test_list_converted_to_new_format_in_operation(self):
"""Test whether all ListBlock children have converted formats in an operation using the generator
We will test this with the ListChildrenToStructBlockOperation.
Check whether each ListBlock child has attributes id, value, type and type is item.
Check whether the ListBlock child value is a struct with the previous block as value.
Check whether the previous values are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
self.raw_data,
"nestedlist_stream",
ListChildrenToStructBlockOperation(block_name="stream1"),
streamfield=models.SampleModel.content,
)
for ind, listitem in enumerate(altered_raw_data[1]["value"]):
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
self.assertIsInstance(listitem["value"], dict)
self.assertEqual(len(listitem["value"]), 1)
self.assertIn("stream1", listitem["value"])
self.assertEqual(
listitem["value"]["stream1"], self.raw_data[1]["value"][ind]
)
for ind, listitem in enumerate(altered_raw_data[2]["value"]):
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
self.assertIsInstance(listitem["value"], dict)
self.assertEqual(len(listitem["value"]), 1)
self.assertIn("stream1", listitem["value"])
self.assertEqual(
listitem["value"]["stream1"], self.raw_data[2]["value"][ind]
)
class OldListFormatNestedStructTestCase(TestCase):
"""Tests involving changes to ListBlocks in the old format with StructBlock children"""
@classmethod
def setUpTestData(cls):
raw_data = [
{"type": "char1", "id": "0001", "value": "Char Block 1"},
{
"type": "nestedlist_struct",
"id": "0002",
"value": [
{"char1": "Char Block 1", "char2": "Char Block 2"},
{"char1": "Char Block 1", "char2": "Char Block 2"},
],
},
{
"type": "nestedlist_struct",
"id": "0007",
"value": [
{"char1": "Char Block 1", "char2": "Char Block 2"},
],
},
]
cls.raw_data = raw_data
def test_list_converted_to_new_format_in_recursion(self):
"""Test whether all ListBlock children have converted formats during the recursion.
This tests the changes done in the recursion process only, so the operation used isn't
important. We will use a rename operation for now.
Check whether each ListBlock child has attributes id, value, type and type is item.
Check whether rename operation was done successfully.
"""
altered_raw_data = apply_changes_to_raw_data(
self.raw_data,
"nestedlist_struct.item",
RenameStructChildrenOperation(old_name="char1", new_name="renamed1"),
streamfield=models.SampleModel.content,
)
for listitem in altered_raw_data[1]["value"]:
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
for listitem in altered_raw_data[2]["value"]:
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
# The nested blocks which were renamed
altered_block_indices = [(1, 0), (1, 1), (2, 0)]
for ind0, ind1 in altered_block_indices:
self.assertNotIn("char1", altered_raw_data[ind0]["value"][ind1]["value"])
self.assertIn("renamed1", altered_raw_data[ind0]["value"][ind1]["value"])
self.assertEqual(
altered_raw_data[ind0]["value"][ind1]["value"]["renamed1"],
self.raw_data[ind0]["value"][ind1]["char1"],
)
self.assertIn("char2", altered_raw_data[ind0]["value"][ind1]["value"])
self.assertEqual(
altered_raw_data[ind0]["value"][ind1]["value"]["char2"],
self.raw_data[ind0]["value"][ind1]["char2"],
)
def test_list_converted_to_new_format_in_operation(self):
"""Test whether all ListBlock children have converted formats in an operation using the generator
We will test this with the ListChildrenToStructBlockOperation.
Check whether each ListBlock child has attributes id, value, type and type is item.
Check whether the ListBlock child value is a struct with the previous block as value.
Check whether the previous values are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
self.raw_data,
"nestedlist_struct",
ListChildrenToStructBlockOperation(block_name="struct1"),
streamfield=models.SampleModel.content,
)
for ind, listitem in enumerate(altered_raw_data[1]["value"]):
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
self.assertIsInstance(listitem["value"], dict)
self.assertEqual(len(listitem["value"]), 1)
self.assertIn("struct1", listitem["value"])
self.assertEqual(
listitem["value"]["struct1"], self.raw_data[1]["value"][ind]
)
for ind, listitem in enumerate(altered_raw_data[2]["value"]):
self.assertIsInstance(listitem, dict)
self.assertIn("type", listitem)
self.assertIn("value", listitem)
self.assertEqual(listitem["type"], "item")
self.assertIsInstance(listitem["value"], dict)
self.assertEqual(len(listitem["value"]), 1)
self.assertIn("struct1", listitem["value"])
self.assertEqual(
listitem["value"]["struct1"], self.raw_data[2]["value"][ind]
)

View File

@@ -0,0 +1,535 @@
from django.test import TestCase
from wagtail.blocks.migrations.operations import (
AlterBlockValueOperation,
ListChildrenToStructBlockOperation,
RemoveStreamChildrenOperation,
RemoveStructChildrenOperation,
RenameStreamChildrenOperation,
RenameStructChildrenOperation,
StreamChildrenToListBlockOperation,
StreamChildrenToStreamBlockOperation,
StreamChildrenToStructBlockOperation,
)
from wagtail.blocks.migrations.utils import apply_changes_to_raw_data
from wagtail.test.streamfield_migrations import factories, models
class FieldChildBlockTest(TestCase):
"""Tests involving changes to top level blocks"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1__char2__value="Char Block 2",
content__2__char1__value="Char Block 1",
content__3__char2__value="Char Block 2",
).content.raw_data
self.raw_data = raw_data
def test_rename(self):
"""Rename `char1` blocks to `renamed1`
Check whether all `char1` blocks have been renamed correctly.
Check whether ids and values for renamed blocks are intact.
Check whether other block types are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0]["type"], "renamed1")
self.assertEqual(altered_raw_data[2]["type"], "renamed1")
self.assertEqual(altered_raw_data[0]["id"], self.raw_data[0]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[0]["value"], self.raw_data[0]["value"])
self.assertEqual(altered_raw_data[2]["value"], self.raw_data[2]["value"])
self.assertEqual(altered_raw_data[1], self.raw_data[1])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
def test_remove(self):
"""Remove all `char1` blocks
Check whether all `char1` blocks have been removed and whether other blocks are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=RemoveStreamChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 2)
self.assertEqual(altered_raw_data[0], self.raw_data[1])
self.assertEqual(altered_raw_data[1], self.raw_data[3])
def test_combine_to_listblock(self):
"""Combine all `char1` blocks into a new ListBlock named `list1`
Check whether no `char1` blocks are present among the stream children and whether other
blocks are intact.
Check whether a new `list1` block has been added to the stream children and whether it has
child blocks corresponding to the previous `char1` blocks.
Check whether the ids and values from the `char1` blocks are intact in the list children.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToListBlockOperation(
block_name="char1", list_block_name="list1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 3)
self.assertEqual(altered_raw_data[0], self.raw_data[1])
self.assertEqual(altered_raw_data[1], self.raw_data[3])
self.assertEqual(altered_raw_data[2]["type"], "list1")
self.assertEqual(len(altered_raw_data[2]["value"]), 2)
self.assertEqual(altered_raw_data[2]["value"][0]["type"], "item")
self.assertEqual(altered_raw_data[2]["value"][1]["type"], "item")
self.assertEqual(altered_raw_data[2]["value"][0]["id"], self.raw_data[0]["id"])
self.assertEqual(altered_raw_data[2]["value"][1]["id"], self.raw_data[2]["id"])
self.assertEqual(
altered_raw_data[2]["value"][0]["value"], self.raw_data[0]["value"]
)
self.assertEqual(
altered_raw_data[2]["value"][1]["value"], self.raw_data[2]["value"]
)
def test_combine_to_listblock_no_existing_children(self):
"""Combine all `simplestruct` blocks into a new ListBlock named `list1`
We have no `simplestruct` blocks in our existing data, so there should be no list1 blocks
created and the data should be intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToListBlockOperation(
block_name="simplestruct", list_block_name="list1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 4)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[1], self.raw_data[1])
self.assertEqual(altered_raw_data[2], self.raw_data[2])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
def test_combine_single_type_to_streamblock(self):
"""Combine all `char1` blocks as children of a new StreamBlock named `stream1`
Check whether no `char1` blocks are present among the (top) stream children and whether
other blocks are intact.
Check whether a new `stream1` block has been added to the (top) stream children.
Check whether the new `stream1` block has the `char1` blocks as children and whether they
are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToStreamBlockOperation(
block_names=["char1"], stream_block_name="stream1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 3)
self.assertEqual(altered_raw_data[0], self.raw_data[1])
self.assertEqual(altered_raw_data[1], self.raw_data[3])
self.assertEqual(altered_raw_data[2]["type"], "stream1")
self.assertEqual(len(altered_raw_data[2]["value"]), 2)
self.assertEqual(altered_raw_data[2]["value"][0], self.raw_data[0])
self.assertEqual(altered_raw_data[2]["value"][1], self.raw_data[2])
def test_combine_multiple_types_to_streamblock(self):
"""Combine all `char1` and `char2` blocks as children of a new StreamBlock named `stream1`
Check whether no `char1` or `char2` blocks are present among the (top) stream children.
Check whether a new `stream1` block has been added to the (top) stream children.
Check whether the new `stream1` block has the `char1` and `char2` blocks as children and
that they are intact.
Note:
We only have `char1` and `char2` blocks in our existing data.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToStreamBlockOperation(
block_names=["char1", "char2"], stream_block_name="stream1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 1)
self.assertEqual(altered_raw_data[0]["type"], "stream1")
self.assertEqual(len(altered_raw_data[0]["value"]), 4)
self.assertEqual(altered_raw_data[0]["value"][0], self.raw_data[0])
self.assertEqual(altered_raw_data[0]["value"][1], self.raw_data[1])
self.assertEqual(altered_raw_data[0]["value"][2], self.raw_data[2])
self.assertEqual(altered_raw_data[0]["value"][3], self.raw_data[3])
def test_combine_to_streamblock_no_existing_children(self):
"""Combine all `simplestruct` blocks as children of a new StreamBlock named `stream1`
We have no `simplestruct` blocks in our existing data, so there should be no stream1 blocks
created and the data should be intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToStreamBlockOperation(
block_names=["simplestruct"], stream_block_name="stream1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data), 4)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[1], self.raw_data[1])
self.assertEqual(altered_raw_data[2], self.raw_data[2])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
def test_to_structblock(self):
"""Move each `char1` block inside a new StructBlock named `struct1`
Check whether each `char1` block has been replaced with a `struct1` block in the stream
children.
Check whether other blocks are intact.
Check whether each `struct1` block has a `char1` child and whether it has the value of the
previous `char1` block.
Note:
Block ids are not preserved here.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="",
operation=StreamChildrenToStructBlockOperation("char1", "struct1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0]["type"], "struct1")
self.assertEqual(altered_raw_data[2]["type"], "struct1")
self.assertEqual(altered_raw_data[1], self.raw_data[1])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertIn("char1", altered_raw_data[0]["value"])
self.assertIn("char1", altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[0]["value"]["char1"], self.raw_data[0]["value"]
)
self.assertEqual(
altered_raw_data[2]["value"]["char1"], self.raw_data[2]["value"]
)
def test_alter_value(self):
"""Change the value of each `char1` block to `foo`
Check whether the value of each `char1` block has changed to `foo`.
Check whether the values of other blocks are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="char1",
operation=AlterBlockValueOperation(new_value="foo"),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0]["value"], "foo")
self.assertEqual(altered_raw_data[1]["value"], self.raw_data[1]["value"])
self.assertEqual(altered_raw_data[2]["value"], "foo")
self.assertEqual(altered_raw_data[3]["value"], self.raw_data[3]["value"])
class FieldStructChildBlockTest(TestCase):
"""Tests involving changes to direct children of a StructBlock
We use `simplestruct` blocks as the StructBlocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="simplestruct",
content__2="simplestruct",
content__3__char2__value="Char Block 2",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestruct",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
def test_rename(self):
"""Rename `simplestruct.char1` blocks to `renamed1`
Check whether all `simplestruct.char1` blocks have been renamed correctly.
Check whether values for renamed blocks are intact.
Check whether other children of `simplestruct` are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestruct",
operation=RenameStructChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"]), 2)
self.assertEqual(len(altered_raw_data[2]["value"]), 2)
self.assertNotIn("char1", altered_raw_data[1]["value"])
self.assertNotIn("char1", altered_raw_data[2]["value"])
self.assertIn("renamed1", altered_raw_data[1]["value"])
self.assertIn("renamed1", altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[1]["value"]["renamed1"], self.raw_data[1]["value"]["char1"]
)
self.assertEqual(
altered_raw_data[2]["value"]["renamed1"], self.raw_data[2]["value"]["char1"]
)
self.assertIn("char2", altered_raw_data[1]["value"])
self.assertIn("char2", altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[1]["value"]["char2"], self.raw_data[1]["value"]["char2"]
)
self.assertEqual(
altered_raw_data[2]["value"]["char2"], self.raw_data[2]["value"]["char2"]
)
def test_remove(self):
"""Remove `simplestruct.char1` blocks
Check whether all `simplestruct.char1` blocks have been removed.
Check whether other children of `simplestruct` are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestruct",
operation=RemoveStructChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"]), 1)
self.assertEqual(len(altered_raw_data[2]["value"]), 1)
self.assertNotIn("char1", altered_raw_data[1]["value"])
self.assertNotIn("char1", altered_raw_data[2]["value"])
self.assertIn("char2", altered_raw_data[1]["value"])
self.assertIn("char2", altered_raw_data[2]["value"])
self.assertEqual(
altered_raw_data[1]["value"]["char2"], self.raw_data[1]["value"]["char2"]
)
self.assertEqual(
altered_raw_data[2]["value"]["char2"], self.raw_data[2]["value"]["char2"]
)
class FieldStreamChildBlockTest(TestCase):
"""Tests involving changes to direct children of a StreamBlock
We use `simplestream` blocks as the StreamBlocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1="simplestream",
content__1__simplestream__0__char1__value="Char Block 1",
content__1__simplestream__1__char2__value="Char Block 2",
content__1__simplestream__2__char1__value="Char Block 1",
content__2="simplestream",
content__2__simplestream__0__char1__value="Char Block 1",
content__3__char2__value="Char Block 2",
).content.raw_data
self.raw_data = raw_data
def test_blocks_and_data_not_operated_on_intact(self):
"""Test whether other blocks and data not passed to an operation are intact.
We are checking whether the parts of the data which are not passed to an operation are
intact. Since the recursion process depends just on the block path and block structure,
this check is independent of the operation used. We will use a rename operation for now.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestream",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[0], self.raw_data[0])
self.assertEqual(altered_raw_data[3], self.raw_data[3])
self.assertEqual(altered_raw_data[1]["id"], self.raw_data[1]["id"])
self.assertEqual(altered_raw_data[2]["id"], self.raw_data[2]["id"])
self.assertEqual(altered_raw_data[1]["type"], self.raw_data[1]["type"])
self.assertEqual(altered_raw_data[2]["type"], self.raw_data[2]["type"])
def test_rename(self):
"""Rename `simplestream.char1` blocks to `renamed1`
Check whether all `simplestream.char1` blocks have been renamed correctly.
Check whether values and ids for renamed blocks are intact.
Check whether other children of `simplestream` are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestream",
operation=RenameStreamChildrenOperation(
old_name="char1", new_name="renamed1"
),
streamfield=models.SampleModel.content,
)
self.assertEqual(altered_raw_data[1]["value"][0]["type"], "renamed1")
self.assertEqual(altered_raw_data[1]["value"][2]["type"], "renamed1")
self.assertEqual(altered_raw_data[2]["value"][0]["type"], "renamed1")
self.assertEqual(
altered_raw_data[1]["value"][0]["id"], self.raw_data[1]["value"][0]["id"]
)
self.assertEqual(
altered_raw_data[1]["value"][2]["id"], self.raw_data[1]["value"][2]["id"]
)
self.assertEqual(
altered_raw_data[2]["value"][0]["id"], self.raw_data[2]["value"][0]["id"]
)
self.assertEqual(
altered_raw_data[1]["value"][0]["value"],
self.raw_data[1]["value"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][2]["value"],
self.raw_data[1]["value"][2]["value"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"],
self.raw_data[2]["value"][0]["value"],
)
self.assertEqual(altered_raw_data[1]["value"][1], self.raw_data[1]["value"][1])
def test_remove(self):
"""Remove `simplestream.char1` blocks
Check whether all `simplestream.char1` blocks have been removed.
Check whether other children of `simplestream` are intact.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplestream",
operation=RemoveStreamChildrenOperation(name="char1"),
streamfield=models.SampleModel.content,
)
self.assertEqual(len(altered_raw_data[1]["value"]), 1)
self.assertEqual(len(altered_raw_data[2]["value"]), 0)
self.assertEqual(altered_raw_data[1]["value"][0], self.raw_data[1]["value"][1])
class FieldListChildBlockTest(TestCase):
"""Tests involving changes to direct children of a ListBlock
We use `simplelist` blocks as the ListBlocks here.
"""
def setUp(self):
raw_data = factories.SampleModelFactory(
content__0__char1__value="Char Block 1",
content__1__simplelist__0="Foo 1",
content__1__simplelist__1="Foo 2",
content__2__simplelist__0="Foo 3",
).content.raw_data
self.raw_data = raw_data
def test_to_structblock(self):
"""Turn each list child into a StructBlock and move value inside as a child named `text`
Check whether each list child has been converted to a StructBlock with a child named `text`
in it.
Check whether the previous value of each list child is now the value that `text` takes.
Note:
Block ids are not preserved here.
"""
altered_raw_data = apply_changes_to_raw_data(
raw_data=self.raw_data,
block_path_str="simplelist",
operation=ListChildrenToStructBlockOperation(block_name="text"),
streamfield=models.SampleModel.content,
)
self.assertEqual(type(altered_raw_data[1]["value"][0]["value"]), dict)
self.assertEqual(type(altered_raw_data[1]["value"][1]["value"]), dict)
self.assertEqual(type(altered_raw_data[2]["value"][0]["value"]), dict)
self.assertIn("text", altered_raw_data[1]["value"][0]["value"])
self.assertIn("text", altered_raw_data[1]["value"][1]["value"])
self.assertIn("text", altered_raw_data[2]["value"][0]["value"])
self.assertEqual(
altered_raw_data[1]["value"][0]["value"]["text"],
self.raw_data[1]["value"][0]["value"],
)
self.assertEqual(
altered_raw_data[1]["value"][1]["value"]["text"],
self.raw_data[1]["value"][1]["value"],
)
self.assertEqual(
altered_raw_data[2]["value"][0]["value"]["text"],
self.raw_data[2]["value"][0]["value"],
)

View File

@@ -0,0 +1,658 @@
import datetime
import json
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.core.serializers.json import DjangoJSONEncoder
from django.test import TestCase
from django.utils import timezone
from freezegun import freeze_time
from wagtail.log_actions import LogActionRegistry
from wagtail.log_actions import registry as log_registry
from wagtail.models import (
Page,
PageLogEntry,
PageViewRestriction,
Task,
Workflow,
WorkflowTask,
)
from wagtail.models.audit_log import ModelLogEntry
from wagtail.test.testapp.models import FullFeaturedSnippet, SimplePage
from wagtail.test.utils import WagtailTestUtils
class TestAuditLogManager(WagtailTestUtils, TestCase):
def setUp(self):
self.user = self.create_superuser(
username="administrator",
email="administrator@email.com",
password="password",
)
self.page = Page.objects.get(pk=1)
self.simple_page = self.page.add_child(
instance=SimplePage(
title="Simple page", slug="simple", content="Hello", owner=self.user
)
)
self.snippet_1 = FullFeaturedSnippet.objects.create(text="snippet 1")
self.snippet_2 = FullFeaturedSnippet.objects.create(text="snippet 2")
self.snippet_content_type = ContentType.objects.get_for_model(
FullFeaturedSnippet
)
def test_log_action_for_page(self):
now = timezone.now()
with freeze_time(now):
entry = PageLogEntry.objects.log_action(
self.page, "wagtail.edit", user=self.user
)
self.assertEqual(entry.content_type, self.page.content_type)
self.assertEqual(entry.user, self.user)
self.assertEqual(entry.timestamp, now)
def test_log_action_for_snippet(self):
now = timezone.now()
with freeze_time(now):
entry = ModelLogEntry.objects.log_action(
self.snippet_1, "wagtail.edit", user=self.user
)
self.assertEqual(entry.content_type, self.snippet_content_type)
self.assertEqual(entry.user, self.user)
self.assertEqual(entry.timestamp, now)
def test_get_for_page_model(self):
PageLogEntry.objects.log_action(self.page, "wagtail.edit")
PageLogEntry.objects.log_action(self.simple_page, "wagtail.edit")
entries = PageLogEntry.objects.get_for_model(SimplePage)
self.assertEqual(entries.count(), 2)
self.assertListEqual(
list(entries), list(PageLogEntry.objects.filter(page=self.simple_page))
)
def test_get_for_snippet_model(self):
ModelLogEntry.objects.log_action(self.snippet_1, "wagtail.edit")
ModelLogEntry.objects.log_action(self.snippet_2, "wagtail.edit")
entries = ModelLogEntry.objects.get_for_model(FullFeaturedSnippet)
self.assertEqual(entries.count(), 2)
self.assertListEqual(
list(entries),
list(ModelLogEntry.objects.filter(content_type=self.snippet_content_type)),
)
def test_get_for_user(self):
self.assertEqual(
PageLogEntry.objects.get_for_user(self.user).count(), 1
) # the create from setUp
def test_get_for_page_instance(self):
PageLogEntry.objects.log_action(self.page, "wagtail.edit")
PageLogEntry.objects.log_action(self.simple_page, "wagtail.edit")
other_simple_page = self.page.add_child(
instance=SimplePage(
title="Simple page 2", slug="simple2", content="Hello", owner=self.user
)
)
PageLogEntry.objects.log_action(other_simple_page, "wagtail.edit")
entries = PageLogEntry.objects.for_instance(self.simple_page)
expected_entries = list(PageLogEntry.objects.filter(page=self.simple_page))
self.assertEqual(entries.count(), 2)
self.assertListEqual(list(entries), expected_entries)
# should also be able to retrieve entries via the log registry, which
# eliminates the need to know that PageLogEntry is the log entry model
entries = log_registry.get_logs_for_instance(self.simple_page)
self.assertEqual(entries.count(), 2)
self.assertListEqual(list(entries), expected_entries)
def test_get_for_snippet_instance(self):
ModelLogEntry.objects.log_action(self.snippet_1, "wagtail.edit")
ModelLogEntry.objects.log_action(self.snippet_2, "wagtail.edit")
entries = ModelLogEntry.objects.for_instance(self.snippet_1)
expected_entries = list(
ModelLogEntry.objects.filter(
content_type=self.snippet_content_type, object_id=self.snippet_1.pk
)
)
self.assertEqual(entries.count(), 1)
self.assertListEqual(list(entries), expected_entries)
# should also be able to retrieve entries via the log registry, which
# eliminates the need to know that ModelLogEntry is the log entry model
entries = log_registry.get_logs_for_instance(self.snippet_1)
self.assertEqual(entries.count(), 1)
self.assertListEqual(list(entries), expected_entries)
class TestAuditLog(TestCase):
def setUp(self):
self.root_page = Page.objects.get(id=1)
self.home_page = self.root_page.add_child(
instance=SimplePage(title="Homepage", slug="home2", content="hello")
)
PageLogEntry.objects.all().delete() # clean up the log entries here.
def test_page_create(self):
self.assertEqual(PageLogEntry.objects.count(), 0) # homepage
page = self.home_page.add_child(
instance=SimplePage(title="Hello", slug="my-page", content="world")
)
self.assertEqual(PageLogEntry.objects.count(), 1)
log_entry = PageLogEntry.objects.order_by("pk").last()
self.assertEqual(log_entry.action, "wagtail.create")
self.assertEqual(log_entry.page_id, page.id)
self.assertEqual(log_entry.content_type, page.content_type)
self.assertEqual(log_entry.label, page.get_admin_display_title())
def test_alias_create_from_published_page_doesnt_log_publish_action(self):
self.home_page.live = True
self.home_page.save()
alias = self.home_page.create_alias(update_slug="the-alias")
self.assertTrue(alias.live)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.publish").count(), 0
)
def test_page_edit(self):
# Directly saving a revision should not yield a log entry
self.home_page.save_revision()
self.assertEqual(PageLogEntry.objects.count(), 0)
# Explicitly ask to record the revision change
self.home_page.save_revision(log_action=True)
self.assertEqual(PageLogEntry.objects.count(), 1)
self.assertEqual(PageLogEntry.objects.filter(action="wagtail.edit").count(), 1)
# passing a string for the action should log this.
self.home_page.save_revision(log_action="wagtail.revert")
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.revert").count(), 1
)
def test_page_publish(self):
revision = self.home_page.save_revision()
revision.publish()
self.assertEqual(PageLogEntry.objects.count(), 1)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.publish").count(), 1
)
def test_page_publish_doesnt_log_for_aliases(self):
self.home_page.create_alias(update_slug="the-alias")
revision = self.home_page.save_revision()
revision.publish()
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.publish").count(), 1
)
def test_page_rename(self):
# Should not log a name change when publishing the first revision
revision = self.home_page.save_revision()
self.home_page.title = "Old title"
self.home_page.save()
revision.publish()
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.publish").count(), 1
)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.rename").count(), 0
)
# Now, check the rename is logged
revision = self.home_page.save_revision()
self.home_page.title = "New title"
self.home_page.save()
revision.publish()
self.assertEqual(PageLogEntry.objects.count(), 3)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.publish").count(), 2
)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.rename").count(), 1
)
def test_page_unpublish(self):
self.home_page.unpublish()
self.assertEqual(PageLogEntry.objects.count(), 1)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.unpublish").count(), 1
)
def test_page_unpublish_doesnt_log_for_aliases(self):
self.home_page.create_alias(update_slug="the-alias")
self.home_page.unpublish()
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.unpublish").count(), 1
)
def test_revision_revert(self):
revision1 = self.home_page.save_revision()
self.home_page.save_revision()
self.home_page.save_revision(log_action=True, previous_revision=revision1)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.revert").count(), 1
)
def test_revision_schedule_publish(self):
go_live_at = datetime.datetime.now() + datetime.timedelta(days=1)
if settings.USE_TZ:
go_live_at = timezone.make_aware(go_live_at)
expected_go_live_at = timezone.localtime(go_live_at, datetime.timezone.utc)
else:
expected_go_live_at = go_live_at
self.home_page.go_live_at = go_live_at
# with no live revision
revision = self.home_page.save_revision()
revision.publish()
log_entries = PageLogEntry.objects.filter(action="wagtail.publish.schedule")
self.assertEqual(log_entries.count(), 1)
self.assertEqual(log_entries[0].data["revision"]["id"], revision.id)
self.assertEqual(
log_entries[0].data["revision"]["go_live_at"],
# skip double quotes
json.dumps(expected_go_live_at, cls=DjangoJSONEncoder)[1:-1],
)
def test_revision_schedule_revert(self):
revision1 = self.home_page.save_revision()
revision2 = self.home_page.save_revision()
if settings.USE_TZ:
self.home_page.go_live_at = timezone.make_aware(
datetime.datetime.now() + datetime.timedelta(days=1)
)
else:
self.home_page.go_live_at = datetime.datetime.now() + datetime.timedelta(
days=1
)
schedule_revision = self.home_page.save_revision(
log_action=True, previous_revision=revision2
)
schedule_revision.publish(previous_revision=revision1)
self.assertListEqual(
list(PageLogEntry.objects.values_list("action", flat=True)),
[
"wagtail.publish.schedule",
"wagtail.revert",
], # order_by -timestamp, by default
)
def test_revision_cancel_schedule(self):
go_live_at = datetime.datetime.now() + datetime.timedelta(days=1)
if settings.USE_TZ:
go_live_at = timezone.make_aware(go_live_at)
expected_go_live_at = timezone.localtime(go_live_at, datetime.timezone.utc)
else:
expected_go_live_at = go_live_at
self.home_page.go_live_at = go_live_at
revision = self.home_page.save_revision()
revision.publish()
revision.approved_go_live_at = None
revision.save(update_fields=["approved_go_live_at"])
log_entries = PageLogEntry.objects.filter(action="wagtail.schedule.cancel")
self.assertEqual(log_entries.count(), 1)
self.assertEqual(log_entries[0].data["revision"]["id"], revision.id)
self.assertEqual(
log_entries[0].data["revision"]["go_live_at"],
# skip double quotes
json.dumps(expected_go_live_at, cls=DjangoJSONEncoder)[1:-1],
)
# The home_page was live already and we've only cancelled the publication of the above revision.
self.assertTrue(log_entries[0].data["revision"]["has_live_version"])
def test_page_lock_unlock(self):
self.home_page.save(log_action="wagtail.lock")
self.home_page.save(log_action="wagtail.unlock")
self.assertEqual(
PageLogEntry.objects.filter(
action__in=["wagtail.lock", "wagtail.unlock"]
).count(),
2,
)
def test_page_copy(self):
self.home_page.copy(update_attrs={"title": "About us", "slug": "about-us"})
self.assertListEqual(
list(PageLogEntry.objects.values_list("action", flat=True)),
["wagtail.publish", "wagtail.copy", "wagtail.create"],
)
def test_page_reorder(self):
section_1 = self.root_page.add_child(
instance=SimplePage(title="Child 1", slug="child-1", content="hello")
)
self.root_page.add_child(
instance=SimplePage(title="Child 2", slug="child-2", content="hello")
)
user = get_user_model().objects.first()
# Reorder section 1 to be the last page under root_page.
# This should log as `wagtail.reorder` because the page was moved under the same parent page
section_1.move(self.root_page, user=user, pos="last-child")
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.reorder", user=user).count(), 1
)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.move", user=user).count(), 0
)
def test_page_move(self):
section = self.root_page.add_child(
instance=SimplePage(title="About us", slug="about", content="hello")
)
user = get_user_model().objects.first()
# move() interprets `target` as an intended 'sibling' by default, so
# we must use `pos` to indicate that `self.home_page` should be the
# new 'parent'
section.move(self.home_page, pos="last-child", user=user)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.move", user=user).count(), 1
)
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.reorder", user=user).count(), 0
)
def test_page_delete(self):
self.home_page.add_child(
instance=SimplePage(title="Child", slug="child-page", content="hello")
)
child = self.home_page.add_child(
instance=SimplePage(
title="Another child", slug="child-page-2", content="hello"
)
)
child.add_child(
instance=SimplePage(
title="Grandchild", slug="grandchild-page", content="hello"
)
)
# check deleting a parent page logs descendent deletion
self.home_page.delete()
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.delete").count(), 4
)
self.assertEqual(
set(
PageLogEntry.objects.filter(action="wagtail.delete").values_list(
"label", flat=True
)
),
{
"Homepage (simple page)",
"Grandchild (simple page)",
"Child (simple page)",
"Another child (simple page)",
},
)
def test_workflow_actions(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
task_2 = Task.objects.create(name="test_task_2")
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
WorkflowTask.objects.create(workflow=workflow, task=task_2, sort_order=2)
self.home_page.save_revision()
user = get_user_model().objects.first()
workflow_state = workflow.start(self.home_page, user)
workflow_entry = PageLogEntry.objects.filter(action="wagtail.workflow.start")
self.assertEqual(workflow_entry.count(), 1)
self.assertEqual(
workflow_entry[0].data,
{
"workflow": {
"id": workflow.id,
"title": workflow.name,
"status": workflow_state.status,
"task_state_id": workflow_state.current_task_state_id,
"next": {
"id": workflow_state.current_task_state.task.id,
"title": workflow_state.current_task_state.task.name,
},
}
},
)
# Approve
for action in ["approve", "reject"]:
with self.subTest(action):
task_state = workflow_state.current_task_state
task_state.task.on_action(
task_state,
user=None,
action_name=action,
comment="This is my comment",
)
workflow_state.refresh_from_db()
entry = PageLogEntry.objects.filter(action=f"wagtail.workflow.{action}")
self.assertEqual(entry.count(), 1)
self.assertEqual(
entry[0].data,
{
"workflow": {
"id": workflow.id,
"title": workflow.name,
"status": task_state.status,
"task_state_id": task_state.id,
"task": {
"id": task_state.task.id,
"title": task_state.task.name,
},
"next": {
"id": workflow_state.current_task_state.task.id,
"title": workflow_state.current_task_state.task.name,
},
},
"comment": "This is my comment",
},
)
self.assertEqual(entry[0].comment, "This is my comment")
def test_snippet_workflow_actions(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
task_2 = Task.objects.create(name="test_task_2")
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
WorkflowTask.objects.create(workflow=workflow, task=task_2, sort_order=2)
snippet = FullFeaturedSnippet.objects.create(text="Initial", live=False)
snippet.save_revision()
user = get_user_model().objects.first()
workflow_state = workflow.start(snippet, user)
workflow_entry = ModelLogEntry.objects.filter(action="wagtail.workflow.start")
self.assertEqual(workflow_entry.count(), 1)
self.assertEqual(
workflow_entry[0].data,
{
"workflow": {
"id": workflow.id,
"title": workflow.name,
"status": workflow_state.status,
"task_state_id": workflow_state.current_task_state_id,
"next": {
"id": workflow_state.current_task_state.task.id,
"title": workflow_state.current_task_state.task.name,
},
}
},
)
# Approve
for action in ["approve", "reject"]:
with self.subTest(action):
task_state = workflow_state.current_task_state
task_state.task.on_action(
task_state,
user=None,
action_name=action,
comment="This is my comment",
)
workflow_state.refresh_from_db()
entry = ModelLogEntry.objects.filter(
action=f"wagtail.workflow.{action}"
)
self.assertEqual(entry.count(), 1)
self.assertEqual(
entry[0].data,
{
"workflow": {
"id": workflow.id,
"title": workflow.name,
"status": task_state.status,
"task_state_id": task_state.id,
"task": {
"id": task_state.task.id,
"title": task_state.task.name,
},
"next": {
"id": workflow_state.current_task_state.task.id,
"title": workflow_state.current_task_state.task.name,
},
},
"comment": "This is my comment",
},
)
self.assertEqual(entry[0].comment, "This is my comment")
def test_workflow_completions_logs_publishing_user(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
self.assertFalse(PageLogEntry.objects.filter(action="wagtail.publish").exists())
self.home_page.save_revision()
user = get_user_model().objects.first()
workflow_state = workflow.start(self.home_page, user)
publisher = get_user_model().objects.last()
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
self.assertEqual(
PageLogEntry.objects.get(action="wagtail.publish").user, publisher
)
def test_snippet_workflow_completions_logs_publishing_user(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
self.assertFalse(
ModelLogEntry.objects.filter(action="wagtail.publish").exists()
)
snippet = FullFeaturedSnippet.objects.create(text="Initial", live=False)
snippet.save_revision()
user = get_user_model().objects.first()
workflow_state = workflow.start(snippet, user)
publisher = get_user_model().objects.last()
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
self.assertEqual(
ModelLogEntry.objects.get(action="wagtail.publish").user, publisher
)
def test_page_privacy(self):
restriction = PageViewRestriction.objects.create(page=self.home_page)
self.assertEqual(
PageLogEntry.objects.filter(
action="wagtail.view_restriction.create"
).count(),
1,
)
restriction.restriction_type = PageViewRestriction.PASSWORD
restriction.save()
self.assertEqual(
PageLogEntry.objects.filter(action="wagtail.view_restriction.edit").count(),
1,
)
def test_hook(actions):
return actions.register_action("test.custom_action", "Custom action", "Tested!")
class TestAuditLogHooks(WagtailTestUtils, TestCase):
def setUp(self):
self.root_page = Page.objects.get(id=2)
def test_register_log_actions_hook(self):
log_actions = LogActionRegistry()
self.assertTrue(log_actions.action_exists("wagtail.create"))
def test_action_must_be_registered(self):
# We check actions are registered to let developers know if they have forgotten to register
# a new action or made a spelling mistake. It's not intended as a database-level constraint.
with self.assertRaises(ValidationError) as e:
PageLogEntry.objects.log_action(self.root_page, action="test.custom_action")
self.assertEqual(
e.exception.message_dict,
{
"action": [
"The log action 'test.custom_action' has not been registered."
]
},
)
def test_action_format_message(self):
# All new logs should pass our validation, but older logs or logs that were added in bulk
# may be invalid.
# Using LogEntry.objects.update, we can bypass the on save validation.
log_entry = PageLogEntry.objects.log_action(
self.root_page, action="wagtail.create"
)
PageLogEntry.objects.update(action="test.custom_action")
log_entry.refresh_from_db()
log_actions = LogActionRegistry()
self.assertEqual(log_entry.message, "Unknown test.custom_action")
self.assertFalse(log_actions.action_exists("test.custom_action"))
with self.register_hook("register_log_actions", test_hook):
log_actions = LogActionRegistry()
self.assertTrue(log_actions.action_exists("test.custom_action"))
self.assertEqual(
log_actions.get_formatter(log_entry).format_message(log_entry),
"Tested!",
)
self.assertEqual(
log_actions.get_action_label("test.custom_action"), "Custom action"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,111 @@
from django.test import TestCase
from wagtail.models import Collection
class TestCollectionTreeOperations(TestCase):
def setUp(self):
self.root_collection = Collection.get_first_root_node()
self.holiday_photos_collection = self.root_collection.add_child(
name="Holiday photos"
)
self.evil_plans_collection = self.root_collection.add_child(name="Evil plans")
# self.holiday_photos_collection's path has been updated out from under it by the addition of a sibling with
# an alphabetically earlier name (due to Collection.node_order_by = ['name']), so we need to refresh it from
# the DB to get the new path.
self.holiday_photos_collection.refresh_from_db()
def test_alphabetic_sorting(self):
old_evil_path = self.evil_plans_collection.path
old_holiday_path = self.holiday_photos_collection.path
# Add a child to Root that has an earlier name than "Evil plans" and "Holiday Photos".
alpha_collection = self.root_collection.add_child(name="Alpha")
# Take note that self.evil_plans_collection and self.holiday_photos_collection have not yet changed.
self.assertEqual(old_evil_path, self.evil_plans_collection.path)
self.assertEqual(old_holiday_path, self.holiday_photos_collection.path)
# Update the two Collections from the database.
self.evil_plans_collection.refresh_from_db()
self.holiday_photos_collection.refresh_from_db()
# Confirm that the "Evil plans" and "Holiday photos" paths have changed in the DB due to adding "Alpha".
self.assertNotEqual(old_evil_path, self.evil_plans_collection.path)
self.assertNotEqual(old_holiday_path, self.holiday_photos_collection.path)
# Confirm that Alpha is before Evil Plans and Holiday Photos, due to Collection.node_order_by = ['name'].
self.assertLess(alpha_collection.path, self.evil_plans_collection.path)
self.assertLess(alpha_collection.path, self.holiday_photos_collection.path)
def test_get_ancestors(self):
self.assertEqual(
list(self.holiday_photos_collection.get_ancestors().order_by("path")),
[self.root_collection],
)
self.assertEqual(
list(
self.holiday_photos_collection.get_ancestors(inclusive=True).order_by(
"path"
)
),
[self.root_collection, self.holiday_photos_collection],
)
def test_get_descendants(self):
self.assertEqual(
list(self.root_collection.get_descendants().order_by("path")),
[self.evil_plans_collection, self.holiday_photos_collection],
)
self.assertEqual(
list(self.root_collection.get_descendants(inclusive=True).order_by("path")),
[
self.root_collection,
self.evil_plans_collection,
self.holiday_photos_collection,
],
)
def test_get_siblings(self):
self.assertEqual(
list(self.holiday_photos_collection.get_siblings().order_by("path")),
[self.evil_plans_collection, self.holiday_photos_collection],
)
self.assertEqual(
list(
self.holiday_photos_collection.get_siblings(inclusive=False).order_by(
"path"
)
),
[self.evil_plans_collection],
)
def test_get_next_siblings(self):
self.assertEqual(
list(self.evil_plans_collection.get_next_siblings().order_by("path")),
[self.holiday_photos_collection],
)
self.assertEqual(
list(
self.holiday_photos_collection.get_next_siblings(
inclusive=True
).order_by("path")
),
[self.holiday_photos_collection],
)
self.assertEqual(
list(self.holiday_photos_collection.get_next_siblings().order_by("path")),
[],
)
def test_get_prev_siblings(self):
self.assertEqual(
list(self.holiday_photos_collection.get_prev_siblings().order_by("path")),
[self.evil_plans_collection],
)
self.assertEqual(
list(self.evil_plans_collection.get_prev_siblings().order_by("path")), []
)
self.assertEqual(
list(
self.evil_plans_collection.get_prev_siblings(inclusive=True).order_by(
"path"
)
),
[self.evil_plans_collection],
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,43 @@
from django.contrib.auth import get_user_model
from django.test import TestCase
from wagtail.models import Comment, Page
class CommentTestingUtils:
def setUp(self):
self.page = Page.objects.get(title="Welcome to the Wagtail test site!")
self.revision_1 = self.page.save_revision()
self.revision_2 = self.page.save_revision()
def create_comment(self, revision_created):
return Comment.objects.create(
page=self.page,
user=get_user_model().objects.first(),
text="test",
contentpath="title",
revision_created=revision_created,
)
class TestRevisionDeletion(CommentTestingUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
super().setUp()
self.revision_3 = self.page.save_revision()
self.old_comment = self.create_comment(self.revision_1)
self.new_comment = self.create_comment(self.revision_3)
def test_deleting_old_revision_moves_comment_revision_created_forwards(self):
# test that when a revision is deleted, a comment linked to it via revision_created has its revision_created moved
# to the next revision
self.revision_1.delete()
self.old_comment.refresh_from_db()
self.assertEqual(self.old_comment.revision_created, self.revision_2)
def test_deleting_most_recent_revision_deletes_created_comments(self):
# test that when the most recent revision is deleted, any comments created on it are also deleted
self.revision_3.delete()
with self.assertRaises(Comment.DoesNotExist):
self.new_comment.refresh_from_db()

View File

@@ -0,0 +1,64 @@
from django.apps import apps
from django.core import checks
from django.db import models
from django.test import TestCase
from wagtail.models import DraftStateMixin, RevisionMixin
class TestDraftStateMixin(TestCase):
def tearDown(self):
# Unregister the models from the overall model registry
# so that it doesn't break tests elsewhere.
# We can probably replace this with Django's @isolate_apps decorator.
for package in ("wagtailcore", "wagtail.tests"):
try:
for model in (
"draftstatewithoutrevisionmodel",
"draftstateincorrectrevisionmodel",
"draftstatewithrevisionmodel",
):
del apps.all_models[package][model]
except KeyError:
pass
apps.clear_cache()
def test_missing_revision_mixin(self):
class DraftStateWithoutRevisionModel(DraftStateMixin, models.Model):
pass
self.assertEqual(
DraftStateWithoutRevisionModel.check(),
[
checks.Error(
"DraftStateMixin requires RevisionMixin to be applied after DraftStateMixin.",
hint="Add RevisionMixin to the model's base classes after DraftStateMixin.",
obj=DraftStateWithoutRevisionModel,
id="wagtailcore.E004",
)
],
)
def test_incorrect_revision_mixin_order(self):
class DraftStateIncorrectRevisionModel(
RevisionMixin, DraftStateMixin, models.Model
):
pass
self.assertEqual(
DraftStateIncorrectRevisionModel.check(),
[
checks.Error(
"DraftStateMixin requires RevisionMixin to be applied after DraftStateMixin.",
hint="Add RevisionMixin to the model's base classes after DraftStateMixin.",
obj=DraftStateIncorrectRevisionModel,
id="wagtailcore.E004",
)
],
)
def test_correct_model(self):
class DraftStateWithRevisionModel(DraftStateMixin, RevisionMixin, models.Model):
pass
self.assertEqual(DraftStateWithRevisionModel.check(), [])

View File

@@ -0,0 +1,189 @@
from django.test import SimpleTestCase
from wagtail.test.utils.form_data import querydict_from_html
class TestQueryDictFromHTML(SimpleTestCase):
html = """
<form id="personal-details">
<input type="hidden" name="csrfmiddlewaretoken" value="Z783HTL5Bc2J54WhAtEeR3eefM1FBkq0EbTfNnYnepFGuJSOfvosFvwjeKYtMwFr">
<input type="hidden" name="no_value_input">
<input type="hidden" value="no name input">
<div>
<div>
<label>
<span>Full name</span>
<input type="text" name="name" value="Jane Doe" placeholder="">
</label>
<label>
<span>Email address</span>
<input type="email" name="email" value="jane@example.com" placeholder="name@example.com">
</label>
</div>
</div>
</form>
<form id="event-details">
<div>
<div>
<label>
<span>When is your event?</span>
<input type="date" name="date" value="2023-01-01">
</label>
<label>
<span>What type of event is it?</span>
<select name="event_type">
<option value="corporate">Corporate event</option>
<option value="wedding">Wedding</option>
<option value="birthday">Birthday</option>
<option value="other" selected>Other</option>
</select>
</label>
<label>
<span>What age groups is it suitable for?</span>
<select name="ages" multiple>
<option>Infants</option>
<option>Children</option>
<option>Teenagers</option>
<option selected>18-30</option>
<option selected>30-50</option>
<option>50-70</option>
<option>70+</option>
</select>
</label>
</div>
</div>
</form>
<form id="market-research">
<div>
<div>
<fieldset>
<legend>How many pets do you have?</legend>
<div class="radio-list">
<div class="radio">
<label>
<input type="radio" name="pets" value="0" />
None
</label>
</div>
<div class="radio">
<label>
<input type="radio" name="pets" value="1" />
One
</label>
</div>
<div class="radio">
<label>
<input type="radio" name="pets" value="2" checked />
Two
</label>
</div>
<div class="radio">
<label>
<input type="radio" name="pets" value="3+" />
Three or more
</label>
</div>
</div>
</fieldset>
<fieldset>
<legend>Which two colours do you like best?</legend>
<div class="checkbox-list">
<div class="checkbox">
<label>
<input type="checkbox" name="colours" value="cyan">
Cyan
</label>
</div>
<div class="checkbox">
<label>
<input type="checkbox" name="colours" value="magenta" checked />
Magenta
</label>
</div>
<div class="checkbox">
<label>
<input type="checkbox" name="colours" value="yellow" />
Yellow
</label>
</div>
<div class="checkbox">
<label>
<input type="checkbox" name="colours" value="black" checked />
Black
</label>
</div>
<div class="checkbox">
<label>
<input type="checkbox" name="colours" value="white" />
White
</label>
</div>
</div>
</fieldset>
<label>
<span>Tell us what you love</span>
<textarea name="love" rows="3">Comic books</textarea>
</label>
</div>
</div>
</form>
"""
personal_details = [
("no_value_input", [""]),
("name", ["Jane Doe"]),
("email", ["jane@example.com"]),
]
event_details = [
("date", ["2023-01-01"]),
("event_type", ["other"]),
("ages", ["18-30", "30-50"]),
]
market_research = [
("pets", ["2"]),
("colours", ["magenta", "black"]),
("love", ["Comic books"]),
]
def test_html_only(self):
# data should be extracted from the 'first' form by default
result = querydict_from_html(self.html)
self.assertEqual(list(result.lists()), self.personal_details)
def test_include_csrf(self):
result = querydict_from_html(self.html, exclude_csrf=False)
expected_result = [
(
"csrfmiddlewaretoken",
["Z783HTL5Bc2J54WhAtEeR3eefM1FBkq0EbTfNnYnepFGuJSOfvosFvwjeKYtMwFr"],
)
] + self.personal_details
self.assertEqual(list(result.lists()), expected_result)
def test_form_index(self):
for index, expected_data in (
(0, self.personal_details),
("2", self.market_research),
(1, self.event_details),
):
result = querydict_from_html(self.html, form_index=index)
self.assertEqual(list(result.lists()), expected_data)
def test_form_id(self):
for id, expected_data in (
("event-details", self.event_details),
("personal-details", self.personal_details),
("market-research", self.market_research),
):
result = querydict_from_html(self.html, form_id=id)
self.assertEqual(list(result.lists()), expected_data)
def test_invalid_form_id(self):
with self.assertRaises(ValueError):
querydict_from_html(self.html, form_id="invalid-id")
def test_invalid_index(self):
with self.assertRaises(ValueError):
querydict_from_html(self.html, form_index=5)

View File

@@ -0,0 +1,36 @@
from django.test import TestCase
from wagtail import hooks
from wagtail.test.utils import WagtailTestUtils
def test_hook():
pass
class TestLoginView(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
@classmethod
def setUpClass(cls):
hooks.register("test_hook_name", test_hook)
@classmethod
def tearDownClass(cls):
del hooks._hooks["test_hook_name"]
def test_before_hook(self):
def before_hook():
pass
with self.register_hook("test_hook_name", before_hook, order=-1):
hook_fns = hooks.get_hooks("test_hook_name")
self.assertEqual(hook_fns, [before_hook, test_hook])
def test_after_hook(self):
def after_hook():
pass
with self.register_hook("test_hook_name", after_hook, order=1):
hook_fns = hooks.get_hooks("test_hook_name")
self.assertEqual(hook_fns, [test_hook, after_hook])

View File

@@ -0,0 +1,312 @@
from django.template import engines
from django.template.loader import render_to_string
from django.test import TestCase
from django.utils.safestring import mark_safe
from wagtail import __version__, blocks
from wagtail.coreutils import get_dummy_request
from wagtail.models import Page, Site
from wagtail.test.testapp.blocks import SectionBlock
class TestCoreGlobalsAndFilters(TestCase):
def setUp(self):
self.engine = engines["jinja2"]
def render(self, string, context=None, request_context=True):
if context is None:
context = {}
# Add a request to the template, to simulate a RequestContext
if request_context:
site = Site.objects.get(is_default_site=True)
context["request"] = get_dummy_request(site=site)
template = self.engine.from_string(string)
return template.render(context)
def test_richtext(self):
richtext = '<p>Merry <a linktype="page" id="2">Christmas</a>!</p>'
self.assertEqual(
self.render("{{ text|richtext }}", {"text": richtext}),
'<p>Merry <a href="/">Christmas</a>!</p>',
)
def test_pageurl(self):
page = Page.objects.get(pk=2)
self.assertEqual(self.render("{{ pageurl(page) }}", {"page": page}), page.url)
def test_fullpageurl(self):
page = Page.objects.get(pk=2)
self.assertEqual(
self.render("{{ fullpageurl(page) }}", {"page": page}), page.full_url
)
def test_slugurl(self):
page = Page.objects.get(pk=2)
self.assertEqual(
self.render("{{ slugurl(page.slug) }}", {"page": page}), page.url
)
def test_bad_slugurl(self):
self.assertEqual(
self.render('{{ slugurl("bad-slug-doesnt-exist") }}', {}), "None"
)
def test_wagtail_site(self):
self.assertEqual(self.render("{{ wagtail_site().hostname }}"), "localhost")
def test_wagtail_version(self):
self.assertEqual(self.render("{{ wagtail_version() }}"), __version__)
class TestJinjaEscaping(TestCase):
fixtures = ["test.json"]
def test_block_render_result_is_safe(self):
"""
Ensure that any results of template rendering in block.render are marked safe
so that they don't get double-escaped when inserted into a parent template (#2541)
"""
stream_block = blocks.StreamBlock(
[("paragraph", blocks.CharBlock(template="tests/jinja2/paragraph.html"))]
)
stream_value = stream_block.to_python(
[
{"type": "paragraph", "value": "hello world"},
]
)
result = render_to_string(
"tests/jinja2/stream.html",
{
"value": stream_value,
},
)
self.assertIn("<p>hello world</p>", result)
def test_rich_text_is_safe(self):
"""
Ensure that RichText values are marked safe
so that they don't get double-escaped when inserted into a parent template (#2542)
"""
stream_block = blocks.StreamBlock(
[
(
"paragraph",
blocks.RichTextBlock(template="tests/jinja2/rich_text.html"),
)
]
)
stream_value = stream_block.to_python(
[
{
"type": "paragraph",
"value": '<p>Merry <a linktype="page" id="4">Christmas</a>!</p>',
},
]
)
result = render_to_string(
"tests/jinja2/stream.html",
{
"value": stream_value,
},
)
self.assertIn(
'<p>Merry <a href="/events/christmas/">Christmas</a>!</p>', result
)
class TestIncludeBlockTag(TestCase):
def test_include_block_tag_with_boundblock(self):
"""
The include_block tag should be able to render a BoundBlock's template
while keeping the parent template's context
"""
block = blocks.CharBlock(template="tests/jinja2/heading_block.html")
bound_block = block.bind("bonjour")
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": bound_block,
"language": "fr",
},
)
self.assertIn('<body><h1 lang="fr">bonjour</h1></body>', result)
def test_include_block_tag_with_structvalue(self):
"""
The include_block tag should be able to render a StructValue's template
while keeping the parent template's context
"""
block = SectionBlock()
struct_value = block.to_python(
{"title": "Bonjour", "body": "monde <i>italique</i>"}
)
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": struct_value,
"language": "fr",
},
)
self.assertIn(
"""<body><h1 lang="fr">Bonjour</h1>monde <i>italique</i></body>""", result
)
def test_include_block_tag_with_streamvalue(self):
"""
The include_block tag should be able to render a StreamValue's template
while keeping the parent template's context
"""
block = blocks.StreamBlock(
[
(
"heading",
blocks.CharBlock(template="tests/jinja2/heading_block.html"),
),
("paragraph", blocks.CharBlock()),
],
template="tests/jinja2/stream_with_language.html",
)
stream_value = block.to_python([{"type": "heading", "value": "Bonjour"}])
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": stream_value,
"language": "fr",
},
)
self.assertIn(
'<div class="heading" lang="fr"><h1 lang="fr">Bonjour</h1></div>', result
)
def test_include_block_tag_with_plain_value(self):
"""
The include_block tag should be able to render a value without a render_as_block method
by just rendering it as a string
"""
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": 42,
},
)
self.assertIn("<body>42</body>", result)
def test_include_block_tag_with_filtered_value(self):
"""
The block parameter on include_block tag should support complex values including filters,
e.g. {% include_block foo|default:123 %}
"""
block = blocks.CharBlock(template="tests/jinja2/heading_block.html")
bound_block = block.bind("bonjour")
result = render_to_string(
"tests/jinja2/include_block_test_with_filter.html",
{
"test_block": bound_block,
"language": "fr",
},
)
self.assertIn('<body><h1 lang="fr">bonjour</h1></body>', result)
result = render_to_string(
"tests/jinja2/include_block_test_with_filter.html",
{
"test_block": None,
"language": "fr",
},
)
self.assertIn("<body>999</body>", result)
def test_include_block_tag_with_additional_variable(self):
"""
The include_block tag should be able to pass local variables from parent context to the
child context
"""
block = blocks.CharBlock(template="tests/blocks/heading_block.html")
bound_block = block.bind("bonjour")
result = render_to_string(
"tests/jinja2/include_block_tag_with_additional_variable.html",
{"test_block": bound_block},
)
self.assertIn('<body><h1 class="important">bonjour</h1></body>', result)
def test_include_block_html_escaping(self):
"""
Output of include_block should be escaped as per Django autoescaping rules
"""
block = blocks.CharBlock()
bound_block = block.bind(block.to_python("some <em>evil</em> HTML"))
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": bound_block,
},
)
self.assertIn("<body>some &lt;em&gt;evil&lt;/em&gt; HTML</body>", result)
# {% autoescape off %} should be respected
result = render_to_string(
"tests/blocks/include_block_autoescape_off_test.html",
{
"test_block": bound_block,
},
)
self.assertIn("<body>some <em>evil</em> HTML</body>", result)
# The same escaping should be applied when passed a plain value rather than a BoundBlock -
# a typical situation where this would occur would be rendering an item of a StructBlock,
# e.g. {% include_block person_block.first_name %} as opposed to
# {% include_block person_block.bound_blocks.first_name %}
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": "some <em>evil</em> HTML",
},
)
self.assertIn("<body>some &lt;em&gt;evil&lt;/em&gt; HTML</body>", result)
result = render_to_string(
"tests/jinja2/include_block_autoescape_off_test.html",
{
"test_block": "some <em>evil</em> HTML",
},
)
self.assertIn("<body>some <em>evil</em> HTML</body>", result)
# Blocks that explicitly return 'safe HTML'-marked values (such as RawHTMLBlock) should
# continue to produce unescaped output
block = blocks.RawHTMLBlock()
bound_block = block.bind(block.to_python("some <em>evil</em> HTML"))
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": bound_block,
},
)
self.assertIn("<body>some <em>evil</em> HTML</body>", result)
# likewise when applied to a plain 'safe HTML' value rather than a BoundBlock
result = render_to_string(
"tests/jinja2/include_block_test.html",
{
"test_block": mark_safe("some <em>evil</em> HTML"),
},
)
self.assertIn("<body>some <em>evil</em> HTML</body>", result)

View File

@@ -0,0 +1,168 @@
from django.conf import settings
from django.test import TestCase, override_settings
from django.utils import translation
from django.utils.translation import gettext_lazy as _
from wagtail.models import Locale, Page
from wagtail.test.i18n.models import TestPage
def make_test_page(**kwargs):
root_page = Page.objects.get(id=1)
kwargs.setdefault("title", "Test page")
return root_page.add_child(instance=TestPage(**kwargs))
class TestLocaleModel(TestCase):
def setUp(self):
language_codes = dict(settings.LANGUAGES).keys()
for language_code in language_codes:
Locale.objects.get_or_create(language_code=language_code)
def test_default(self):
locale = Locale.get_default()
self.assertEqual(locale.language_code, "en")
@override_settings(LANGUAGE_CODE="fr-ca")
def test_default_doesnt_have_to_be_english(self):
locale = Locale.get_default()
self.assertEqual(locale.language_code, "fr")
def test_get_active_default(self):
self.assertEqual(Locale.get_active().language_code, "en")
def test_get_active_overridden(self):
with translation.override("fr"):
self.assertEqual(Locale.get_active().language_code, "fr")
def test_language_name(self):
for language_code, expected_result in (
("en", "English"),
("fr", "French"),
("zh-hans", "Simplified Chinese"),
):
with self.subTest(language_code):
locale = Locale(language_code=language_code)
self.assertEqual(locale.language_name, expected_result)
def test_language_name_for_unrecognised_language(self):
locale = Locale(language_code="foo")
with self.assertRaises(KeyError):
locale.language_name
def test_language_name_local(self):
for language_code, expected_result in (
("en", "English"),
("fr", "français"),
("zh-hans", "简体中文"),
):
with self.subTest(language_code):
locale = Locale(language_code=language_code)
self.assertEqual(locale.language_name_local, expected_result)
def test_language_name_local_for_unrecognised_language(self):
locale = Locale(language_code="foo")
with self.assertRaises(KeyError):
locale.language_name_local
def test_language_name_localized_reflects_active_language(self):
for language_code in (
"fr", # French
"zh-hans", # Simplified Chinese
"ca", # Catalan
"de", # German
):
with self.subTest(language_code):
locale = Locale(language_code=language_code)
with translation.override("en"):
self.assertEqual(
locale.language_name_localized, locale.language_name
)
with translation.override(language_code):
# NB: Casing can differ between these, hence the lower()
self.assertEqual(
locale.language_name_localized.lower(),
locale.language_name_local.lower(),
)
def test_language_name_localized_for_unconfigured_language(self):
locale = Locale(language_code="zh-hans")
self.assertEqual(locale.language_name_localized, "Simplified Chinese")
with translation.override("zh-hans"):
self.assertEqual(locale.language_name_localized, locale.language_name_local)
def test_language_name_localized_for_unrecognised_language(self):
locale = Locale(language_code="foo")
with self.assertRaises(KeyError):
locale.language_name_localized
def test_is_bidi(self):
for language_code, expected_result in (
("en", False),
("ar", True),
("he", True),
("fr", False),
("foo", False),
):
with self.subTest(language_code):
locale = Locale(language_code=language_code)
self.assertIs(locale.is_bidi, expected_result)
def test_is_default(self):
for language_code, expected_result in (
(settings.LANGUAGE_CODE, True), # default
("zh-hans", False), # alternative
("foo", False), # invalid
):
with self.subTest(language_code):
locale = Locale(language_code=language_code)
self.assertIs(locale.is_default, expected_result)
def test_is_active(self):
for locale_language, active_language, expected_result in (
(settings.LANGUAGE_CODE, settings.LANGUAGE_CODE, True),
(settings.LANGUAGE_CODE, "fr", False),
("zh-hans", settings.LANGUAGE_CODE, False),
("en", "en-gb", True),
("foo", settings.LANGUAGE_CODE, False),
):
with self.subTest(f"locale={locale_language} active={active_language}"):
with translation.override(active_language):
locale = Locale(language_code=locale_language)
self.assertEqual(locale.is_active, expected_result)
def test_get_display_name(self):
for language_code, expected_result in (
("en", "English"), # configured
("zh-hans", "Simplified Chinese"), # not configured but valid
("foo", "foo"), # not configured or valid
):
locale = Locale(language_code=language_code)
with self.subTest(language_code):
self.assertEqual(locale.get_display_name(), expected_result)
def test_str_reflects_get_display(self):
for language_code in ("en", "zh-hans", "foo"):
locale = Locale(language_code=language_code)
with self.subTest(language_code):
self.assertEqual(str(locale), locale.get_display_name())
@override_settings(LANGUAGES=[("en", _("English")), ("fr", _("French"))])
def test_str_when_languages_uses_gettext(self):
locale = Locale(language_code="en")
self.assertIsInstance(locale.__str__(), str)
@override_settings(LANGUAGE_CODE="fr")
def test_change_root_page_locale_on_locale_deletion(self):
"""
On deleting the locale used for the root page (but no 'real' pages), the
root page should be reassigned to a new locale (the default one, if possible)
"""
# change 'real' pages first
Page.objects.filter(depth__gt=1).update(
locale=Locale.objects.get(language_code="fr")
)
self.assertEqual(Page.get_first_root_node().locale.language_code, "en")
Locale.objects.get(language_code="en").delete()
self.assertEqual(Page.get_first_root_node().locale.language_code, "fr")

View File

@@ -0,0 +1,54 @@
from django.apps import apps
from django.core import checks
from django.db import models
from django.test import TestCase
from wagtail.models import LockableMixin, RevisionMixin
class TestLockableMixin(TestCase):
def tearDown(self):
# Unregister the models from the overall model registry
# so that it doesn't break tests elsewhere.
# We can probably replace this with Django's @isolate_apps decorator.
for package in ("wagtailcore", "wagtail.tests"):
try:
for model in (
"lockablewithoutrevisionmodel",
"lockableincorrectrevisionmodel",
"lockablewithrevisionmodel",
):
del apps.all_models[package][model]
except KeyError:
pass
apps.clear_cache()
def test_lockable_mixin_only(self):
class LockableWithoutRevisionModel(LockableMixin, models.Model):
pass
self.assertEqual(LockableWithoutRevisionModel.check(), [])
def test_incorrect_revision_mixin_order(self):
class LockableIncorrectRevisionModel(
RevisionMixin, LockableMixin, models.Model
):
pass
self.assertEqual(
LockableIncorrectRevisionModel.check(),
[
checks.Error(
"LockableMixin must be applied before RevisionMixin.",
hint="Move LockableMixin in the model's base classes before RevisionMixin.",
obj=LockableIncorrectRevisionModel,
id="wagtailcore.E005",
)
],
)
def test_correct_revision_mixin_order(self):
class LockableWithRevisionModel(LockableMixin, RevisionMixin, models.Model):
pass
self.assertEqual(LockableWithRevisionModel.check(), [])

View File

@@ -0,0 +1,851 @@
from datetime import timedelta
from io import StringIO
from unittest import mock
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.core import management
from django.db import models
from django.test import TestCase, override_settings
from django.utils import timezone
from wagtail.embeds.models import Embed
from wagtail.models import (
Collection,
Page,
PageLogEntry,
Revision,
Task,
Workflow,
WorkflowTask,
)
from wagtail.signals import page_published, page_unpublished, published, unpublished
from wagtail.test.testapp.models import (
DraftStateModel,
EventPage,
FullFeaturedSnippet,
PurgeRevisionsProtectedTestModel,
SecretPage,
SimplePage,
)
from wagtail.test.utils import WagtailTestUtils
class TestFixTreeCommand(TestCase):
fixtures = ["test.json"]
def badly_delete_page(self, page):
# Deletes a page the wrong way.
# This will not update numchild and may leave orphans
models.Model.delete(page)
def run_command(self, **options):
options.setdefault("interactive", False)
output = StringIO()
management.call_command("fixtree", stdout=output, **options)
output.seek(0)
return output
def test_fixes_numchild(self):
# Get homepage and save old value
homepage = Page.objects.get(url_path="/home/")
old_numchild = homepage.numchild
# Break it
homepage.numchild = 12345
homepage.save()
# Check that its broken
self.assertEqual(Page.objects.get(url_path="/home/").numchild, 12345)
# Call command
self.run_command()
# Check if its fixed
self.assertEqual(Page.objects.get(url_path="/home/").numchild, old_numchild)
def test_fixes_depth(self):
# Get homepage and save old value
homepage = Page.objects.get(url_path="/home/")
old_depth = homepage.depth
# Break it
homepage.depth = 12345
homepage.save()
# also break the root collection's depth
root_collection = Collection.get_first_root_node()
root_collection.depth = 42
root_collection.save()
# Check that its broken
self.assertEqual(Page.objects.get(url_path="/home/").depth, 12345)
self.assertEqual(Collection.objects.get(id=root_collection.id).depth, 42)
# Call command
self.run_command()
# Check if its fixed
self.assertEqual(Page.objects.get(url_path="/home/").depth, old_depth)
self.assertEqual(Collection.objects.get(id=root_collection.id).depth, 1)
def test_detects_orphans(self):
events_index = Page.objects.get(url_path="/home/events/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
# Delete the events index badly
self.badly_delete_page(events_index)
# Check that christmas_page is still in the tree
self.assertTrue(Page.objects.filter(id=christmas_page.id).exists())
# Call command
output = self.run_command()
# Check that the issues were detected
output_string = output.read()
self.assertIn("Incorrect numchild value found for pages: [2]", output_string)
# Note that page ID 15 was also deleted, but is not picked up here, as
# it is a child of 14.
self.assertIn("Orphaned pages found: [4, 5, 6, 9, 13, 15]", output_string)
# Check that christmas_page is still in the tree
self.assertTrue(Page.objects.filter(id=christmas_page.id).exists())
def test_deletes_orphans(self):
events_index = Page.objects.get(url_path="/home/events/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
# Delete the events index badly
self.badly_delete_page(events_index)
# Check that christmas_page is still in the tree
self.assertTrue(Page.objects.filter(id=christmas_page.id).exists())
# Call command
# delete_orphans simulates a user pressing "y" at the prompt
output = self.run_command(delete_orphans=True)
# Check that the issues were detected
output_string = output.read()
self.assertIn("Incorrect numchild value found for pages: [2]", output_string)
self.assertIn("7 orphaned pages deleted.", output_string)
# Check that christmas_page has been deleted
self.assertFalse(Page.objects.filter(id=christmas_page.id).exists())
def test_remove_path_holes(self):
events_index = Page.objects.get(url_path="/home/events/")
# Delete the event page in path position 0001
Page.objects.get(path=events_index.path + "0001").delete()
self.run_command(full=True)
# the gap at position 0001 should have been closed
events_index = Page.objects.get(url_path="/home/events/")
self.assertTrue(Page.objects.filter(path=events_index.path + "0001").exists())
class TestMovePagesCommand(TestCase):
fixtures = ["test.json"]
def run_command(self, from_, to):
management.call_command("move_pages", str(from_), str(to), stdout=StringIO())
def test_move_pages(self):
# Get pages
events_index = Page.objects.get(url_path="/home/events/")
about_us = Page.objects.get(url_path="/home/about-us/")
page_ids = events_index.get_children().values_list("id", flat=True)
# Move all events into "about us"
self.run_command(events_index.id, about_us.id)
# Check that all pages moved
for page_id in page_ids:
self.assertEqual(Page.objects.get(id=page_id).get_parent(), about_us)
class TestSetUrlPathsCommand(TestCase):
fixtures = ["test.json"]
def run_command(self):
management.call_command("set_url_paths", stdout=StringIO())
def test_set_url_paths(self):
self.run_command()
class TestPublishScheduledPagesCommand(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
# Find root page
self.root_page = Page.objects.get(id=2)
def test_go_live_page_will_be_published(self):
# Connect a mock signal handler to page_published signal
signal_fired = [False]
signal_page = [None]
def page_published_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_page[0] = instance
page_published.connect(page_published_handler)
try:
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
has_unpublished_changes=True,
go_live_at=timezone.now() - timedelta(days=1),
)
self.root_page.add_child(instance=page)
page.save_revision(approved_go_live_at=timezone.now() - timedelta(days=1))
p = Page.objects.get(slug="hello-world")
self.assertFalse(p.live)
self.assertTrue(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
self.assertTrue(p.first_published_at)
self.assertFalse(p.has_unpublished_changes)
self.assertFalse(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
# Check that the page_published signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_page[0], page)
self.assertEqual(signal_page[0], signal_page[0].specific)
finally:
page_published.disconnect(page_published_handler)
def test_go_live_page_created_by_editor_will_be_published(self):
# Connect a mock signal handler to page_published signal
signal_fired = [False]
signal_page = [None]
editor = self.create_user("ed")
editor.groups.add(Group.objects.get(name="Site-wide editors"))
def page_published_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_page[0] = instance
page_published.connect(page_published_handler)
try:
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
has_unpublished_changes=True,
go_live_at=timezone.now() - timedelta(days=1),
)
self.root_page.add_child(instance=page)
page.save_revision(
user=editor, approved_go_live_at=timezone.now() - timedelta(days=1)
)
p = Page.objects.get(slug="hello-world")
self.assertFalse(p.live)
self.assertTrue(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
self.assertTrue(p.first_published_at)
self.assertFalse(p.has_unpublished_changes)
self.assertFalse(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
# Check that the page_published signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_page[0], page)
self.assertEqual(signal_page[0], signal_page[0].specific)
finally:
page_published.disconnect(page_published_handler)
def test_go_live_when_newer_revision_exists(self):
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
has_unpublished_changes=True,
go_live_at=timezone.now() - timedelta(days=1),
)
self.root_page.add_child(instance=page)
page.save_revision(approved_go_live_at=timezone.now() - timedelta(days=1))
page.title = "Goodbye world!"
page.save_revision()
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
self.assertTrue(p.has_unpublished_changes)
self.assertEqual(p.title, "Hello world!")
def test_future_go_live_page_will_not_be_published(self):
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
go_live_at=timezone.now() + timedelta(days=1),
)
self.root_page.add_child(instance=page)
page.save_revision(approved_go_live_at=timezone.now() - timedelta(days=1))
p = Page.objects.get(slug="hello-world")
self.assertFalse(p.live)
self.assertTrue(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertFalse(p.live)
self.assertTrue(
Revision.page_revisions.filter(object_id=p.id)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
def test_expired_page_will_be_unpublished(self):
# Connect a mock signal handler to page_unpublished signal
signal_fired = [False]
signal_page = [None]
def page_unpublished_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_page[0] = instance
page_unpublished.connect(page_unpublished_handler)
try:
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=True,
has_unpublished_changes=False,
expire_at=timezone.now() - timedelta(days=1),
)
self.root_page.add_child(instance=page)
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertFalse(p.live)
self.assertTrue(p.has_unpublished_changes)
self.assertTrue(p.expired)
# Check that the page_published signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_page[0], page)
self.assertEqual(signal_page[0], signal_page[0].specific)
finally:
page_unpublished.disconnect(page_unpublished_handler)
def test_future_expired_page_will_not_be_unpublished(self):
page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=True,
expire_at=timezone.now() + timedelta(days=1),
)
self.root_page.add_child(instance=page)
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
management.call_command("publish_scheduled_pages")
p = Page.objects.get(slug="hello-world")
self.assertTrue(p.live)
self.assertFalse(p.expired)
class TestPublishScheduledCommand(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
self.snippet = DraftStateModel.objects.create(text="Hello world!", live=False)
def test_go_live_will_be_published(self):
# Connect a mock signal handler to published signal
signal_fired = [False]
signal_obj = [None]
def published_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_obj[0] = instance
published.connect(published_handler)
try:
go_live_at = timezone.now() - timedelta(days=1)
self.snippet.has_unpublished_changes = True
self.snippet.go_live_at = go_live_at
self.snippet.save_revision(approved_go_live_at=go_live_at)
self.snippet.refresh_from_db()
self.assertFalse(self.snippet.live)
self.assertTrue(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled")
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
self.assertTrue(self.snippet.first_published_at)
self.assertFalse(self.snippet.has_unpublished_changes)
self.assertFalse(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
# Check that the published signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_obj[0], self.snippet)
finally:
published.disconnect(published_handler)
def test_go_live_created_by_editor_will_be_published(self):
# Connect a mock signal handler to published signal
signal_fired = [False]
signal_obj = [None]
editor = self.create_user("ed")
editor.groups.add(Group.objects.get(name="Site-wide editors"))
def published_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_obj[0] = instance
published.connect(published_handler)
try:
go_live_at = timezone.now() - timedelta(days=1)
self.snippet.has_unpublished_changes = True
self.snippet.go_live_at = go_live_at
self.snippet.save_revision(user=editor, approved_go_live_at=go_live_at)
self.snippet.refresh_from_db()
self.assertFalse(self.snippet.live)
self.assertTrue(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled")
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
self.assertTrue(self.snippet.first_published_at)
self.assertFalse(self.snippet.has_unpublished_changes)
self.assertFalse(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
# Check that the published signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_obj[0], self.snippet)
finally:
published.disconnect(published_handler)
def test_go_live_when_newer_revision_exists(self):
go_live_at = timezone.now() - timedelta(days=1)
self.snippet.has_unpublished_changes = True
self.snippet.go_live_at = go_live_at
self.snippet.save_revision(approved_go_live_at=go_live_at)
self.snippet.text = "Goodbye world!"
self.snippet.save_revision()
management.call_command("publish_scheduled")
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
self.assertTrue(self.snippet.has_unpublished_changes)
self.assertEqual(self.snippet.text, "Hello world!")
def test_future_go_live_will_not_be_published(self):
self.snippet.has_unpublished_changes = True
self.snippet.go_live_at = timezone.now() + timedelta(days=1)
self.snippet.save_revision(
approved_go_live_at=timezone.now() - timedelta(days=1)
)
self.snippet.refresh_from_db()
self.assertFalse(self.snippet.live)
self.assertTrue(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
management.call_command("publish_scheduled")
self.assertFalse(self.snippet.live)
self.assertTrue(
Revision.objects.for_instance(self.snippet)
.exclude(approved_go_live_at__isnull=True)
.exists()
)
def test_expired_will_be_unpublished(self):
# Connect a mock signal handler to unpublished signal
signal_fired = [False]
signal_obj = [None]
def unpublished_handler(sender, instance, **kwargs):
signal_fired[0] = True
signal_obj[0] = instance
unpublished.connect(unpublished_handler)
try:
self.snippet.expire_at = timezone.now() - timedelta(days=1)
self.snippet.save_revision().publish()
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
management.call_command("publish_scheduled")
self.snippet.refresh_from_db()
self.assertFalse(self.snippet.live)
self.assertTrue(self.snippet.has_unpublished_changes)
self.assertTrue(self.snippet.expired)
# Check that the unpublished signal was fired
self.assertTrue(signal_fired[0])
self.assertEqual(signal_obj[0], self.snippet)
finally:
unpublished.disconnect(unpublished_handler)
def test_future_expired_will_not_be_unpublished(self):
self.snippet.expire_at = timezone.now() + timedelta(days=1)
self.snippet.save_revision().publish()
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
management.call_command("publish_scheduled")
self.snippet.refresh_from_db()
self.assertTrue(self.snippet.live)
self.assertFalse(self.snippet.expired)
class TestPurgeRevisionsCommandForPages(TestCase):
base_options = {}
def setUp(self):
self.object = self.get_object()
def get_object(self):
# Find root page
self.root_page = Page.objects.get(id=2)
self.page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
)
self.root_page.add_child(instance=self.page)
self.page.refresh_from_db()
return self.page
def assertRevisionNotExists(self, revision):
self.assertFalse(Revision.objects.filter(id=revision.id).exists())
def assertRevisionExists(self, revision):
self.assertTrue(Revision.objects.filter(id=revision.id).exists())
def run_command(self, **options):
return management.call_command(
"purge_revisions", **{**self.base_options, **options}, stdout=StringIO()
)
def test_latest_revision_not_purged(self):
revision_1 = self.object.save_revision()
revision_2 = self.object.save_revision()
self.run_command()
# revision 1 should be deleted, revision 2 should not be
self.assertRevisionNotExists(revision_1)
self.assertRevisionExists(revision_2)
def test_revisions_in_moderation_or_workflow_not_purged(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
user = get_user_model().objects.first()
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
revision = self.object.save_revision()
workflow.start(self.object, user)
# Save a new revision to ensure that the revision in the workflow
# is not the latest one
self.object.save_revision()
self.run_command()
# even though they're no longer the latest revisions, the old revisions
# should stay as they are attached to an in progress workflow
self.assertRevisionExists(revision)
# If workflow is disabled at some point after that, the revision should
# be deleted
with override_settings(WAGTAIL_WORKFLOW_ENABLED=False):
self.run_command()
self.assertRevisionNotExists(revision)
def test_revisions_with_approve_go_live_not_purged(self):
revision = self.object.save_revision(
approved_go_live_at=timezone.now() + timedelta(days=1)
)
# Save a new revision to ensure that the approved revision
# is not the latest one
self.object.save_revision()
self.run_command()
self.assertRevisionExists(revision)
def test_purge_revisions_with_date_cutoff(self):
old_revision = self.object.save_revision()
self.object.save_revision()
self.run_command(days=30)
# revision should not be deleted, as it is younger than 30 days
self.assertRevisionExists(old_revision)
old_revision.created_at = timezone.now() - timedelta(days=31)
old_revision.save()
self.run_command(days=30)
# revision is now older than 30 days, so should be deleted
self.assertRevisionNotExists(old_revision)
def test_purge_revisions_protected_error(self):
revision_old = self.object.save_revision()
PurgeRevisionsProtectedTestModel.objects.create(revision=revision_old)
revision_purged = self.object.save_revision()
self.object.save_revision()
self.run_command()
# revision should not be deleted, as it is protected
self.assertRevisionExists(revision_old)
# Any other revisions are deleted
self.assertRevisionNotExists(revision_purged)
class TestPurgeRevisionsCommandForSnippets(TestPurgeRevisionsCommandForPages):
def get_object(self):
return FullFeaturedSnippet.objects.create(text="Hello world!")
class TestPurgeRevisionsCommandForPagesWithPagesOnly(TestPurgeRevisionsCommandForPages):
base_options = {"pages": True}
class TestPurgeRevisionsCommandForPagesWithNonPagesOnly(
TestPurgeRevisionsCommandForPages
):
base_options = {"non_pages": True}
def assertRevisionNotExists(self, revision):
# Page revisions won't be purged if only non_pages is specified
return self.assertRevisionExists(revision)
class TestPurgeRevisionsCommandForSnippetsWithNonPagesOnly(
TestPurgeRevisionsCommandForSnippets
):
base_options = {"non_pages": True}
class TestPurgeRevisionsCommandForSnippetsWithPagesOnly(
TestPurgeRevisionsCommandForSnippets
):
base_options = {"pages": True}
def assertRevisionNotExists(self, revision):
# Snippet revisions won't be purged if only pages is specified
return self.assertRevisionExists(revision)
class TestPurgeEmbedsCommand(TestCase):
fixtures = ["test.json"]
def setUp(self):
# create dummy Embed objects
for i in range(5):
embed = Embed(
hash=f"{i}",
url="https://www.youtube.com/watch?v=Js8dIRxwSRY",
max_width=None,
type="video",
html="test html",
title="test title",
author_name="test author name",
provider_name="test provider name",
thumbnail_url="http://test/thumbnail.url",
width=1000,
height=1000,
)
embed.save()
def test_purge_embeds(self):
"""
fetch all dummy embeds and confirm they are deleted when the management command runs
"""
self.assertEqual(Embed.objects.count(), 5)
management.call_command("purge_embeds", stdout=StringIO())
self.assertEqual(Embed.objects.count(), 0)
class TestCreateLogEntriesFromRevisionsCommand(TestCase):
fixtures = ["test.json"]
def setUp(self):
self.page = SimplePage(
title="Hello world!",
slug="hello-world",
content="hello",
live=False,
expire_at=timezone.now() - timedelta(days=1),
)
Page.objects.get(id=2).add_child(instance=self.page)
# Create empty revisions, which should not be converted to log entries
for i in range(3):
self.page.save_revision()
# Add another revision with a content change
self.page.title = "Hello world!!"
revision = self.page.save_revision()
revision.publish()
# Do the same with a SecretPage (to check that the version comparison code doesn't
# trip up on permission-dependent edit handlers)
self.secret_page = SecretPage(
title="The moon",
slug="the-moon",
boring_data="the moon",
secret_data="is made of cheese",
live=False,
)
Page.objects.get(id=2).add_child(instance=self.secret_page)
# Create empty revisions, which should not be converted to log entries
for i in range(3):
self.secret_page.save_revision()
# Add another revision with a content change
self.secret_page.secret_data = "is flat"
revision = self.secret_page.save_revision()
revision.publish()
# clean up log entries
PageLogEntry.objects.all().delete()
def test_log_entries_created_from_revisions(self):
management.call_command("create_log_entries_from_revisions")
# Should not create entries for empty revisions.
self.assertListEqual(
list(PageLogEntry.objects.values_list("page_id", "action")),
# Default PageLogEntry sort order is from newest event to oldest.
# We reverse here to make it easier to understand what is being
# tested. The events here should correspond with setUp above.
list(
reversed(
[
# The SimplePage was created in draft mode, with an initial revision.
(self.page.pk, "wagtail.create"),
(self.page.pk, "wagtail.edit"),
# The SimplePage was edited as a new draft, then published.
(self.page.pk, "wagtail.edit"),
(self.page.pk, "wagtail.publish"),
# The SecretPage was created in draft mode, with an initial revision.
(self.secret_page.pk, "wagtail.create"),
(self.secret_page.pk, "wagtail.edit"),
# The SecretPage was edited as a new draft, then published.
(self.secret_page.pk, "wagtail.edit"),
(self.secret_page.pk, "wagtail.publish"),
]
)
),
)
def test_command_doesnt_crash_for_revisions_without_page_model(self):
with mock.patch(
"wagtail.models.Page.specific_class",
return_value=None,
new_callable=mock.PropertyMock,
):
management.call_command("create_log_entries_from_revisions")
self.assertEqual(PageLogEntry.objects.count(), 0)

View File

@@ -0,0 +1,63 @@
"""
Check that all changes to Wagtail models have had migrations created. If there
are outstanding model changes that need migrations, fail the tests.
"""
from django.apps import apps
from django.db.migrations.autodetector import MigrationAutodetector
from django.db.migrations.loader import MigrationLoader
from django.db.migrations.questioner import MigrationQuestioner
from django.db.migrations.state import ProjectState
from django.test import TestCase
class TestForMigrations(TestCase):
def test__migrations(self):
app_labels = {
app.label
for app in apps.get_app_configs()
if app.name.split(".")[0] == "wagtail"
}
for app_label in app_labels:
apps.get_app_config(app_label.split(".")[-1])
loader = MigrationLoader(None, ignore_no_migrations=True)
conflicts = {
(app_label, conflict)
for app_label, conflict in loader.detect_conflicts().items()
if app_label in app_labels
}
if conflicts:
name_str = "; ".join(
"{} in {}".format(", ".join(names), app)
for app, names in conflicts.items()
)
self.fail("Conflicting migrations detected (%s)." % name_str)
autodetector = MigrationAutodetector(
loader.project_state(),
ProjectState.from_apps(apps),
MigrationQuestioner(specified_apps=app_labels, dry_run=True),
)
changes = autodetector.changes(
graph=loader.graph,
trim_to_apps=app_labels or None,
convert_apps=app_labels or None,
)
if changes:
migrations = "\n".join(
" {migration}\n{changes}".format(
migration=migration,
changes="\n".join(
f" {operation.describe()}"
for operation in migration.operations
),
)
for (_, migrations) in changes.items()
for migration in migrations
)
self.fail("Model changes with no migrations detected:\n%s" % migrations)

View File

@@ -0,0 +1,173 @@
from unittest import mock
from django.conf import settings
from wagtail.models import Page
from wagtail.test.routablepage.models import RoutablePageTest
from wagtail.test.utils import WagtailPageTestCase
class TestCustomPageAssertions(WagtailPageTestCase):
@classmethod
def setUpTestData(cls):
cls.superuser = cls.create_superuser("super")
def setUp(self):
self.parent = Page.objects.get(id=2)
self.page = RoutablePageTest(
title="Hello world!",
slug="hello-world",
)
self.parent.add_child(instance=self.page)
def test_is_routable(self):
self.assertPageIsRoutable(self.page)
def test_is_routable_with_alternative_route(self):
self.assertPageIsRoutable(self.page, "archive/year/1984/")
def test_is_routable_fails_for_draft_page(self):
self.page.live = False
self.page.save()
with self.assertRaises(self.failureException):
self.assertPageIsRoutable(self.page)
def test_is_routable_fails_for_invalid_route_path(self):
with self.assertRaises(self.failureException):
self.assertPageIsRoutable(self.page, "invalid-route-path/")
@mock.patch("django.test.testcases.Client.get")
@mock.patch("django.test.testcases.Client.force_login")
def test_is_renderable(self, mocked_force_login, mocked_get):
self.assertPageIsRenderable(self.page)
mocked_force_login.assert_not_called()
mocked_get.assert_called_once_with("/hello-world/", data=None)
@mock.patch("django.test.testcases.Client.get")
@mock.patch("django.test.testcases.Client.force_login")
def test_is_renderable_for_alternative_route(self, mocked_force_login, mocked_get):
self.assertPageIsRenderable(self.page, "archive/year/1984/")
mocked_force_login.assert_not_called()
mocked_get.assert_called_once_with("/hello-world/archive/year/1984/", data=None)
@mock.patch("django.test.testcases.Client.get")
@mock.patch("django.test.testcases.Client.force_login")
def test_is_renderable_for_user(self, mocked_force_login, mocked_get):
self.assertPageIsRenderable(self.page, user=self.superuser)
mocked_force_login.assert_called_once_with(
self.superuser, settings.AUTHENTICATION_BACKENDS[0]
)
mocked_get.assert_called_once_with("/hello-world/", data=None)
@mock.patch("django.test.testcases.Client.get")
def test_is_renderable_with_query_data(self, mocked_get):
query_data = {"p": 1, "q": "test"}
self.assertPageIsRenderable(self.page, query_data=query_data)
mocked_get.assert_called_once_with("/hello-world/", data=query_data)
@mock.patch("django.test.testcases.Client.post")
def test_is_renderable_with_query_and_post_data(self, mocked_post):
query_data = {"p": 1, "q": "test"}
post_data = {"subscribe": True}
self.assertPageIsRenderable(
self.page, query_data=query_data, post_data=post_data
)
mocked_post.assert_called_once_with(
"/hello-world/", data=post_data, QUERYSTRING="p=1&q=test"
)
def test_is_renderable_for_draft_page(self):
self.page.live = False
self.page.save()
# When accept_404 is False (the default) the test should fail
with self.assertRaises(self.failureException):
self.assertPageIsRenderable(self.page)
# When accept_404 is True, the test should pass
self.assertPageIsRenderable(self.page, accept_404=True)
def test_is_renderable_for_invalid_route_path(self):
# When accept_404 is False (the default) the test should fail
with self.assertRaises(self.failureException):
self.assertPageIsRenderable(self.page, "invalid-route-path/")
# When accept_404 is True, the test should pass
self.assertPageIsRenderable(self.page, "invalid-route-path/", accept_404=True)
def test_is_rendereable_accept_redirect(self):
redirect_route_paths = [
"permanant-homepage-redirect/",
"temporary-homepage-redirect/",
]
# When accept_redirect is False (the default) the tests should fail
for route_path in redirect_route_paths:
with self.assertRaises(self.failureException):
self.assertPageIsRenderable(self.page, route_path)
# When accept_redirect is True, the tests should pass
for route_path in redirect_route_paths:
self.assertPageIsRenderable(self.page, route_path, accept_redirect=True)
def test_is_editable(self):
self.assertPageIsEditable(self.page)
@mock.patch("django.test.testcases.Client.force_login")
def test_is_editable_always_authenticates(self, mocked_force_login):
try:
self.assertPageIsEditable(self.page)
except self.failureException:
pass
mocked_force_login.assert_called_with(
self._pageiseditable_superuser, settings.AUTHENTICATION_BACKENDS[0]
)
try:
self.assertPageIsEditable(self.page, user=self.superuser)
except self.failureException:
pass
mocked_force_login.assert_called_with(
self.superuser, settings.AUTHENTICATION_BACKENDS[0]
)
@mock.patch("django.test.testcases.Client.get")
@mock.patch("django.test.testcases.Client.force_login")
def test_is_editable_with_permission_lacking_user(
self, mocked_force_login, mocked_get
):
user = self.create_user("bob")
with self.assertRaises(self.failureException):
self.assertPageIsEditable(self.page, user=user)
mocked_force_login.assert_not_called()
mocked_get.assert_not_called()
def test_is_editable_with_post_data(self):
self.assertPageIsEditable(
self.page,
post_data={
"title": "Goodbye world?",
"slug": "goodbye-world",
"content": "goodbye",
},
)
def test_is_previewable(self):
self.assertPageIsPreviewable(self.page)
def test_is_previewable_with_post_data(self):
self.assertPageIsPreviewable(
self.page, post_data={"title": "test", "slug": "test"}
)
def test_is_previewable_with_custom_user(self):
self.assertPageIsPreviewable(self.page, user=self.superuser)
def test_is_previewable_for_alternative_mode(self):
self.assertPageIsPreviewable(self.page, mode="extra")
def test_is_previewable_for_broken_mode(self):
with self.assertRaises(self.failureException):
self.assertPageIsPreviewable(self.page, mode="broken")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,521 @@
from django.contrib.auth.models import AnonymousUser, Group, Permission
from django.test import TestCase
from wagtail.models import GroupPagePermission, Page, get_default_page_content_type
from wagtail.permission_policies.pages import PagePermissionPolicy
from wagtail.test.utils import WagtailTestUtils
from wagtail.tests.test_permission_policies import PermissionPolicyTestUtils
class PermissionPolicyTestCase(PermissionPolicyTestUtils, WagtailTestUtils, TestCase):
def setUp(self):
page_type = get_default_page_content_type()
self.root_page = Page.objects.get(id=2)
self.reports_page = self.root_page.add_child(
instance=Page(
title="Reports",
slug="reports",
)
)
root_editors_group = Group.objects.create(name="Root editors")
self.root_edit_perm = GroupPagePermission.objects.create(
group=root_editors_group,
page=self.root_page,
permission=Permission.objects.get(
content_type=page_type, codename="change_page"
),
)
report_editors_group = Group.objects.create(name="Report editors")
self.report_edit_perm = GroupPagePermission.objects.create(
group=report_editors_group,
page=self.reports_page,
permission=Permission.objects.get(
content_type=page_type, codename="change_page"
),
)
report_adders_group = Group.objects.create(name="Report adders")
self.report_add_perm = GroupPagePermission.objects.create(
group=report_adders_group,
page=self.reports_page,
permission=Permission.objects.get(
content_type=page_type, codename="add_page"
),
)
# Users
self.superuser = self.create_superuser(
"superuser", "superuser@example.com", "password"
)
self.inactive_superuser = self.create_superuser(
"inactivesuperuser", "inactivesuperuser@example.com", "password"
)
self.inactive_superuser.is_active = False
self.inactive_superuser.save()
# a user with edit permission through the root_editors_group
self.root_editor = self.create_user(
"rooteditor", "rooteditor@example.com", "password"
)
self.root_editor.groups.add(root_editors_group)
# a user that has edit permission, but is inactive
self.inactive_root_editor = self.create_user(
"inactiverooteditor", "inactiverooteditor@example.com", "password"
)
self.inactive_root_editor.groups.add(root_editors_group)
self.inactive_root_editor.is_active = False
self.inactive_root_editor.save()
# a user with edit permission on reports via the report_editors_group
self.report_editor = self.create_user(
"reporteditor", "reporteditor@example.com", "password"
)
self.report_editor.groups.add(report_editors_group)
# a user with add permission on reports via the report_adders_group
self.report_adder = self.create_user(
"reportadder", "reportadder@example.com", "password"
)
self.report_adder.groups.add(report_adders_group)
# a user with no permissions
self.useless_user = self.create_user(
"uselessuser", "uselessuser@example.com", "password"
)
self.anonymous_user = AnonymousUser()
# a page in the root owned by 'reporteditor'
self.editor_page = self.root_page.add_child(
instance=Page(
title="reporteditor's page",
slug="reporteditor-page",
owner=self.report_editor,
)
)
# a page in reports owned by 'reporteditor'
self.editor_report = self.reports_page.add_child(
instance=Page(
title="reporteditor's report",
slug="reporteditor-report",
owner=self.report_editor,
)
)
# a page in reports owned by 'reportadder'
self.adder_report = self.reports_page.add_child(
instance=Page(
title="reportadder's report",
slug="reportadder-report",
owner=self.report_adder,
)
)
# a page in reports owned by 'uselessuser'
self.useless_report = self.reports_page.add_child(
instance=Page(
title="uselessuser's report",
slug="uselessuser-report",
owner=self.useless_user,
)
)
# a page in reports with no owner
self.anonymous_report = self.reports_page.add_child(
instance=Page(
title="anonymous report",
slug="anonymous-report",
)
)
class TestPagePermissionPolicy(PermissionPolicyTestCase):
def setUp(self):
super().setUp()
self.policy = PagePermissionPolicy()
def _test_get_all_permissions_for_user(self):
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.superuser),
{},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.inactive_superuser),
{},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.inactive_root_editor),
{},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.useless_user),
{},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.anonymous_user),
{},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.root_editor),
{self.root_edit_perm},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.report_editor),
{self.report_edit_perm},
)
self.assertResultSetEqual(
self.policy.get_cached_permissions_for_user(self.report_adder),
{self.report_add_perm},
)
def test_get_cached_permissions_for_user(self):
self._test_get_all_permissions_for_user()
with self.assertNumQueries(0):
self._test_get_all_permissions_for_user()
def test_user_has_permission(self):
self.assertUserPermissionMatrix(
[
(self.superuser, True, True, True, True),
(self.inactive_superuser, False, False, False, False),
(self.root_editor, False, True, False, False),
(self.inactive_root_editor, False, False, False, False),
(self.report_editor, False, True, False, False),
(self.report_adder, True, True, False, False),
(self.useless_user, False, False, False, False),
(self.anonymous_user, False, False, False, False),
],
["add", "change", "delete", "frobnicate"],
)
def test_user_has_any_permission(self):
self.assertTrue(
self.policy.user_has_any_permission(self.superuser, ["add", "change"])
)
self.assertFalse(
self.policy.user_has_any_permission(
self.inactive_superuser, ["add", "change"]
)
)
self.assertTrue(
self.policy.user_has_any_permission(self.report_editor, ["add", "change"])
)
self.assertTrue(
self.policy.user_has_any_permission(self.report_adder, ["add", "change"])
)
self.assertFalse(
self.policy.user_has_any_permission(self.anonymous_user, ["add", "change"])
)
self.assertTrue(
self.policy.user_has_any_permission(self.report_adder, ["change"])
)
def test_users_with_any_permission(self):
users_with_add_or_change_permission = self.policy.users_with_any_permission(
["add", "change"]
)
self.assertResultSetEqual(
users_with_add_or_change_permission,
[
self.superuser,
self.root_editor,
self.report_editor,
self.report_adder,
],
)
users_with_add_or_frobnicate_permission = self.policy.users_with_any_permission(
["add", "frobnicate"]
)
self.assertResultSetEqual(
users_with_add_or_frobnicate_permission,
[
self.superuser,
self.report_adder,
],
)
users_with_edit_or_frobnicate_permission = (
self.policy.users_with_any_permission(["change", "frobnicate"])
)
self.assertResultSetEqual(
users_with_edit_or_frobnicate_permission,
[
self.superuser,
self.root_editor,
self.report_editor,
self.report_adder,
],
)
def test_users_with_permission(self):
users_with_change_permission = self.policy.users_with_permission("change")
self.assertResultSetEqual(
users_with_change_permission,
[
self.superuser,
self.root_editor,
self.report_editor,
self.report_adder,
],
)
users_with_custom_permission = self.policy.users_with_permission("frobnicate")
self.assertResultSetEqual(
users_with_custom_permission,
[
self.superuser,
],
)
def test_user_has_permission_for_instance(self):
# page in the root is only editable by users with permissions
# on the root page
self.assertUserInstancePermissionMatrix(
self.editor_page,
[
(self.superuser, True, True, True),
(self.inactive_superuser, False, False, False),
(self.root_editor, True, False, False),
(self.inactive_root_editor, False, False, False),
(self.report_editor, False, False, False),
(self.report_adder, False, False, False),
(self.useless_user, False, False, False),
(self.anonymous_user, False, False, False),
],
["change", "delete", "frobnicate"],
)
# page in 'reports' is editable by users with permissions
# on 'reports' or the root page
self.assertUserInstancePermissionMatrix(
self.useless_report,
[
(self.superuser, True, True, True),
(self.inactive_superuser, False, False, False),
(self.root_editor, True, False, False),
(self.inactive_root_editor, False, False, False),
(self.report_editor, True, False, False),
(self.report_adder, False, False, False),
(self.useless_user, False, False, False),
(self.anonymous_user, False, False, False),
],
["change", "delete", "frobnicate"],
)
def test_user_has_any_permission_for_instance(self):
self.assertTrue(
self.policy.user_has_any_permission_for_instance(
self.report_editor, ["change", "delete"], self.useless_report
)
)
self.assertFalse(
self.policy.user_has_any_permission_for_instance(
self.report_editor, ["change", "delete"], self.editor_page
)
)
self.assertTrue(
self.policy.user_has_any_permission_for_instance(
self.report_adder, ["change", "delete"], self.adder_report
)
)
self.assertFalse(
self.policy.user_has_any_permission_for_instance(
self.anonymous_user, ["change", "delete"], self.editor_page
)
)
def test_instances_user_has_permission_for(self):
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.superuser,
"change",
),
Page.objects.all(),
)
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.inactive_superuser,
"change",
),
[],
)
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.root_editor,
"change",
),
[
self.root_page,
self.reports_page,
self.editor_page,
self.editor_report,
self.adder_report,
self.useless_report,
self.anonymous_report,
],
)
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.report_editor,
"change",
),
[
self.reports_page,
self.editor_report,
self.useless_report,
self.adder_report,
self.anonymous_report,
],
)
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.useless_user,
"change",
),
[],
)
self.assertResultSetEqual(
self.policy.instances_user_has_permission_for(
self.anonymous_user,
"change",
),
[],
)
def test_instances_user_has_any_permission_for(self):
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.superuser, ["change", "delete"]
),
Page.objects.all(),
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.inactive_superuser, ["change", "delete"]
),
[],
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.root_editor, ["change", "delete"]
),
[
self.root_page,
self.reports_page,
self.editor_page,
self.editor_report,
self.adder_report,
self.useless_report,
self.anonymous_report,
],
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.report_editor, ["change", "delete"]
),
[
self.reports_page,
self.editor_report,
self.adder_report,
self.useless_report,
self.anonymous_report,
],
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.report_adder, ["change", "delete"]
),
[self.adder_report],
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.useless_user, ["change", "delete"]
),
[],
)
self.assertResultSetEqual(
self.policy.instances_user_has_any_permission_for(
self.anonymous_user, ["change", "delete"]
),
[],
)
def test_users_with_permission_for_instance(self):
self.assertResultSetEqual(
self.policy.users_with_permission_for_instance("change", self.editor_page),
[self.superuser, self.root_editor],
)
self.assertResultSetEqual(
self.policy.users_with_permission_for_instance("change", self.adder_report),
[self.superuser, self.root_editor, self.report_editor, self.report_adder],
)
self.assertResultSetEqual(
self.policy.users_with_permission_for_instance(
"change", self.editor_report
),
[self.superuser, self.root_editor, self.report_editor],
)
self.assertResultSetEqual(
self.policy.users_with_permission_for_instance(
"change", self.useless_report
),
[self.superuser, self.root_editor, self.report_editor],
)
self.assertResultSetEqual(
self.policy.users_with_permission_for_instance(
"change", self.anonymous_report
),
[self.superuser, self.root_editor, self.report_editor],
)
def test_users_with_any_permission_for_instance(self):
self.assertResultSetEqual(
self.policy.users_with_any_permission_for_instance(
["change", "delete"], self.editor_page
),
[self.superuser, self.root_editor],
)
self.assertResultSetEqual(
self.policy.users_with_any_permission_for_instance(
["change", "delete"], self.adder_report
),
[self.superuser, self.root_editor, self.report_editor, self.report_adder],
)
self.assertResultSetEqual(
self.policy.users_with_any_permission_for_instance(
["change", "delete"], self.useless_report
),
[self.superuser, self.root_editor, self.report_editor],
)
self.assertResultSetEqual(
self.policy.users_with_any_permission_for_instance(
["delete", "frobnicate"], self.useless_report
),
[self.superuser],
)

View File

@@ -0,0 +1,961 @@
import json
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.test import Client, TestCase, override_settings
from django.utils import timezone
from wagtail.models import (
GroupApprovalTask,
GroupPagePermission,
Locale,
Page,
Workflow,
WorkflowTask,
)
from wagtail.permission_policies.pages import PagePermissionPolicy
from wagtail.test.testapp.models import (
BusinessSubIndex,
CustomPermissionPage,
CustomPermissionTester,
EventIndex,
EventPage,
SingletonPageViaMaxCount,
)
class TestPagePermission(TestCase):
fixtures = ["test.json"]
def create_workflow_and_task(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = GroupApprovalTask.objects.create(name="test_task_1")
task_1.groups.add(Group.objects.get(name="Event moderators"))
WorkflowTask.objects.create(
workflow=workflow, task=task_1.task_ptr, sort_order=1
)
return workflow, task_1
def test_nonpublisher_page_permissions(self):
event_editor = get_user_model().objects.get(email="eventeditor@example.com")
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
board_meetings_page = BusinessSubIndex.objects.get(
url_path="/home/events/businessy-events/board-meetings/"
)
homepage_perms = homepage.permissions_for_user(event_editor)
christmas_page_perms = christmas_page.permissions_for_user(event_editor)
unpub_perms = unpublished_event_page.permissions_for_user(event_editor)
someone_elses_event_perms = someone_elses_event_page.permissions_for_user(
event_editor
)
board_meetings_perms = board_meetings_page.permissions_for_user(event_editor)
self.assertFalse(homepage_perms.can_add_subpage())
self.assertTrue(christmas_page_perms.can_add_subpage())
self.assertTrue(unpub_perms.can_add_subpage())
self.assertTrue(someone_elses_event_perms.can_add_subpage())
self.assertFalse(homepage_perms.can_edit())
self.assertTrue(christmas_page_perms.can_edit())
self.assertTrue(unpub_perms.can_edit())
# basic 'add' permission doesn't allow editing pages owned by someone else
self.assertFalse(someone_elses_event_perms.can_edit())
self.assertFalse(homepage_perms.can_delete())
self.assertFalse(
christmas_page_perms.can_delete()
) # cannot delete because it is published
self.assertTrue(unpub_perms.can_delete())
self.assertFalse(someone_elses_event_perms.can_delete())
self.assertFalse(homepage_perms.can_publish())
self.assertFalse(christmas_page_perms.can_publish())
self.assertFalse(unpub_perms.can_publish())
self.assertFalse(homepage_perms.can_unpublish())
self.assertFalse(christmas_page_perms.can_unpublish())
self.assertFalse(unpub_perms.can_unpublish())
self.assertFalse(homepage_perms.can_publish_subpage())
self.assertFalse(christmas_page_perms.can_publish_subpage())
self.assertFalse(unpub_perms.can_publish_subpage())
self.assertFalse(homepage_perms.can_reorder_children())
self.assertFalse(christmas_page_perms.can_reorder_children())
self.assertFalse(unpub_perms.can_reorder_children())
self.assertFalse(homepage_perms.can_move())
# cannot move because this would involve unpublishing from its current location
self.assertFalse(christmas_page_perms.can_move())
self.assertTrue(unpub_perms.can_move())
self.assertFalse(someone_elses_event_perms.can_move())
# cannot move because this would involve unpublishing from its current location
self.assertFalse(christmas_page_perms.can_move_to(unpublished_event_page))
self.assertTrue(unpub_perms.can_move_to(christmas_page))
self.assertFalse(
unpub_perms.can_move_to(homepage)
) # no permission to create pages at destination
self.assertFalse(
unpub_perms.can_move_to(unpublished_event_page)
) # cannot make page a child of itself
# cannot move because the subpage_types rule of BusinessSubIndex forbids EventPage as a subpage
self.assertFalse(unpub_perms.can_move_to(board_meetings_page))
self.assertTrue(board_meetings_perms.can_move())
# cannot move because the parent_page_types rule of BusinessSubIndex forbids EventPage as a parent
self.assertFalse(board_meetings_perms.can_move_to(christmas_page))
def test_publisher_page_permissions(self):
event_moderator = get_user_model().objects.get(
email="eventmoderator@example.com"
)
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
board_meetings_page = BusinessSubIndex.objects.get(
url_path="/home/events/businessy-events/board-meetings/"
)
homepage_perms = homepage.permissions_for_user(event_moderator)
christmas_page_perms = christmas_page.permissions_for_user(event_moderator)
unpub_perms = unpublished_event_page.permissions_for_user(event_moderator)
board_meetings_perms = board_meetings_page.permissions_for_user(event_moderator)
self.assertFalse(homepage_perms.can_add_subpage())
self.assertTrue(christmas_page_perms.can_add_subpage())
self.assertTrue(unpub_perms.can_add_subpage())
self.assertFalse(homepage_perms.can_edit())
self.assertTrue(christmas_page_perms.can_edit())
self.assertTrue(unpub_perms.can_edit())
self.assertFalse(homepage_perms.can_delete())
# can delete a published page because we have publish permission
self.assertTrue(christmas_page_perms.can_delete())
self.assertTrue(unpub_perms.can_delete())
self.assertFalse(homepage_perms.can_publish())
self.assertTrue(christmas_page_perms.can_publish())
self.assertTrue(unpub_perms.can_publish())
self.assertFalse(homepage_perms.can_unpublish())
self.assertTrue(christmas_page_perms.can_unpublish())
self.assertFalse(
unpub_perms.can_unpublish()
) # cannot unpublish a page that isn't published
self.assertFalse(homepage_perms.can_publish_subpage())
self.assertTrue(christmas_page_perms.can_publish_subpage())
self.assertTrue(unpub_perms.can_publish_subpage())
self.assertFalse(homepage_perms.can_reorder_children())
self.assertTrue(christmas_page_perms.can_reorder_children())
self.assertTrue(unpub_perms.can_reorder_children())
self.assertFalse(homepage_perms.can_move())
self.assertTrue(christmas_page_perms.can_move())
self.assertTrue(unpub_perms.can_move())
self.assertTrue(christmas_page_perms.can_move_to(unpublished_event_page))
self.assertTrue(unpub_perms.can_move_to(christmas_page))
self.assertFalse(
unpub_perms.can_move_to(homepage)
) # no permission to create pages at destination
self.assertFalse(
unpub_perms.can_move_to(unpublished_event_page)
) # cannot make page a child of itself
# cannot move because the subpage_types rule of BusinessSubIndex forbids EventPage as a subpage
self.assertFalse(unpub_perms.can_move_to(board_meetings_page))
self.assertTrue(board_meetings_perms.can_move())
# cannot move because the parent_page_types rule of BusinessSubIndex forbids EventPage as a parent
self.assertFalse(board_meetings_perms.can_move_to(christmas_page))
def test_publish_page_permissions_without_edit(self):
event_moderator = get_user_model().objects.get(
email="eventmoderator@example.com"
)
# Remove 'edit' permission from the event_moderator group
GroupPagePermission.objects.filter(
group__name="Event moderators", permission__codename="change_page"
).delete()
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
# 'someone else's event' is owned by eventmoderator
moderator_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
homepage_perms = homepage.permissions_for_user(event_moderator)
christmas_page_perms = christmas_page.permissions_for_user(event_moderator)
unpub_perms = unpublished_event_page.permissions_for_user(event_moderator)
moderator_event_perms = moderator_event_page.permissions_for_user(
event_moderator
)
# we still have add permission within events
self.assertFalse(homepage_perms.can_add_subpage())
self.assertTrue(christmas_page_perms.can_add_subpage())
# add permission lets us edit our own event
self.assertFalse(christmas_page_perms.can_edit())
self.assertTrue(moderator_event_perms.can_edit())
# with add + publish permissions, can delete a published page owned by us
self.assertTrue(moderator_event_perms.can_delete())
# but NOT a page owned by someone else (which would require edit permission)
self.assertFalse(christmas_page_perms.can_delete())
# ...even an unpublished one
self.assertFalse(unpub_perms.can_delete())
# we can still publish/unpublish events regardless of owner
self.assertFalse(homepage_perms.can_publish())
self.assertTrue(christmas_page_perms.can_publish())
self.assertTrue(unpub_perms.can_publish())
self.assertFalse(homepage_perms.can_unpublish())
self.assertTrue(christmas_page_perms.can_unpublish())
self.assertFalse(
unpub_perms.can_unpublish()
) # cannot unpublish a page that isn't published
self.assertFalse(homepage_perms.can_publish_subpage())
self.assertTrue(christmas_page_perms.can_publish_subpage())
self.assertTrue(unpub_perms.can_publish_subpage())
# reorder permission is considered equivalent to publish permission
# (so we can do it on pages we can't edit)
self.assertFalse(homepage_perms.can_reorder_children())
self.assertTrue(christmas_page_perms.can_reorder_children())
self.assertTrue(unpub_perms.can_reorder_children())
# moving requires edit permission
self.assertFalse(homepage_perms.can_move())
self.assertFalse(christmas_page_perms.can_move())
self.assertTrue(moderator_event_perms.can_move())
# and add permission on the destination
self.assertFalse(moderator_event_perms.can_move_to(homepage))
self.assertTrue(moderator_event_perms.can_move_to(unpublished_event_page))
def test_cannot_bulk_delete_without_permissions(self):
event_moderator = get_user_model().objects.get(
email="eventmoderator@example.com"
)
events_page = EventIndex.objects.get(url_path="/home/events/")
events_perms = events_page.permissions_for_user(event_moderator)
self.assertFalse(events_perms.can_delete())
def test_can_bulk_delete_with_permissions(self):
event_moderator = get_user_model().objects.get(
email="eventmoderator@example.com"
)
events_page = EventIndex.objects.get(url_path="/home/events/")
# Assign 'bulk_delete' permission to the event_moderator group
event_moderators_group = Group.objects.get(name="Event moderators")
GroupPagePermission.objects.create(
group=event_moderators_group,
page=events_page,
permission_type="bulk_delete",
)
events_perms = events_page.permissions_for_user(event_moderator)
self.assertTrue(events_perms.can_delete())
def test_need_delete_permission_to_bulk_delete(self):
"""
Having bulk_delete permission is not in itself sufficient to allow deleting pages -
you need actual edit permission on the pages too.
In this test the event editor is given bulk_delete permission, but since their
only other permission is 'add', they cannot delete published pages or pages owned
by other users, and therefore the bulk deletion cannot happen.
"""
event_editor = get_user_model().objects.get(email="eventeditor@example.com")
events_page = EventIndex.objects.get(url_path="/home/events/")
# Assign 'bulk_delete' permission to the event_editor group
event_editors_group = Group.objects.get(name="Event editors")
GroupPagePermission.objects.create(
group=event_editors_group, page=events_page, permission_type="bulk_delete"
)
events_perms = events_page.permissions_for_user(event_editor)
self.assertFalse(events_perms.can_delete())
def test_inactive_user_has_no_permissions(self):
user = get_user_model().objects.get(email="inactiveuser@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
christmas_page_perms = christmas_page.permissions_for_user(user)
unpub_perms = unpublished_event_page.permissions_for_user(user)
self.assertFalse(unpub_perms.can_add_subpage())
self.assertFalse(unpub_perms.can_edit())
self.assertFalse(unpub_perms.can_delete())
self.assertFalse(unpub_perms.can_publish())
self.assertFalse(christmas_page_perms.can_unpublish())
self.assertFalse(unpub_perms.can_publish_subpage())
self.assertFalse(unpub_perms.can_reorder_children())
self.assertFalse(unpub_perms.can_move())
self.assertFalse(unpub_perms.can_move_to(christmas_page))
def test_superuser_has_full_permissions(self):
user = get_user_model().objects.get(email="superuser@example.com")
homepage = Page.objects.get(url_path="/home/").specific
root = Page.objects.get(url_path="/").specific
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
board_meetings_page = BusinessSubIndex.objects.get(
url_path="/home/events/businessy-events/board-meetings/"
)
homepage_perms = homepage.permissions_for_user(user)
root_perms = root.permissions_for_user(user)
unpub_perms = unpublished_event_page.permissions_for_user(user)
board_meetings_perms = board_meetings_page.permissions_for_user(user)
self.assertTrue(homepage_perms.can_add_subpage())
self.assertTrue(root_perms.can_add_subpage())
self.assertTrue(homepage_perms.can_edit())
self.assertFalse(
root_perms.can_edit()
) # root is not a real editable page, even to superusers
self.assertTrue(homepage_perms.can_delete())
self.assertFalse(root_perms.can_delete())
self.assertTrue(homepage_perms.can_publish())
self.assertFalse(root_perms.can_publish())
self.assertTrue(homepage_perms.can_unpublish())
self.assertFalse(root_perms.can_unpublish())
self.assertFalse(unpub_perms.can_unpublish())
self.assertTrue(homepage_perms.can_publish_subpage())
self.assertTrue(root_perms.can_publish_subpage())
self.assertTrue(homepage_perms.can_reorder_children())
self.assertTrue(root_perms.can_reorder_children())
self.assertTrue(homepage_perms.can_move())
self.assertFalse(root_perms.can_move())
self.assertTrue(homepage_perms.can_move_to(root))
self.assertFalse(homepage_perms.can_move_to(unpublished_event_page))
# cannot move because the subpage_types rule of BusinessSubIndex forbids EventPage as a subpage
self.assertFalse(unpub_perms.can_move_to(board_meetings_page))
self.assertTrue(board_meetings_perms.can_move())
# cannot move because the parent_page_types rule of BusinessSubIndex forbids EventPage as a parent
self.assertFalse(board_meetings_perms.can_move_to(unpublished_event_page))
def test_cant_move_pages_between_locales(self):
user = get_user_model().objects.get(email="superuser@example.com")
homepage = Page.objects.get(url_path="/home/").specific
root = Page.objects.get(url_path="/").specific
fr_locale = Locale.objects.create(language_code="fr")
fr_page = root.add_child(
instance=Page(
title="French page",
slug="french-page",
locale=fr_locale,
)
)
fr_homepage = root.add_child(
instance=Page(
title="French homepage",
slug="french-homepage",
locale=fr_locale,
)
)
french_page_perms = fr_page.permissions_for_user(user)
# fr_page can be moved into fr_homepage but not homepage
self.assertFalse(french_page_perms.can_move_to(homepage))
self.assertTrue(french_page_perms.can_move_to(fr_homepage))
# All pages can be moved to the root, regardless what language they are
self.assertTrue(french_page_perms.can_move_to(root))
events_index = Page.objects.get(url_path="/home/events/")
events_index_perms = events_index.permissions_for_user(user)
self.assertTrue(events_index_perms.can_move_to(root))
def test_editable_pages_for_user_with_add_permission(self):
event_editor = get_user_model().objects.get(email="eventeditor@example.com")
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
policy = PagePermissionPolicy()
editable_pages = policy.instances_user_has_permission_for(
event_editor, "change"
)
can_edit_pages = policy.user_has_permission(event_editor, "change")
publishable_pages = policy.instances_user_has_permission_for(
event_editor, "publish"
)
can_publish_pages = policy.user_has_permission(event_editor, "publish")
self.assertFalse(editable_pages.filter(id=homepage.id).exists())
self.assertTrue(editable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(editable_pages.filter(id=unpublished_event_page.id).exists())
self.assertFalse(editable_pages.filter(id=someone_elses_event_page.id).exists())
self.assertTrue(can_edit_pages)
self.assertFalse(publishable_pages.filter(id=homepage.id).exists())
self.assertFalse(publishable_pages.filter(id=christmas_page.id).exists())
self.assertFalse(
publishable_pages.filter(id=unpublished_event_page.id).exists()
)
self.assertFalse(
publishable_pages.filter(id=someone_elses_event_page.id).exists()
)
self.assertFalse(can_publish_pages)
def test_explorable_pages(self):
event_editor = get_user_model().objects.get(email="eventeditor@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
about_us_page = Page.objects.get(url_path="/home/about-us/")
policy = PagePermissionPolicy()
explorable_pages = policy.explorable_instances(event_editor)
# Verify all pages below /home/events/ are explorable
self.assertTrue(explorable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(explorable_pages.filter(id=unpublished_event_page.id).exists())
self.assertTrue(
explorable_pages.filter(id=someone_elses_event_page.id).exists()
)
# Verify page outside /events/ tree are not explorable
self.assertFalse(explorable_pages.filter(id=about_us_page.id).exists())
def test_explorable_pages_in_explorer(self):
event_editor = get_user_model().objects.get(email="eventeditor@example.com")
client = Client()
client.force_login(event_editor)
homepage = Page.objects.get(url_path="/home/")
explorer_response = client.get(
f"/admin/api/main/pages/?child_of={homepage.pk}&for_explorer=1"
)
explorer_json = json.loads(explorer_response.content.decode("utf-8"))
events_page = Page.objects.get(url_path="/home/events/")
about_us_page = Page.objects.get(url_path="/home/about-us/")
explorable_titles = [t.get("title") for t in explorer_json.get("items")]
self.assertIn(events_page.title, explorable_titles)
self.assertNotIn(about_us_page.title, explorable_titles)
def test_explorable_pages_with_permission_gap_in_hierarchy(self):
corporate_editor = get_user_model().objects.get(
email="corporateeditor@example.com"
)
policy = PagePermissionPolicy()
about_us_page = Page.objects.get(url_path="/home/about-us/")
businessy_events = Page.objects.get(url_path="/home/events/businessy-events/")
events_page = Page.objects.get(url_path="/home/events/")
explorable_pages = policy.explorable_instances(corporate_editor)
self.assertTrue(explorable_pages.filter(id=about_us_page.id).exists())
self.assertTrue(explorable_pages.filter(id=businessy_events.id).exists())
self.assertTrue(explorable_pages.filter(id=events_page.id).exists())
def test_editable_pages_for_user_with_edit_permission(self):
event_moderator = get_user_model().objects.get(
email="eventmoderator@example.com"
)
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
policy = PagePermissionPolicy()
editable_pages = policy.instances_user_has_permission_for(
event_moderator, "change"
)
can_edit_pages = policy.user_has_permission(event_moderator, "change")
publishable_pages = policy.instances_user_has_permission_for(
event_moderator, "publish"
)
can_publish_pages = policy.user_has_permission(event_moderator, "publish")
self.assertFalse(editable_pages.filter(id=homepage.id).exists())
self.assertTrue(editable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(editable_pages.filter(id=unpublished_event_page.id).exists())
self.assertTrue(editable_pages.filter(id=someone_elses_event_page.id).exists())
self.assertTrue(can_edit_pages)
self.assertFalse(publishable_pages.filter(id=homepage.id).exists())
self.assertTrue(publishable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(publishable_pages.filter(id=unpublished_event_page.id).exists())
self.assertTrue(
publishable_pages.filter(id=someone_elses_event_page.id).exists()
)
self.assertTrue(can_publish_pages)
def test_editable_pages_for_inactive_user(self):
user = get_user_model().objects.get(email="inactiveuser@example.com")
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
policy = PagePermissionPolicy()
editable_pages = policy.instances_user_has_permission_for(user, "change")
can_edit_pages = policy.user_has_permission(user, "change")
publishable_pages = policy.instances_user_has_permission_for(user, "publish")
can_publish_pages = policy.user_has_permission(user, "publish")
self.assertFalse(editable_pages.filter(id=homepage.id).exists())
self.assertFalse(editable_pages.filter(id=christmas_page.id).exists())
self.assertFalse(editable_pages.filter(id=unpublished_event_page.id).exists())
self.assertFalse(editable_pages.filter(id=someone_elses_event_page.id).exists())
self.assertFalse(can_edit_pages)
self.assertFalse(publishable_pages.filter(id=homepage.id).exists())
self.assertFalse(publishable_pages.filter(id=christmas_page.id).exists())
self.assertFalse(
publishable_pages.filter(id=unpublished_event_page.id).exists()
)
self.assertFalse(
publishable_pages.filter(id=someone_elses_event_page.id).exists()
)
self.assertFalse(can_publish_pages)
def test_editable_pages_for_superuser(self):
user = get_user_model().objects.get(email="superuser@example.com")
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
policy = PagePermissionPolicy()
editable_pages = policy.instances_user_has_permission_for(user, "change")
can_edit_pages = policy.user_has_permission(user, "change")
publishable_pages = policy.instances_user_has_permission_for(user, "publish")
can_publish_pages = policy.user_has_permission(user, "publish")
self.assertTrue(editable_pages.filter(id=homepage.id).exists())
self.assertTrue(editable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(editable_pages.filter(id=unpublished_event_page.id).exists())
self.assertTrue(editable_pages.filter(id=someone_elses_event_page.id).exists())
self.assertTrue(can_edit_pages)
self.assertTrue(publishable_pages.filter(id=homepage.id).exists())
self.assertTrue(publishable_pages.filter(id=christmas_page.id).exists())
self.assertTrue(publishable_pages.filter(id=unpublished_event_page.id).exists())
self.assertTrue(
publishable_pages.filter(id=someone_elses_event_page.id).exists()
)
self.assertTrue(can_publish_pages)
def test_editable_pages_for_non_editing_user(self):
user = get_user_model().objects.get(email="admin_only_user@example.com")
homepage = Page.objects.get(url_path="/home/")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
unpublished_event_page = EventPage.objects.get(
url_path="/home/events/tentative-unpublished-event/"
)
someone_elses_event_page = EventPage.objects.get(
url_path="/home/events/someone-elses-event/"
)
policy = PagePermissionPolicy()
editable_pages = policy.instances_user_has_permission_for(user, "change")
can_edit_pages = policy.user_has_permission(user, "change")
publishable_pages = policy.instances_user_has_permission_for(user, "publish")
can_publish_pages = policy.user_has_permission(user, "publish")
self.assertFalse(editable_pages.filter(id=homepage.id).exists())
self.assertFalse(editable_pages.filter(id=christmas_page.id).exists())
self.assertFalse(editable_pages.filter(id=unpublished_event_page.id).exists())
self.assertFalse(editable_pages.filter(id=someone_elses_event_page.id).exists())
self.assertFalse(can_edit_pages)
self.assertFalse(publishable_pages.filter(id=homepage.id).exists())
self.assertFalse(publishable_pages.filter(id=christmas_page.id).exists())
self.assertFalse(
publishable_pages.filter(id=unpublished_event_page.id).exists()
)
self.assertFalse(
publishable_pages.filter(id=someone_elses_event_page.id).exists()
)
self.assertFalse(can_publish_pages)
def test_lock_page_for_superuser(self):
user = get_user_model().objects.get(email="superuser@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
locked_page = Page.objects.get(url_path="/home/my-locked-page/")
perms = christmas_page.permissions_for_user(user)
locked_perms = locked_page.permissions_for_user(user)
self.assertTrue(perms.can_lock())
self.assertFalse(
locked_perms.can_unpublish()
) # locked pages can't be unpublished
self.assertTrue(perms.can_unlock())
def test_lock_page_for_moderator(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
perms = christmas_page.permissions_for_user(user)
self.assertTrue(perms.can_lock())
self.assertTrue(perms.can_unlock())
def test_lock_page_for_moderator_without_unlock_permission(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
GroupPagePermission.objects.filter(
group__name="Event moderators", permission__codename="unlock_page"
).delete()
perms = christmas_page.permissions_for_user(user)
self.assertTrue(perms.can_lock())
self.assertFalse(perms.can_unlock())
def test_lock_page_for_moderator_whole_locked_page_without_unlock_permission(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
# Lock the page
christmas_page.locked = True
christmas_page.locked_by = user
christmas_page.locked_at = timezone.now()
christmas_page.save()
GroupPagePermission.objects.filter(
group__name="Event moderators", permission__codename="unlock_page"
).delete()
perms = christmas_page.permissions_for_user(user)
# Unlike in the previous test, the user can unlock this page as it was them who locked
self.assertTrue(perms.can_lock())
self.assertTrue(perms.can_unlock())
def test_lock_page_for_editor(self):
user = get_user_model().objects.get(email="eventeditor@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
perms = christmas_page.permissions_for_user(user)
self.assertFalse(perms.can_lock())
self.assertFalse(perms.can_unlock())
def test_lock_page_for_non_editing_user(self):
user = get_user_model().objects.get(email="admin_only_user@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
perms = christmas_page.permissions_for_user(user)
self.assertFalse(perms.can_lock())
self.assertFalse(perms.can_unlock())
def test_lock_page_for_editor_with_lock_permission(self):
user = get_user_model().objects.get(email="eventeditor@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
GroupPagePermission.objects.create(
group=Group.objects.get(name="Event editors"),
page=christmas_page,
permission_type="lock",
)
perms = christmas_page.permissions_for_user(user)
self.assertTrue(perms.can_lock())
# Still shouldn't have unlock permission
self.assertFalse(perms.can_unlock())
def test_page_locked_for_unlocked_page(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
perms = christmas_page.permissions_for_user(user)
self.assertFalse(perms.page_locked())
def test_page_locked_for_locked_page(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
# Lock the page
christmas_page.locked = True
christmas_page.locked_by = user
christmas_page.locked_at = timezone.now()
christmas_page.save()
perms = christmas_page.permissions_for_user(user)
# The user who locked the page shouldn't see the page as locked
self.assertFalse(perms.page_locked())
# Other users should see the page as locked
other_user = get_user_model().objects.get(email="eventeditor@example.com")
other_perms = christmas_page.permissions_for_user(other_user)
self.assertTrue(other_perms.page_locked())
@override_settings(WAGTAILADMIN_GLOBAL_EDIT_LOCK=True)
def test_page_locked_for_locked_page_with_global_lock_enabled(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
# Lock the page
christmas_page.locked = True
christmas_page.locked_by = user
christmas_page.locked_at = timezone.now()
christmas_page.save()
perms = christmas_page.permissions_for_user(user)
# The user who locked the page should now also see the page as locked
self.assertTrue(perms.page_locked())
# Other users should see the page as locked, like before
other_user = get_user_model().objects.get(email="eventeditor@example.com")
other_perms = christmas_page.permissions_for_user(other_user)
self.assertTrue(other_perms.page_locked())
def test_page_locked_in_workflow(self):
workflow, task = self.create_workflow_and_task()
editor = get_user_model().objects.get(email="eventeditor@example.com")
moderator = get_user_model().objects.get(email="eventmoderator@example.com")
superuser = get_user_model().objects.get(email="superuser@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
christmas_page.save_revision()
workflow.start(christmas_page, editor)
moderator_perms = christmas_page.permissions_for_user(moderator)
# the moderator is in the group assigned to moderate the task, so the page should
# not be locked for them
self.assertFalse(moderator_perms.page_locked())
superuser_perms = christmas_page.permissions_for_user(superuser)
# superusers can moderate any GroupApprovalTask, so the page should not be locked
# for them
self.assertFalse(superuser_perms.page_locked())
editor_perms = christmas_page.permissions_for_user(editor)
# the editor is not in the group assigned to moderate the task, so the page should
# be locked for them
self.assertTrue(editor_perms.page_locked())
def test_page_lock_in_workflow(self):
workflow, task = self.create_workflow_and_task()
editor = get_user_model().objects.get(email="eventeditor@example.com")
moderator = get_user_model().objects.get(email="eventmoderator@example.com")
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
christmas_page.save_revision()
workflow.start(christmas_page, editor)
moderator_perms = christmas_page.permissions_for_user(moderator)
# the moderator is in the group assigned to moderate the task, so they can lock the page, but can't unlock it
# unless they're the locker
self.assertTrue(moderator_perms.can_lock())
self.assertFalse(moderator_perms.can_unlock())
editor_perms = christmas_page.permissions_for_user(editor)
# the editor is not in the group assigned to moderate the task, so they can't lock or unlock the page
self.assertFalse(editor_perms.can_lock())
self.assertFalse(editor_perms.can_unlock())
def test_custom_permission_tester_page(self):
homepage = Page.objects.get(url_path="/home/")
instance = CustomPermissionPage(
title="This page has a custom permission tester",
slug="page-with-custom-permission-tester",
)
homepage.add_child(instance=instance)
page = Page.objects.get(pk=instance.pk)
user = get_user_model().objects.get(email="eventeditor@example.com")
self.assertIsInstance(page.permissions_for_user(user), CustomPermissionTester)
class TestPagePermissionTesterCanCopyTo(TestCase):
"""Tests PagePermissionTester.can_copy_to()"""
fixtures = ["test.json"]
def setUp(self):
# These same pages will be used for testing the result for each user
self.board_meetings_page = BusinessSubIndex.objects.get(
url_path="/home/events/businessy-events/board-meetings/"
)
self.event_page = EventPage.objects.get(url_path="/home/events/christmas/")
# We'll also create a SingletonPageViaMaxCount to use
homepage = Page.objects.get(url_path="/home/")
self.singleton_page = SingletonPageViaMaxCount(title="there can be only one")
homepage.add_child(instance=self.singleton_page)
def test_inactive_user_cannot_copy_any_pages(self):
user = get_user_model().objects.get(email="inactiveuser@example.com")
# Create PagePermissionTester objects for this user, for each page
board_meetings_page_perms = self.board_meetings_page.permissions_for_user(user)
event_page_perms = self.event_page.permissions_for_user(user)
singleton_page_perms = self.singleton_page.permissions_for_user(user)
# This user should not be able to copy any pages
self.assertFalse(event_page_perms.can_copy_to(self.event_page.get_parent()))
self.assertFalse(
board_meetings_page_perms.can_copy_to(self.board_meetings_page.get_parent())
)
self.assertFalse(
singleton_page_perms.can_copy_to(self.singleton_page.get_parent())
)
def test_no_permissions_admin_cannot_copy_any_pages(self):
user = get_user_model().objects.get(email="admin_only_user@example.com")
# Create PagePermissionTester objects for this user, for each page
board_meetings_page_perms = self.board_meetings_page.permissions_for_user(user)
event_page_perms = self.event_page.permissions_for_user(user)
singleton_page_perms = self.singleton_page.permissions_for_user(user)
# This user should not be able to copy any pages
self.assertFalse(event_page_perms.can_copy_to(self.event_page.get_parent()))
self.assertFalse(
board_meetings_page_perms.can_copy_to(self.board_meetings_page.get_parent())
)
self.assertFalse(
singleton_page_perms.can_copy_to(self.singleton_page.get_parent())
)
def test_event_moderator_cannot_copy_a_singleton_page(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
# Create PagePermissionTester objects for this user, for each page
board_meetings_page_perms = self.board_meetings_page.permissions_for_user(user)
event_page_perms = self.event_page.permissions_for_user(user)
singleton_page_perms = self.singleton_page.permissions_for_user(user)
# We'd expect an event moderator to be able to copy an event page
self.assertTrue(event_page_perms.can_copy_to(self.event_page.get_parent()))
# This works because copying doesn't necessarily have to mean publishing
self.assertTrue(
board_meetings_page_perms.can_copy_to(self.board_meetings_page.get_parent())
)
# SingletonPageViaMaxCount.can_create_at() prevents copying, regardless of a user's permissions
self.assertFalse(
singleton_page_perms.can_copy_to(self.singleton_page.get_parent())
)
def test_not_even_a_superuser_can_copy_a_singleton_page(self):
user = get_user_model().objects.get(email="superuser@example.com")
# Create PagePermissionTester object for this user, for each page
board_meetings_page_perms = self.board_meetings_page.permissions_for_user(user)
event_page_perms = self.event_page.permissions_for_user(user)
singleton_page_perms = self.singleton_page.permissions_for_user(user)
# A superuser has full permissions, so these are self explanatory
self.assertTrue(event_page_perms.can_copy_to(self.event_page.get_parent()))
self.assertTrue(
board_meetings_page_perms.can_copy_to(self.board_meetings_page.get_parent())
)
# However, SingletonPageViaMaxCount.can_create_at() prevents copying, regardless of a user's permissions
self.assertFalse(
singleton_page_perms.can_copy_to(self.singleton_page.get_parent())
)
class TestPagePermissionModel(TestCase):
fixtures = [
"test.json",
]
def test_create_with_permission_type_only(self):
user = get_user_model().objects.get(email="eventmoderator@example.com")
page = Page.objects.get(url_path="/home/secret-plans/steal-underpants/")
group_permission = GroupPagePermission.objects.create(
group=user.groups.first(), page=page, permission_type="add"
)
self.assertEqual(group_permission.permission.codename, "add_page")

View File

@@ -0,0 +1,233 @@
from django.contrib.auth.models import Group
from django.test import TestCase, override_settings
from wagtail.models import Page, PageViewRestriction
from wagtail.test.utils import WagtailTestUtils
class TestPagePrivacy(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
self.secret_plans_page = Page.objects.get(url_path="/home/secret-plans/")
self.view_restriction = PageViewRestriction.objects.get(
page=self.secret_plans_page
)
self.secret_event_editor_plans_page = Page.objects.get(
url_path="/home/secret-event-editor-plans/"
)
self.event_editors_group = Group.objects.get(name="Event editors")
self.secret_login_plans_page = Page.objects.get(
url_path="/home/secret-login-plans/"
)
def test_anonymous_user_must_authenticate(self):
response = self.client.get("/secret-plans/")
self.assertEqual(
response.templates[0].name, "wagtailcore/password_required.html"
)
submit_url = "/_util/authenticate_with_password/%d/%d/" % (
self.view_restriction.id,
self.secret_plans_page.id,
)
self.assertContains(response, '<form action="%s"' % submit_url)
self.assertContains(
response,
'<input id="id_return_url" name="return_url" type="hidden" value="/secret-plans/" />',
html=True,
)
# posting the wrong password should redisplay the password page
response = self.client.post(
submit_url,
{
"password": "wrongpassword",
"return_url": "/secret-plans/",
},
)
self.assertEqual(
response.templates[0].name, "wagtailcore/password_required.html"
)
self.assertContains(response, '<form action="%s"' % submit_url)
# posting the correct password should redirect back to return_url
response = self.client.post(
submit_url,
{
"password": "swordfish",
"return_url": "/secret-plans/",
},
)
self.assertRedirects(response, "/secret-plans/")
# now requests to /secret-plans/ should pass authentication
response = self.client.get("/secret-plans/")
self.assertEqual(response.templates[0].name, "tests/simple_page.html")
self.client.logout()
# posting an invalid return_url will redirect to default login redirect
with self.settings(LOGIN_REDIRECT_URL="/"):
response = self.client.post(
submit_url,
{
"password": "swordfish",
"return_url": "https://invaliddomain.com",
},
)
self.assertRedirects(response, "/")
@override_settings(
WAGTAIL_PASSWORD_REQUIRED_TEMPLATE="tests/custom_page_password_required.html"
)
def test_anonymous_user_must_authenticate_with_custom_password_required_template(
self
):
response = self.client.get("/secret-plans/")
self.assertNotEqual(
"wagtailcore/password_required.html",
response.templates[0].name,
)
self.assertEqual(
"tests/custom_page_password_required.html",
response.templates[0].name,
)
def test_view_restrictions_apply_to_subpages(self):
underpants_page = Page.objects.get(
url_path="/home/secret-plans/steal-underpants/"
)
response = self.client.get("/secret-plans/steal-underpants/")
# check that we're overriding the default password_required template for this page type
self.assertEqual(
response.templates[0].name, "tests/event_page_password_required.html"
)
submit_url = "/_util/authenticate_with_password/%d/%d/" % (
self.view_restriction.id,
underpants_page.id,
)
self.assertContains(response, "<title>Steal underpants</title>")
self.assertContains(response, '<form action="%s"' % submit_url)
self.assertContains(
response,
'<input id="id_return_url" name="return_url" type="hidden" value="/secret-plans/steal-underpants/" />',
html=True,
)
# posting the wrong password should redisplay the password page
response = self.client.post(
submit_url,
{
"password": "wrongpassword",
"return_url": "/secret-plans/steal-underpants/",
},
)
self.assertEqual(
response.templates[0].name, "tests/event_page_password_required.html"
)
self.assertContains(response, '<form action="%s"' % submit_url)
# posting the correct password should redirect back to return_url
response = self.client.post(
submit_url,
{
"password": "swordfish",
"return_url": "/secret-plans/steal-underpants/",
},
)
self.assertRedirects(response, "/secret-plans/steal-underpants/")
# now requests to /secret-plans/ should pass authentication
response = self.client.get("/secret-plans/steal-underpants/")
self.assertEqual(response.templates[0].name, "tests/event_page.html")
def test_view_restrictions_apply_to_aliases(self):
secret_plans_page = Page.objects.get(url_path="/home/secret-plans/")
secret_plans_alias_page = secret_plans_page.create_alias(
update_slug="alias-secret-plans"
)
response = self.client.get("/alias-secret-plans/")
self.assertEqual(
response.templates[0].name, "wagtailcore/password_required.html"
)
submit_url = "/_util/authenticate_with_password/%d/%d/" % (
self.view_restriction.id,
secret_plans_alias_page.id,
)
self.assertContains(response, '<form action="%s"' % submit_url)
self.assertContains(
response,
'<input id="id_return_url" name="return_url" type="hidden" value="/alias-secret-plans/" />',
html=True,
)
def test_view_restrictions_apply_to_subpages_of_aliases(self):
secret_plans_page = Page.objects.get(url_path="/home/secret-plans/")
secret_plans_alias_page = secret_plans_page.create_alias(
update_slug="alias-secret-plans"
)
underpants_page = Page.objects.get(
url_path="/home/secret-plans/steal-underpants/"
)
underpants_alias_page = underpants_page.create_alias(
parent=secret_plans_alias_page
)
response = self.client.get("/alias-secret-plans/steal-underpants/")
# check that we're overriding the default password_required template for this page type
self.assertEqual(
response.templates[0].name, "tests/event_page_password_required.html"
)
submit_url = "/_util/authenticate_with_password/%d/%d/" % (
self.view_restriction.id,
underpants_alias_page.id,
)
self.assertContains(response, "<title>Steal underpants</title>")
self.assertContains(response, '<form action="%s"' % submit_url)
self.assertContains(
response,
'<input id="id_return_url" name="return_url" type="hidden" value="/alias-secret-plans/steal-underpants/" />',
html=True,
)
def test_group_restriction_with_anonymous_user(self):
response = self.client.get("/secret-event-editor-plans/")
self.assertRedirects(response, "/_util/login/?next=/secret-event-editor-plans/")
def test_group_restriction_with_unpermitted_user(self):
self.login(username="eventmoderator", password="password")
response = self.client.get("/secret-event-editor-plans/")
self.assertRedirects(response, "/_util/login/?next=/secret-event-editor-plans/")
def test_group_restriction_with_permitted_user(self):
self.login(username="eventeditor", password="password")
response = self.client.get("/secret-event-editor-plans/")
self.assertEqual(response.status_code, 200)
self.assertContains(response, "<title>Secret event editor plans</title>")
def test_group_restriction_with_superuser(self):
self.login(username="superuser", password="password")
response = self.client.get("/secret-event-editor-plans/")
self.assertEqual(response.status_code, 200)
self.assertContains(response, "<title>Secret event editor plans</title>")
def test_login_restriction_with_anonymous_user(self):
response = self.client.get("/secret-login-plans/")
self.assertRedirects(response, "/_util/login/?next=/secret-login-plans/")
def test_login_restriction_with_logged_in_user(self):
self.login(username="eventmoderator", password="password")
response = self.client.get("/secret-login-plans/")
self.assertEqual(response.status_code, 200)
self.assertContains(response, "<title>Secret login plans</title>")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,480 @@
from io import StringIO
from django.contrib.contenttypes.models import ContentType
from django.core import management
from django.test import TestCase
from django.utils.functional import SimpleLazyObject
from wagtail.blocks import StreamValue, StructValue
from wagtail.documents import get_document_model
from wagtail.documents.tests.utils import get_test_document_file
from wagtail.images import get_image_model
from wagtail.images.tests.utils import get_test_image_file
from wagtail.models import Page, ReferenceIndex
from wagtail.rich_text import RichText
from wagtail.test.testapp.models import (
Advert,
AdvertWithCustomUUIDPrimaryKey,
EventPage,
EventPageCarouselItem,
EventPageRelatedLink,
GenericSnippetNoFieldIndexPage,
GenericSnippetNoIndexPage,
GenericSnippetPage,
ModelWithNullableParentalKey,
VariousOnDeleteModel,
)
class TestCreateOrUpdateForObject(TestCase):
def setUp(self):
image_model = get_image_model()
self.image_content_type = ContentType.objects.get_for_model(image_model)
self.test_feed_image = image_model.objects.create(
title="Test feed image",
file=get_test_image_file(),
)
self.test_image_1 = image_model.objects.create(
title="Test image 1",
file=get_test_image_file(),
)
self.test_image_2 = image_model.objects.create(
title="Test image 2",
file=get_test_image_file(),
)
# Add event page
self.event_page = EventPage(
title="Event page",
slug="event-page",
location="the moon",
audience="public",
cost="free",
date_from="2001-01-01",
feed_image=self.test_feed_image,
)
self.event_page.carousel_items = [
EventPageCarouselItem(
caption="1234567", image=self.test_image_1, sort_order=1
),
EventPageCarouselItem(
caption="7654321", image=self.test_image_2, sort_order=2
),
EventPageCarouselItem(
caption="abcdefg", image=self.test_image_1, sort_order=3
),
]
self.root_page = Page.objects.get(id=2)
self.root_page.add_child(instance=self.event_page)
self.expected_references = {
(
self.image_content_type.id,
str(self.test_feed_image.pk),
"feed_image",
"feed_image",
),
(
self.image_content_type.id,
str(self.test_image_1.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=1).id}.image",
),
(
self.image_content_type.id,
str(self.test_image_2.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=2).id}.image",
),
(
self.image_content_type.id,
str(self.test_image_1.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=3).id}.image",
),
}
def test(self):
self.assertSetEqual(
set(
ReferenceIndex.get_references_for_object(self.event_page).values_list(
"to_content_type", "to_object_id", "model_path", "content_path"
)
),
self.expected_references,
)
def test_update(self):
reference_to_keep = ReferenceIndex.objects.get(
base_content_type=ReferenceIndex._get_base_content_type(self.event_page),
content_type=ContentType.objects.get_for_model(self.event_page),
content_path="feed_image",
)
reference_to_remove = ReferenceIndex.objects.create(
base_content_type=ReferenceIndex._get_base_content_type(self.event_page),
content_type=ContentType.objects.get_for_model(self.event_page),
object_id=self.event_page.pk,
to_content_type=self.image_content_type,
to_object_id=self.test_image_1.pk,
model_path="hero_image", # Field doesn't exist
content_path="hero_image",
content_path_hash=ReferenceIndex._get_content_path_hash("hero_image"),
)
ReferenceIndex.create_or_update_for_object(self.event_page)
# Check that the record for the reference to be kept has been preserved/reused
self.assertTrue(ReferenceIndex.objects.filter(id=reference_to_keep.id).exists())
# Check that the record for the reference to be removed has been deleted
self.assertFalse(
ReferenceIndex.objects.filter(id=reference_to_remove.id).exists()
)
# Check that the current stored references are correct
self.assertSetEqual(
set(
ReferenceIndex.get_references_for_object(self.event_page).values_list(
"to_content_type", "to_object_id", "model_path", "content_path"
)
),
{
(
self.image_content_type.id,
str(self.test_feed_image.pk),
"feed_image",
"feed_image",
),
(
self.image_content_type.id,
str(self.test_image_1.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=1).id}.image",
),
(
self.image_content_type.id,
str(self.test_image_2.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=2).id}.image",
),
(
self.image_content_type.id,
str(self.test_image_1.pk),
"carousel_items.item.image",
f"carousel_items.{self.event_page.carousel_items.get(sort_order=3).id}.image",
),
},
)
def test_saving_base_model_does_not_remove_references(self):
page = Page.objects.get(pk=self.event_page.pk)
page.save()
self.assertSetEqual(
set(
ReferenceIndex.get_references_for_object(self.event_page).values_list(
"to_content_type", "to_object_id", "model_path", "content_path"
)
),
self.expected_references,
)
def test_null_parental_key(self):
obj = ModelWithNullableParentalKey(
content="""<p><a linktype="page" id="%d">event page</a></p>"""
% self.event_page.id
)
obj.save()
# Models with a ParentalKey are not considered indexable - references are recorded against the parent model
# instead. Since the ParentalKey is null here, no reference will be recorded.
refs = ReferenceIndex.get_references_to(self.event_page)
self.assertEqual(refs.count(), 0)
def test_lazy_parental_key(self):
event_page_related_link = EventPageRelatedLink()
# The parent model is a lazy object
event_page_related_link.page = SimpleLazyObject(lambda: self.event_page)
event_page_related_link.link_page = self.root_page
event_page_related_link.save()
refs = ReferenceIndex.get_references_to(self.root_page)
self.assertEqual(refs.count(), 1)
def test_generic_foreign_key(self):
page1 = GenericSnippetPage(
title="generic snippet page", snippet_content_object=self.event_page
)
self.root_page.add_child(instance=page1)
page2 = GenericSnippetPage(
title="generic snippet page", snippet_content_object=None
)
self.root_page.add_child(instance=page2)
refs = ReferenceIndex.get_references_to(self.event_page)
self.assertEqual(refs.count(), 1)
def test_model_index_ignore_generic_foreign_key(self):
page1 = GenericSnippetNoIndexPage(
title="generic snippet page", snippet_content_object=self.event_page
)
self.root_page.add_child(instance=page1)
page2 = GenericSnippetNoIndexPage(
title="generic snippet page", snippet_content_object=None
)
self.root_page.add_child(instance=page2)
# There should be no references
refs = ReferenceIndex.get_references_to(self.event_page)
self.assertEqual(refs.count(), 0)
def test_model_field_index_ignore_generic_foreign_key(self):
content_type = ContentType.objects.get_for_model(self.event_page)
page1 = GenericSnippetNoFieldIndexPage(
title="generic snippet page", snippet_content_type_nonindexed=content_type
)
self.root_page.add_child(instance=page1)
page2 = GenericSnippetNoFieldIndexPage(
title="generic snippet page", snippet_content_type_nonindexed=None
)
self.root_page.add_child(instance=page2)
# There should be no references
refs = ReferenceIndex.get_references_to(content_type)
self.assertEqual(refs.count(), 0)
def test_rebuild_references_index_no_verbosity(self):
stdout = StringIO()
management.call_command(
"rebuild_references_index",
verbosity=0,
stdout=stdout,
)
self.assertFalse(stdout.getvalue())
def test_show_references_index(self):
stdout = StringIO()
management.call_command(
"show_references_index",
stdout=stdout,
)
self.assertIn(" 3 wagtail.images.models.Image", stdout.getvalue())
self.assertIn(" 4 wagtail.test.testapp.models.EventPage", stdout.getvalue())
class TestDescribeOnDelete(TestCase):
fixtures = ["test.json"]
@classmethod
def setUpTestData(cls):
management.call_command("rebuild_references_index", stdout=StringIO())
def setUp(self):
field = VariousOnDeleteModel._meta.get_field("stream_field")
advertisement_content = field.stream_block.child_blocks["advertisement_content"]
captioned_advert = advertisement_content.child_blocks["captioned_advert"]
self.advert = Advert.objects.create(text="An advertisement")
self.advert_uuid = AdvertWithCustomUUIDPrimaryKey.objects.create(
text="A UUID advertisement"
)
self.page = EventPage.objects.first()
page_link = f'<p>Link to <a id="{self.page.id}" linktype="page">a page</a></p>'
self.image = get_image_model().objects.create(
title="My image",
file=get_test_image_file(),
)
self.document = get_document_model().objects.create(
title="My document",
file=get_test_document_file(),
)
# Each case is a tuple of (
# VariousOnDeleteModel init kwargs,
# referred object,
# expected field description,
# expected on delete description,
# )
self.cases = [
# References from ForeignKey
(
{"text": "on_delete=CASCADE", "on_delete_cascade": self.advert},
self.advert,
"On delete cascade",
"the various on delete model will also be deleted",
),
(
{"text": "on_delete=PROTECT", "on_delete_protect": self.advert},
self.advert,
"On delete protect",
"prevents deletion",
),
(
{"text": "on_delete=RESTRICT", "on_delete_restrict": self.advert},
self.advert,
"On delete restrict",
"may prevent deletion",
),
(
{"text": "on_delete=SET_NULL", "on_delete_set_null": self.advert},
self.advert,
"On delete set null",
"will unset the reference",
),
(
{"text": "on_delete=SET_DEFAULT", "on_delete_set_default": self.advert},
self.advert,
"On delete set default",
"will be set to the default various on delete model",
),
(
{"text": "on_delete=SET", "on_delete_set": self.advert},
self.advert,
"On delete set",
"will be set to a various on delete model specified by the system",
),
(
{"text": "on_delete=DO_NOTHING", "on_delete_do_nothing": self.advert},
self.advert,
"On delete do nothing",
"will do nothing",
),
# References from GenericForeignKey
(
{"text": "GenericForeignKey", "content_object": self.advert_uuid},
self.advert_uuid,
"Content object",
"will unset the reference",
),
# References from RichTextField
(
{"text": "RichTextField model field", "rich_text": page_link},
self.page,
"Rich text",
"will unset the reference",
),
(
{
"text": "deep RichTextBlock",
"stream_field": StreamValue(
field.stream_block,
[
(
"advertisement_content",
StreamValue(
advertisement_content,
[
(
"rich_text",
RichText(page_link),
)
],
),
)
],
),
},
self.page,
"Stream field → Advertisement content → Rich text",
"will unset the reference",
),
# References from StreamField
(
{
"text": "deep SnippetChooserBlock",
"stream_field": StreamValue(
field.stream_block,
[
(
"advertisement_content",
StreamValue(
advertisement_content,
[
(
"captioned_advert",
StructValue(
captioned_advert,
[
("advert", self.advert),
("caption", "Deep text"),
],
),
)
],
),
)
],
),
},
self.advert,
"Stream field → Advertisement content → Captioned advert",
"will unset the reference",
),
(
{
"text": "ImageChooserBlock",
"stream_field": StreamValue(
field.stream_block, [("image", self.image)]
),
},
self.image,
"Stream field → Image",
"will unset the reference",
),
(
{
"text": "DocumentChooserBlock",
"stream_field": StreamValue(
field.stream_block, [("document", self.document)]
),
},
self.document,
"Stream field → Document",
"will unset the reference",
),
]
def test_describe_source_field_and_on_delete(self):
for (
init_kwargs,
referred_object,
field_description,
on_delete_description,
) in self.cases:
with self.subTest(test=init_kwargs["text"]):
# Explicitly pass None to this field so that it is not set to
# the default value for test cases other than the SET_DEFAULT case
if "on_delete_set_default" not in init_kwargs:
init_kwargs["on_delete_set_default"] = None
obj = VariousOnDeleteModel.objects.create(**init_kwargs)
usage = ReferenceIndex.get_references_to(
referred_object
).group_by_source_object()
referrer, references = usage[0]
reference = references[0]
self.assertIs(usage.is_protected, "on_delete_protect" in init_kwargs)
self.assertEqual(usage.count(), 1)
self.assertEqual(referrer, obj)
self.assertEqual(len(references), 1)
self.assertEqual(reference.describe_source_field(), field_description)
self.assertEqual(reference.describe_on_delete(), on_delete_description)
obj.delete()
def test_describe_source_field_and_on_delete_parental_key(self):
# The test fixtures contain two references to the advert:
# 1. One advert placement on the home page
# 2. One advert placement on the Christmas page
advert = Advert.objects.first()
usage = ReferenceIndex.get_references_to(advert).group_by_source_object()
self.assertEqual(usage.count(), 2)
for _, references in usage:
reference = references[0]
self.assertEqual(reference.describe_source_field(), "Advert")
self.assertEqual(
reference.describe_on_delete(),
"the advert placement will also be deleted",
)

View File

@@ -0,0 +1,193 @@
import datetime
from django.contrib.contenttypes.models import ContentType
from django.test import TestCase
from freezegun import freeze_time
from wagtail.models import Page, Revision, get_default_page_content_type
from wagtail.test.testapp.models import (
FullFeaturedSnippet,
RevisableGrandChildModel,
RevisableModel,
SimplePage,
)
class TestRevisableModel(TestCase):
@classmethod
def setUpTestData(cls):
cls.instance = RevisableModel.objects.create(text="foo")
cls.content_type = ContentType.objects.get_for_model(RevisableModel)
@classmethod
def create_page(cls):
homepage = Page.objects.get(url_path="/home/")
hello_page = SimplePage(
title="Hello world", slug="hello-world", content="hello"
)
homepage.add_child(instance=hello_page)
return hello_page
def test_can_save_revision(self):
self.instance.text = "updated"
revision = self.instance.save_revision()
revision_from_db = self.instance.revisions.first()
self.instance.refresh_from_db()
self.assertEqual(revision, revision_from_db)
# The latest revision should be set
self.assertEqual(self.instance.latest_revision, revision_from_db)
# The revision should have the updated data
self.assertEqual(revision_from_db.content["text"], "updated")
# Only saving a revision should not update the instance itself
self.assertEqual(self.instance.text, "foo")
def test_get_latest_revision_exists(self):
self.instance.text = "updated"
self.instance.save_revision()
self.instance.text = "updated twice"
revision = self.instance.save_revision()
self.instance.refresh_from_db()
with self.assertNumQueries(1):
# Should be able to query directly using latest_revision ForeignKey
revision_from_db = self.instance.get_latest_revision()
self.assertEqual(revision, revision_from_db)
self.assertEqual(revision_from_db.content["text"], "updated twice")
def test_content_type_without_inheritance(self):
self.instance.text = "updated"
revision = self.instance.save_revision()
revision_from_db = Revision.objects.filter(
base_content_type=self.content_type,
content_type=self.content_type,
object_id=self.instance.pk,
).first()
self.assertEqual(revision, revision_from_db)
self.assertEqual(self.instance.get_base_content_type(), self.content_type)
self.assertEqual(self.instance.get_content_type(), self.content_type)
def test_content_type_with_inheritance(self):
instance = RevisableGrandChildModel.objects.create(text="test")
instance.text = "test updated"
revision = instance.save_revision()
base_content_type = self.content_type
content_type = ContentType.objects.get_for_model(RevisableGrandChildModel)
revision_from_db = Revision.objects.filter(
base_content_type=base_content_type,
content_type=content_type,
object_id=instance.pk,
).first()
self.assertEqual(revision, revision_from_db)
self.assertEqual(instance.get_base_content_type(), base_content_type)
self.assertEqual(instance.get_content_type(), content_type)
def test_content_type_for_page_model(self):
hello_page = self.create_page()
hello_page.content = "Updated world"
revision = hello_page.save_revision()
base_content_type = get_default_page_content_type()
content_type = ContentType.objects.get_for_model(SimplePage)
revision_from_db = Revision.objects.filter(
base_content_type=base_content_type,
content_type=content_type,
object_id=hello_page.pk,
).first()
self.assertEqual(revision, revision_from_db)
self.assertEqual(hello_page.get_base_content_type(), base_content_type)
self.assertEqual(hello_page.get_content_type(), content_type)
def test_as_object(self):
self.instance.text = "updated"
self.instance.save_revision()
self.instance.refresh_from_db()
revision = self.instance.revisions.first()
instance = revision.as_object()
self.assertIsInstance(instance, RevisableModel)
# The instance created from the revision should be updated
self.assertEqual(instance.text, "updated")
# Only saving a revision should not update the instance itself
self.assertEqual(self.instance.text, "foo")
def test_as_object_with_page(self):
hello_page = self.create_page()
hello_page.content = "updated"
hello_page.save_revision()
hello_page.refresh_from_db()
revision = hello_page.revisions.first()
instance = revision.as_object()
# The instance should be of the specific page class.
self.assertIsInstance(instance, SimplePage)
self.assertEqual(instance.content, "updated")
self.assertEqual(hello_page.content, "hello")
def test_is_latest_revision_newer_creation_date_and_id(self):
first = self.instance.save_revision()
self.assertTrue(first.is_latest_revision())
second = self.instance.save_revision()
self.assertFalse(first.is_latest_revision())
self.assertTrue(second.is_latest_revision())
# Normal case, both creation date and id are newer
self.assertLess(first.created_at, second.created_at)
self.assertLess(first.id, second.id)
def test_is_latest_revision_newer_creation_date_older_id(self):
first = self.instance.save_revision()
self.assertTrue(first.is_latest_revision())
second = self.instance.save_revision()
first.created_at = second.created_at + datetime.timedelta(days=9)
first.save()
self.assertTrue(first.is_latest_revision())
self.assertFalse(second.is_latest_revision())
# The creation date takes precedence over the id
self.assertGreater(first.created_at, second.created_at)
self.assertLess(first.id, second.id)
@freeze_time("2023-01-19")
def test_is_latest_revision_same_creation_dates(self):
first = self.instance.save_revision()
self.assertTrue(first.is_latest_revision())
second = self.instance.save_revision()
self.assertFalse(first.is_latest_revision())
self.assertTrue(second.is_latest_revision())
# The id is used as a tie breaker
self.assertEqual(first.created_at, second.created_at)
self.assertLess(first.id, second.id)
def test_revision_cascade_on_object_delete(self):
page = self.create_page()
full_featured_snippet = FullFeaturedSnippet.objects.create(text="foo")
cases = [
# Tuple of (instance, cascades)
# For models that define a GenericRelation to Revision, the revision
# should be deleted when the instance is deleted.
(page, True),
(full_featured_snippet, True),
(self.instance, False), # No GenericRelation to Revision
]
for instance, cascades in cases:
with self.subTest(instance=instance):
revision = instance.save_revision()
query = {
"base_content_type": instance.get_base_content_type(),
"object_id": str(instance.pk),
}
self.assertEqual(Revision.objects.filter(**query).first(), revision)
instance.delete()
self.assertIs(Revision.objects.filter(**query).exists(), not cascades)

View File

@@ -0,0 +1,437 @@
from unittest.mock import patch
from django.forms.models import modelform_factory
from django.test import TestCase, override_settings
from django.utils import translation
from wagtail.fields import RichTextField
from wagtail.models import Locale, Page, Site
from wagtail.rich_text import RichText, RichTextMaxLengthValidator, expand_db_html
from wagtail.rich_text.feature_registry import FeatureRegistry
from wagtail.rich_text.pages import PageLinkHandler
from wagtail.rich_text.rewriters import LinkRewriter, extract_attrs
from wagtail.test.testapp.models import EventIndex, EventPage
from wagtail.test.utils.form_data import rich_text
class TestPageLinktypeHandler(TestCase):
fixtures = ["test.json"]
def test_expand_db_attributes(self):
result = PageLinkHandler.expand_db_attributes(
{"id": Page.objects.get(url_path="/home/events/christmas/").id}
)
self.assertEqual(result, '<a href="/events/christmas/">')
def test_expand_db_attributes_page_does_not_exist(self):
result = PageLinkHandler.expand_db_attributes({"id": 0})
self.assertEqual(result, "<a>")
def test_expand_db_attributes_not_for_editor(self):
result = PageLinkHandler.expand_db_attributes({"id": 1})
self.assertEqual(result, '<a href="None">')
@override_settings(
WAGTAIL_I18N_ENABLED=True,
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("fr", "French"),
],
ROOT_URLCONF="wagtail.test.urls_multilang",
)
class TestPageLinktypeHandlerWithI18N(TestCase):
fixtures = ["test.json"]
def setUp(self):
self.fr_locale = Locale.objects.create(language_code="fr")
self.event_page = Page.objects.get(url_path="/home/events/christmas/")
self.fr_event_page = self.event_page.copy_for_translation(
self.fr_locale, copy_parents=True
)
self.fr_event_page.slug = "noel"
self.fr_event_page.save(update_fields=["slug"])
self.fr_event_page.save_revision().publish()
def test_expand_db_attributes(self):
result = PageLinkHandler.expand_db_attributes({"id": self.event_page.id})
self.assertEqual(result, '<a href="/en/events/christmas/">')
def test_expand_db_attributes_autolocalizes(self):
# Even though it's linked to the english page in rich text.
# The link should be to the local language version if it's available
with translation.override("fr"):
result = PageLinkHandler.expand_db_attributes({"id": self.event_page.id})
self.assertEqual(result, '<a href="/fr/events/noel/">')
def test_expand_db_attributes_doesnt_autolocalize_unpublished_page(self):
# We shouldn't autolocalize if the translation is unpublished
self.fr_event_page.unpublish()
self.fr_event_page.save()
with translation.override("fr"):
result = PageLinkHandler.expand_db_attributes({"id": self.event_page.id})
self.assertEqual(result, '<a href="/en/events/christmas/">')
class TestExtractAttrs(TestCase):
def test_extract_attr(self):
html = '<a foo="bar" baz="quux">snowman</a>'
result = extract_attrs(html)
self.assertEqual(result, {"foo": "bar", "baz": "quux"})
class TestExpandDbHtml(TestCase):
fixtures = ["test.json"]
def test_expand_db_html_no_linktype(self):
html = '<a id="1">foo</a>'
result = expand_db_html(html)
self.assertEqual(result, '<a id="1">foo</a>')
def test_invalid_linktype_set_to_empty_link(self):
html = '<a id="1" linktype="invalid">foo</a>'
result = expand_db_html(html)
self.assertEqual(result, "<a>foo</a>")
def test_valid_linktype_and_reference(self):
html = '<a id="1" linktype="document">foo</a>'
result = expand_db_html(html)
self.assertEqual(result, '<a href="/documents/1/test.pdf">foo</a>')
def test_valid_linktype_invalid_reference_set_to_empty_link(self):
html = '<a id="9999" linktype="document">foo</a>'
result = expand_db_html(html)
self.assertEqual(result, "<a>foo</a>")
def test_no_embedtype_remove_tag(self):
self.assertEqual(expand_db_html('<embed id="1" />'), "")
def test_invalid_embedtype_remove_tag(self):
self.assertEqual(expand_db_html('<embed id="1" embedtype="invalid" />'), "")
@patch("wagtail.embeds.embeds.get_embed")
def test_expand_db_html_with_embed(self, get_embed):
from wagtail.embeds.models import Embed
get_embed.return_value = Embed(html="test html")
html = '<embed embedtype="media" url="http://www.youtube.com/watch" />'
result = expand_db_html(html)
self.assertIn("test html", result)
# Override CACHES so we don't generate any cache-related SQL queries
# for page site root paths (tests use DatabaseCache otherwise).
@override_settings(
CACHES={
"default": {
"BACKEND": "django.core.cache.backends.locmem.LocMemCache",
},
}
)
def test_expand_db_html_database_queries_pages(self):
Site.clear_site_root_paths_cache()
with self.assertNumQueries(5):
expand_db_html(
"""
This rich text has 8 page links, and this test verifies that the code uses the
minimal number of database queries (5) to expand them.
All of these pages should be retrieved with 4 queries, one to do the base
Page table lookup and then 1 each for the EventIndex, EventPage, and
SimplePage tables.
<a linktype="page" id="3">This links to an EventIndex page.</a>
<a linktype="page" id="4">This links to an EventPage page.</a>
<a linktype="page" id="5">This links to an EventPage page.</a>
<a linktype="page" id="6">This links to an EventPage page.</a>
<a linktype="page" id="9">This links to an EventPage page.</a>
<a linktype="page" id="12">This links to an EventPage page.</a>
<a linktype="page" id="7">This links to a SimplePage page.</a>
<a linktype="page" id="11">This links to a SimplePage page.</a>
Finally there's one additional query needed to do the Site root paths lookup.
"""
)
def test_expand_db_html_database_queries_documents(self):
with self.assertNumQueries(1):
expand_db_html(
html="""
This rich text has 2 document links, and this test verifies that the code uses
the minimal number of database queries (1) to expand them.
Both of these documents should be retrieved with 1 query:
<a linktype="document" id="1">This links to a document.</a>
<a linktype="document" id="2">This links to another document.</a>
"""
)
# Disable rendition cache that might be populated by other tests.
@override_settings(
CACHES={
"renditions": {
"BACKEND": "django.core.cache.backends.dummy.DummyCache",
},
}
)
def test_expand_db_html_database_queries_images(self):
with self.assertNumQueries(3):
expand_db_html(
"""
This rich text has 2 image links, and this test verifies that the code uses the
minimal number of database queries (3) to expand them.
Both of these images should be retrieved with 3 queries, one to fetch the
image objects in bulk and then one per image to fetch their renditions:
This is an image: <embed embedtype="image" id="1" format="left" />
This is another image: <embed embedtype="image" id="2" format="left" />
"""
)
def test_expand_db_html_mixed_link_types(self):
self.assertEqual(
expand_db_html(
'<a href="https://wagtail.org/">foo</a>'
'<a linktype="page" id="3">bar</a>'
),
'<a href="https://wagtail.org/">foo</a><a href="/events/">bar</a>',
)
self.assertEqual(
expand_db_html(
'<a linktype="page" id="3">page</a>'
'<a linktype="document" id="1">document</a>'
'<a linktype="page" id="3">page</a>'
),
(
'<a href="/events/">page</a>'
'<a href="/documents/1/test.pdf">document</a>'
'<a href="/events/">page</a>'
),
)
class TestRichTextValue(TestCase):
fixtures = ["test.json"]
def test_construct_with_none(self):
value = RichText(None)
self.assertEqual(value.source, "")
def test_construct_with_empty_string(self):
value = RichText("")
self.assertEqual(value.source, "")
def test_construct_with_nonempty_string(self):
value = RichText("<p>hello world</p>")
self.assertEqual(value.source, "<p>hello world</p>")
def test_render(self):
value = RichText('<p>Merry <a linktype="page" id="4">Christmas</a>!</p>')
result = str(value)
self.assertEqual(
result, '<p>Merry <a href="/events/christmas/">Christmas</a>!</p>'
)
def test_evaluate_value(self):
value = RichText(None)
self.assertFalse(value)
value = RichText("<p>wagtail</p>")
self.assertTrue(value)
def test_compare_value(self):
value1 = RichText("<p>wagtail</p>")
value2 = RichText("<p>wagtail</p>")
value3 = RichText("<p>django</p>")
self.assertNotEqual(value1, value3)
self.assertNotEqual(value1, 12345)
self.assertEqual(value1, value2)
class TestFeatureRegistry(TestCase):
def test_register_rich_text_features_hook(self):
# testapp/wagtail_hooks.py defines a 'blockquote' rich text feature with a Draftail
# plugin, via the register_rich_text_features hook; test that we can retrieve it here
features = FeatureRegistry()
quotation = features.get_editor_plugin("draftail", "quotation")
self.assertEqual(quotation.js, ["testapp/js/draftail-quotation.js"])
def test_missing_editor_plugin_returns_none(self):
features = FeatureRegistry()
self.assertIsNone(features.get_editor_plugin("made_up_editor", "blockquote"))
self.assertIsNone(features.get_editor_plugin("draftail", "made_up_feature"))
class TestLinkRewriterTagReplacing(TestCase):
def test_should_follow_default_behaviour(self):
# we always have default `page` rules registered.
rules = {"page": lambda attrs: '<a href="/article/{}">'.format(attrs["id"])}
rewriter = LinkRewriter(rules)
page_type_link = rewriter('<a linktype="page" id="3">')
self.assertEqual(page_type_link, '<a href="/article/3">')
# but it should also be able to handle other supported
# link types (email, external, anchor) even if no rules is provided
external_type_link = rewriter('<a href="https://wagtail.org/">')
self.assertEqual(external_type_link, '<a href="https://wagtail.org/">')
email_type_link = rewriter('<a href="mailto:test@wagtail.org">')
self.assertEqual(email_type_link, '<a href="mailto:test@wagtail.org">')
anchor_type_link = rewriter('<a href="#test">')
self.assertEqual(anchor_type_link, '<a href="#test">')
# As well as link which don't have any linktypes
link_without_linktype = rewriter('<a data-link="https://wagtail.org">')
self.assertEqual(link_without_linktype, '<a data-link="https://wagtail.org">')
# But should not handle if a custom linktype is mentioned but no
# associate rules are registered.
link_with_custom_linktype = rewriter(
'<a linktype="custom" href="https://wagtail.org">'
)
self.assertNotEqual(link_with_custom_linktype, '<a href="https://wagtail.org">')
self.assertEqual(link_with_custom_linktype, "<a>")
# And should properly handle mixed linktypes.
self.assertEqual(
rewriter('<a href="https://wagtail.org/"><a linktype="page" id="3">'),
'<a href="https://wagtail.org/"><a href="/article/3">',
)
def test_supported_type_should_follow_given_rules(self):
# we always have `page` rules by default
rules = {
"page": lambda attrs: '<a href="/article/{}">'.format(attrs["id"]),
"external": lambda attrs: '<a rel="nofollow" href="{}">'.format(
attrs["href"]
),
"email": lambda attrs: '<a data-email="true" href="{}">'.format(
attrs["href"]
),
"anchor": lambda attrs: '<a data-anchor="true" href="{}">'.format(
attrs["href"]
),
"custom": lambda attrs: '<a data-phone="true" href="{}">'.format(
attrs["href"]
),
}
rewriter = LinkRewriter(rules)
page_type_link = rewriter('<a linktype="page" id="3">')
self.assertEqual(page_type_link, '<a href="/article/3">')
# It should call appropriate rule supported linktypes (external or email)
# based on the href value
external_type_link = rewriter('<a href="https://wagtail.org/">')
self.assertEqual(
external_type_link, '<a rel="nofollow" href="https://wagtail.org/">'
)
external_type_link_http = rewriter('<a href="http://wagtail.org/">')
self.assertEqual(
external_type_link_http, '<a rel="nofollow" href="http://wagtail.org/">'
)
email_type_link = rewriter('<a href="mailto:test@wagtail.org">')
self.assertEqual(
email_type_link, '<a data-email="true" href="mailto:test@wagtail.org">'
)
anchor_type_link = rewriter('<a href="#test">')
self.assertEqual(anchor_type_link, '<a data-anchor="true" href="#test">')
# But not the unsupported ones.
link_with_no_linktype = rewriter('<a href="tel:+4917640206387">')
self.assertEqual(link_with_no_linktype, '<a href="tel:+4917640206387">')
# Also call the rule if a custom linktype is mentioned.
link_with_custom_linktype = rewriter(
'<a linktype="custom" href="tel:+4917640206387">'
)
self.assertEqual(
link_with_custom_linktype, '<a data-phone="true" href="tel:+4917640206387">'
)
class TestRichTextField(TestCase):
fixtures = ["test.json"]
def test_get_searchable_content(self):
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
christmas_page.body = '<p><b>Merry Christmas from <a href="https://wagtail.org/">Wagtail!</a></b> &amp; co.</p>'
christmas_page.save_revision()
body_field = christmas_page._meta.get_field("body")
value = body_field.value_from_object(christmas_page)
result = body_field.get_searchable_content(value)
self.assertEqual(result, ["Merry Christmas from Wagtail! & co."])
def test_get_searchable_content_whitespace(self):
christmas_page = EventPage.objects.get(url_path="/home/events/christmas/")
christmas_page.body = "<p>buttery<br />mashed</p><p>po<i>ta</i>toes</p>"
christmas_page.save_revision()
body_field = christmas_page._meta.get_field("body")
value = body_field.value_from_object(christmas_page)
result = body_field.get_searchable_content(value)
self.assertEqual(result, ["buttery mashed potatoes"])
def test_max_length_validation(self):
EventIndexForm = modelform_factory(model=EventIndex, fields=["intro"])
form = EventIndexForm(
{"intro": rich_text("<p><i>less</i> than 50 characters</p>")}
)
self.assertTrue(form.is_valid())
form = EventIndexForm(
{
"intro": rich_text(
"<p>a piece of text that is considerably longer than the limit of fifty characters of text</p>"
)
}
)
self.assertFalse(form.is_valid())
form = EventIndexForm(
{
"intro": rich_text(
'<p><a href="http://a-domain-name-that-would-put-us-over-the-limit-if-we-were-counting-it.example.com/">less</a> than 50 characters</p>'
)
}
)
self.assertTrue(form.is_valid())
def test_extract_references(self):
self.assertEqual(
list(
RichTextField().extract_references(
'<a linktype="page" id="1">Link to an internal page</a>'
)
),
[(Page, "1", "", "")],
)
class TestRichTextMaxLengthValidator(TestCase):
def test_count_characters(self):
"""Keep those tests up-to-date with MaxLength tests client-side."""
validator = RichTextMaxLengthValidator(50)
self.assertEqual(validator.clean("<p>Plain text</p>"), 10)
# HTML entities should be un-escaped.
self.assertEqual(validator.clean("<p>There&#x27;s quote</p>"), 13)
# BR should be ignored.
self.assertEqual(validator.clean("<p>Line<br/>break</p>"), 9)
# Content over multiple blocks should be treated as a single line of text with no joiner.
self.assertEqual(validator.clean("<p>Multi</p><p>blocks</p>"), 11)
# Empty blocks should be ignored.
self.assertEqual(validator.clean("<p>Empty</p><p></p><p>blocks</p>"), 11)
# HR should be ignored.
self.assertEqual(validator.clean("<p>With</p><hr/><p>HR</p>"), 6)
# Embed blocks should be ignored.
self.assertEqual(validator.clean("<p>With</p><embed/><p>embed</p>"), 9)
# Counts symbols with multiple code units (heart unicode + variation selector).
self.assertEqual(validator.clean("<p>U+2764 U+FE0F ❤️</p>"), 16)
# Counts symbols with zero-width joiners.
self.assertEqual(validator.clean("<p>👨‍👨‍👧</p>"), 5)

View File

@@ -0,0 +1,146 @@
from unittest import mock
from django.conf import settings
from django.test import TestCase
from wagtail.models import Locale, Site
from wagtail.signals import copy_for_translation_done, page_slug_changed
from wagtail.test.testapp.models import SimplePage
from wagtail.test.utils import WagtailTestUtils
class TestPageSlugChangedSignal(WagtailTestUtils, TestCase):
"""
Tests for the `wagtail.signals.page_slug_changed` signal
"""
def setUp(self):
# Find root page
site = Site.objects.select_related("root_page").get(is_default_site=True)
root_page = site.root_page
# Create two sections
self.section_a = SimplePage(
title="Section A", slug="section-a", content="hello"
)
root_page.add_child(instance=self.section_a)
self.section_b = SimplePage(
title="Section B", slug="section-b", content="hello"
)
root_page.add_child(instance=self.section_b)
# Add test page to section A
self.test_page = SimplePage(
title="Hello world! A", slug="hello-world-a", content="hello"
)
self.section_a.add_child(instance=self.test_page)
def test_signal_emitted_on_slug_change(self):
# Connect a mock signal handler to the signal
handler = mock.MagicMock()
page_slug_changed.connect(handler)
old_page = SimplePage.objects.get(id=self.test_page.id)
try:
self.test_page.slug = "updated"
with self.captureOnCommitCallbacks(execute=True):
self.test_page.save()
finally:
# Disconnect mock handler to prevent cross-test pollution
page_slug_changed.disconnect(handler)
# Check the signal was fired
self.assertEqual(handler.call_count, 1)
handler.assert_called_with(
signal=mock.ANY,
sender=SimplePage,
instance=self.test_page,
instance_before=old_page,
)
def test_signal_not_emitted_on_title_change(self):
# Connect a mock signal handler to the signal
handler = mock.MagicMock()
page_slug_changed.connect(handler)
try:
self.test_page.title = "Goodnight Moon!"
# NOTE: Even though we're not expecting anything to happen here,
# we need to invoke the callbacks via captureOnCommitCallbacks the same way
# the same way we do in ``test_signal_emitted_on_slug_change``,
# otherwise this test wouldn't prove anything.
with self.captureOnCommitCallbacks(execute=True):
self.test_page.save()
finally:
# Disconnect mock handler to prevent cross-test pollution
page_slug_changed.disconnect(handler)
# Check the signal was NOT fired
self.assertEqual(handler.call_count, 0)
def test_signal_not_emitted_on_page_move(self):
# Connect a mock signal handler to the signal
handler = mock.MagicMock()
page_slug_changed.connect(handler)
try:
# NOTE: Even though we're not expecting anything to happen here,
# we need to invoke the callbacks via captureOnCommitCallbacks the same way
# the same way we do in ``test_signal_emitted_on_slug_change``,
# otherwise this test wouldn't prove anything.
with self.captureOnCommitCallbacks(execute=True):
self.test_page.move(self.section_b, pos="last-child")
finally:
# Disconnect mock handler to prevent cross-test pollution
page_slug_changed.disconnect(handler)
# Check the signal was NOT fired
self.assertEqual(handler.call_count, 0)
class TestCopyForTranslationDoneSignal(WagtailTestUtils, TestCase):
"""
Tests for the `wagtail.signals.copy_for_translation_done` signal
"""
def setUp(self):
# Find root page
site = Site.objects.select_related("root_page").get(is_default_site=True)
root_page = site.root_page
# Create a subpage
self.subpage = SimplePage(
title="Subpage in english", slug="subpage-in-english", content="hello"
)
root_page.add_child(instance=self.subpage)
# Get the languages and create locales
language_codes = dict(settings.LANGUAGES).keys()
for language_code in language_codes:
Locale.objects.get_or_create(language_code=language_code)
# Get the locales needed
self.locale = Locale.objects.get(language_code="en")
self.another_locale = Locale.objects.get(language_code="fr")
root_page.copy_for_translation(self.another_locale)
def test_signal_emitted_on_copy_for_translation_done(self):
# Connect a mock signal handler to the signal
handler = mock.MagicMock()
copy_for_translation_done.connect(handler)
page_to_translate = SimplePage.objects.get(id=self.subpage.id)
try:
with self.captureOnCommitCallbacks(execute=True):
page_to_translate.copy_for_translation(self.another_locale)
finally:
# Disconnect mock handler to prevent cross-test pollution
copy_for_translation_done.disconnect(handler)
# Check the signal was fired
self.assertEqual(handler.call_count, 1)

View File

@@ -0,0 +1,218 @@
from django.core.exceptions import ValidationError
from django.test import TestCase, override_settings
from wagtail.coreutils import get_dummy_request
from wagtail.models import Page, Site
class TestSiteNaturalKey(TestCase):
def test_natural_key(self):
site = Site(hostname="example.com", port=8080)
self.assertEqual(site.natural_key(), ("example.com", 8080))
def test_get_by_natural_key(self):
site = Site.objects.create(
hostname="example.com", port=8080, root_page=Page.objects.get(pk=2)
)
self.assertEqual(Site.objects.get_by_natural_key("example.com", 8080), site)
class TestSiteUrl(TestCase):
def test_root_url_http(self):
site = Site(hostname="example.com", port=80)
self.assertEqual(site.root_url, "http://example.com")
def test_root_url_https(self):
site = Site(hostname="example.com", port=443)
self.assertEqual(site.root_url, "https://example.com")
def test_root_url_custom_port(self):
site = Site(hostname="example.com", port=8000)
self.assertEqual(site.root_url, "http://example.com:8000")
class TestSiteNameDisplay(TestCase):
def test_site_name_not_default(self):
site = Site(
hostname="example.com",
port=80,
site_name="example dot com",
is_default_site=False,
)
self.assertEqual(site.__str__(), "example dot com")
def test_site_name_default(self):
site = Site(
hostname="example.com",
port=80,
site_name="example dot com",
is_default_site=True,
)
self.assertEqual(site.__str__(), "example dot com [default]")
def test_no_site_name_not_default_port_80(self):
site = Site(hostname="example.com", port=80, is_default_site=False)
self.assertEqual(site.__str__(), "example.com")
def test_no_site_name_default_port_80(self):
site = Site(hostname="example.com", port=80, is_default_site=True)
self.assertEqual(site.__str__(), "example.com [default]")
def test_no_site_name_not_default_port_n(self):
site = Site(hostname="example.com", port=8080, is_default_site=False)
self.assertEqual(site.__str__(), "example.com:8080")
def test_no_site_name_default_port_n(self):
site = Site(hostname="example.com", port=8080, is_default_site=True)
self.assertEqual(site.__str__(), "example.com:8080 [default]")
class TestSiteOrdering(TestCase):
def setUp(self):
self.root_page = Page.objects.get(pk=2)
Site.objects.all().delete() # Drop the initial site.
def test_site_order_by_hostname(self):
site_1 = Site.objects.create(hostname="charly.com", root_page=self.root_page)
site_2 = Site.objects.create(hostname="bravo.com", root_page=self.root_page)
site_3 = Site.objects.create(hostname="alfa.com", root_page=self.root_page)
self.assertEqual(
list(Site.objects.all().values_list("id", flat=True)),
[site_3.id, site_2.id, site_1.id],
)
def test_site_order_by_hostname_upper(self):
site_1 = Site.objects.create(hostname="charly.com", root_page=self.root_page)
site_2 = Site.objects.create(hostname="Bravo.com", root_page=self.root_page)
site_3 = Site.objects.create(hostname="alfa.com", root_page=self.root_page)
self.assertEqual(
list(Site.objects.all().values_list("id", flat=True)),
[site_3.id, site_2.id, site_1.id],
)
def test_site_order_by_hostname_site_name_irrelevant(self):
site_1 = Site.objects.create(
hostname="charly.com", site_name="X-ray", root_page=self.root_page
)
site_2 = Site.objects.create(
hostname="bravo.com", site_name="Yankee", root_page=self.root_page
)
site_3 = Site.objects.create(
hostname="alfa.com", site_name="Zulu", root_page=self.root_page
)
self.assertEqual(
list(Site.objects.all().values_list("id", flat=True)),
[site_3.id, site_2.id, site_1.id],
)
@override_settings(ALLOWED_HOSTS=["example.com", "unknown.com", "127.0.0.1", "[::1]"])
class TestFindSiteForRequest(TestCase):
def setUp(self):
self.default_site = Site.objects.get()
self.site = Site.objects.create(
hostname="example.com", port=80, root_page=Page.objects.get(pk=2)
)
def test_dummy_request(self):
request = get_dummy_request(site=self.site)
self.assertEqual(Site.find_for_request(request), self.site)
def test_with_host(self):
request = get_dummy_request()
request.META.update({"HTTP_HOST": "example.com", "SERVER_PORT": 80})
self.assertEqual(Site.find_for_request(request), self.site)
def test_with_unknown_host(self):
request = get_dummy_request()
request.META.update({"HTTP_HOST": "unknown.com", "SERVER_PORT": 80})
self.assertEqual(Site.find_for_request(request), self.default_site)
def test_with_server_name(self):
request = get_dummy_request()
request.META.update({"SERVER_NAME": "example.com", "SERVER_PORT": 80})
self.assertEqual(Site.find_for_request(request), self.site)
def test_with_x_forwarded_host(self):
with self.settings(USE_X_FORWARDED_HOST=True):
request = get_dummy_request()
request.META.update(
{"HTTP_X_FORWARDED_HOST": "example.com", "SERVER_PORT": 80}
)
self.assertEqual(Site.find_for_request(request), self.site)
def test_ipv4_host(self):
request = get_dummy_request()
request.META.update({"SERVER_NAME": "127.0.0.1", "SERVER_PORT": 80})
self.assertEqual(Site.find_for_request(request), self.default_site)
def test_ipv6_host(self):
request = get_dummy_request()
request.META.update({"SERVER_NAME": "[::1]", "SERVER_PORT": 80})
self.assertEqual(Site.find_for_request(request), self.default_site)
class TestDefaultSite(TestCase):
def test_create_default_site(self):
Site.objects.all().delete()
Site.objects.create(
hostname="test.com", is_default_site=True, root_page=Page.objects.get(pk=2)
)
self.assertTrue(Site.objects.filter(is_default_site=True).exists())
def test_change_default_site(self):
default = Site.objects.get(is_default_site=True)
default.is_default_site = False
default.save()
Site.objects.create(
hostname="test.com", is_default_site=True, root_page=Page.objects.get(pk=2)
)
self.assertTrue(Site.objects.filter(is_default_site=True).exists())
def test_there_can_only_be_one(self):
site = Site(
hostname="test.com", is_default_site=True, root_page=Page.objects.get(pk=2)
)
with self.assertRaises(ValidationError):
site.clean_fields()
def test_oops_there_is_more_than_one(self):
Site.objects.create(
hostname="example.com",
is_default_site=True,
root_page=Page.objects.get(pk=2),
)
site = Site(
hostname="test.com", is_default_site=True, root_page=Page.objects.get(pk=2)
)
with self.assertRaises(Site.MultipleObjectsReturned):
# If there already are multiple default sites, you're in trouble
site.clean_fields()
class TestGetSiteRootPaths(TestCase):
def setUp(self):
self.default_site = Site.objects.get()
self.abc_site = Site.objects.create(
hostname="abc.com", root_page=self.default_site.root_page
)
self.def_site = Site.objects.create(
hostname="def.com", root_page=self.default_site.root_page
)
# Changing the hostname to show that being the default site takes
# promotes a site over the alphabetical ordering of hostname
self.default_site.hostname = "xyz.com"
self.default_site.save()
def test_result_order_when_multiple_sites_share_the_same_root_page(self):
result = Site.get_site_root_paths()
# An entry for the default site should come first
self.assertEqual(result[0][0], self.default_site.id)
# Followed by entries for others in 'host' alphabetical order
self.assertEqual(result[1][0], self.abc_site.id)
self.assertEqual(result[2][0], self.def_site.id)

View File

@@ -0,0 +1,988 @@
import json
import pickle
from django.apps import apps
from django.db import connection, models
from django.template import Context, Template, engines
from django.test import TestCase, skipUnlessDBFeature
from django.utils.safestring import SafeString
from wagtail import blocks
from wagtail.blocks import StreamBlockValidationError, StreamValue
from wagtail.fields import StreamField
from wagtail.images.models import Image
from wagtail.images.tests.utils import get_test_image_file
from wagtail.models import Page
from wagtail.rich_text import RichText
from wagtail.signal_handlers import disable_reference_index_auto_update
from wagtail.test.testapp.models import (
ComplexDefaultStreamPage,
JSONBlockCountsStreamModel,
JSONMinMaxCountStreamModel,
JSONStreamModel,
StreamPage,
)
class TestLazyStreamField(TestCase):
model = JSONStreamModel
def setUp(self):
self.image = Image.objects.create(
title="Test image", file=get_test_image_file()
)
self.with_image = self.model.objects.create(
body=json.dumps(
[
{"type": "image", "value": self.image.pk},
{"type": "text", "value": "foo"},
]
)
)
self.no_image = self.model.objects.create(
body=json.dumps([{"type": "text", "value": "foo"}])
)
self.three_items = self.model.objects.create(
body=json.dumps(
[
{"type": "text", "value": "foo"},
{"type": "image", "value": self.image.pk},
{"type": "text", "value": "bar"},
]
)
)
def test_lazy_load(self):
"""
Getting a single item should lazily load the StreamField, only
accessing the database once the StreamField is accessed
"""
with self.assertNumQueries(1):
# Get the instance. The StreamField should *not* load the image yet
instance = self.model.objects.get(pk=self.with_image.pk)
with self.assertNumQueries(0):
# Access the body. The StreamField should still not get the image.
body = instance.body
with self.assertNumQueries(1):
# Access the image item from the stream. The image is fetched now
body[0].value
with self.assertNumQueries(0):
# Everything has been fetched now, no further database queries.
self.assertEqual(body[0].value, self.image)
self.assertEqual(body[1].value, "foo")
def test_slice(self):
with self.assertNumQueries(1):
instance = self.model.objects.get(pk=self.three_items.pk)
with self.assertNumQueries(1):
# Access the image item from the stream. The image is fetched now
instance.body[1].value
with self.assertNumQueries(0):
# taking a slice of a StreamValue should re-use already-fetched values
values = [block.value for block in instance.body[1:3]]
self.assertEqual(values, [self.image, "bar"])
with self.assertNumQueries(0):
# test slicing with negative indexing
values = [block.value for block in instance.body[-2:]]
self.assertEqual(values, [self.image, "bar"])
with self.assertNumQueries(0):
# test slicing with skips
values = [block.value for block in instance.body[0:3:2]]
self.assertEqual(values, ["foo", "bar"])
def test_lazy_load_no_images(self):
"""
Getting a single item whose StreamField never accesses the database
should behave as expected.
"""
with self.assertNumQueries(1):
# Get the instance, nothing else
instance = self.model.objects.get(pk=self.no_image.pk)
with self.assertNumQueries(0):
# Access the body. The StreamField has no images, so nothing should
# happen
body = instance.body
self.assertEqual(body[0].value, "foo")
def test_lazy_load_queryset(self):
"""
Ensure that lazy loading StreamField works when gotten as part of a
queryset list
"""
with self.assertNumQueries(1):
instances = self.model.objects.filter(
pk__in=[self.with_image.pk, self.no_image.pk]
)
instances_lookup = {instance.pk: instance for instance in instances}
with self.assertNumQueries(1):
instances_lookup[self.with_image.pk].body[0]
with self.assertNumQueries(0):
instances_lookup[self.no_image.pk].body[0]
def test_lazy_load_queryset_bulk(self):
"""
Ensure that lazy loading StreamField works when gotten as part of a
queryset list
"""
file_obj = get_test_image_file()
image_1 = Image.objects.create(title="Test image 1", file=file_obj)
image_3 = Image.objects.create(title="Test image 3", file=file_obj)
with_image = self.model.objects.create(
body=json.dumps(
[
{"type": "image", "value": image_1.pk},
{"type": "image", "value": None},
{"type": "image", "value": image_3.pk},
{"type": "text", "value": "foo"},
]
)
)
with self.assertNumQueries(1):
instance = self.model.objects.get(pk=with_image.pk)
# Prefetch all image blocks
with self.assertNumQueries(1):
instance.body[0]
# 1. Further image block access should not execute any db lookups
# 2. The blank block '1' should be None.
# 3. The values should be in the original order.
with self.assertNumQueries(0):
assert instance.body[0].value.title == "Test image 1"
assert instance.body[1].value is None
assert instance.body[2].value.title == "Test image 3"
def test_lazy_load_get_prep_value(self):
"""
Saving a lazy StreamField that hasn't had its data accessed should not
cause extra database queries by loading and then re-saving block values.
Instead the initial JSON stream data should be written back for any
blocks that have not been accessed.
"""
with self.assertNumQueries(1):
instance = self.model.objects.get(pk=self.with_image.pk)
# Expect a single UPDATE to update the model, without any additional
# SELECT related to the image block that has not been accessed.
with disable_reference_index_auto_update():
with self.assertNumQueries(1):
instance.save()
class TestSystemCheck(TestCase):
def tearDown(self):
# unregister InvalidStreamModel from the overall model registry
# so that it doesn't break tests elsewhere
for package in ("wagtailcore", "wagtail.tests"):
try:
del apps.all_models[package]["invalidstreammodel"]
except KeyError:
pass
apps.clear_cache()
def test_system_check_validates_block(self):
class InvalidStreamModel(models.Model):
body = StreamField(
[
("heading", blocks.CharBlock()),
("rich text", blocks.RichTextBlock()),
],
)
errors = InvalidStreamModel.check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "wagtailcore.E001")
self.assertEqual(errors[0].hint, "Block names cannot contain spaces")
self.assertEqual(errors[0].obj, InvalidStreamModel._meta.get_field("body"))
class TestStreamValueAccess(TestCase):
def setUp(self):
self.json_body = JSONStreamModel.objects.create(
body=json.dumps([{"type": "text", "value": "foo"}])
)
def test_can_assign_as_list(self):
self.json_body.body = [("rich_text", RichText("<h2>hello world</h2>"))]
self.json_body.save()
# the body should now be a stream consisting of a single rich_text block
fetched_body = JSONStreamModel.objects.get(id=self.json_body.id).body
self.assertIsInstance(fetched_body, StreamValue)
self.assertEqual(len(fetched_body), 1)
self.assertIsInstance(fetched_body[0].value, RichText)
self.assertEqual(fetched_body[0].value.source, "<h2>hello world</h2>")
def test_can_append(self):
self.json_body.body.append(("text", "bar"))
self.json_body.save()
fetched_body = JSONStreamModel.objects.get(id=self.json_body.id).body
self.assertIsInstance(fetched_body, StreamValue)
self.assertEqual(len(fetched_body), 2)
self.assertEqual(fetched_body[0].block_type, "text")
self.assertEqual(fetched_body[0].value, "foo")
self.assertEqual(fetched_body[1].block_type, "text")
self.assertEqual(fetched_body[1].value, "bar")
def test_complex_assignment(self):
page = StreamPage(title="Test page", body=[])
page.body = [
("rich_text", "<h2>hello world</h2>"),
(
"books",
[
("title", "Great Expectations"),
("author", "Charles Dickens"),
],
),
]
self.assertEqual(page.body[0].block_type, "rich_text")
self.assertIsInstance(page.body[0].value, RichText)
self.assertEqual(page.body[0].value.source, "<h2>hello world</h2>")
self.assertEqual(page.body[1].block_type, "books")
self.assertIsInstance(page.body[1].value, StreamValue)
self.assertEqual(len(page.body[1].value), 2)
self.assertEqual(page.body[1].value[0].block_type, "title")
self.assertEqual(page.body[1].value[0].value, "Great Expectations")
self.assertEqual(page.body[1].value[1].block_type, "author")
self.assertEqual(page.body[1].value[1].value, "Charles Dickens")
class TestComplexDefault(TestCase):
def setUp(self):
self.page = ComplexDefaultStreamPage(title="Test page")
def test_default_value(self):
self.assertEqual(self.page.body[0].block_type, "rich_text")
self.assertIsInstance(self.page.body[0].value, RichText)
self.assertEqual(
self.page.body[0].value.source, "<p>My <i>lovely</i> books</p>"
)
self.assertEqual(self.page.body[1].block_type, "books")
self.assertIsInstance(self.page.body[1].value, StreamValue)
self.assertEqual(len(self.page.body[1].value), 2)
self.assertEqual(self.page.body[1].value[0].block_type, "title")
self.assertEqual(self.page.body[1].value[0].value, "The Great Gatsby")
self.assertEqual(self.page.body[1].value[1].block_type, "author")
self.assertEqual(self.page.body[1].value[1].value, "F. Scott Fitzgerald")
class TestStreamFieldRenderingBase(TestCase):
model = JSONStreamModel
def setUp(self):
self.image = Image.objects.create(
title="Test image", file=get_test_image_file()
)
self.instance = self.model.objects.create(
body=json.dumps(
[
{"type": "rich_text", "value": "<p>Rich text</p>"},
{"type": "rich_text", "value": "<p>Привет, Микола</p>"},
{"type": "image", "value": self.image.pk},
{"type": "text", "value": "Hello, World!"},
]
)
)
img_tag = self.image.get_rendition("original").img_tag()
self.expected = "".join(
[
'<div class="block-rich_text"><p>Rich text</p></div>',
'<div class="block-rich_text"><p>Привет, Микола</p></div>',
f'<div class="block-image">{img_tag}</div>',
'<div class="block-text">Hello, World!</div>',
]
)
class TestStreamFieldRendering(TestStreamFieldRenderingBase):
def test_to_string(self):
rendered = str(self.instance.body)
self.assertHTMLEqual(rendered, self.expected)
self.assertIsInstance(rendered, SafeString)
def test___html___access(self):
rendered = self.instance.body.__html__()
self.assertHTMLEqual(rendered, self.expected)
self.assertIsInstance(rendered, SafeString)
class TestStreamFieldDjangoRendering(TestStreamFieldRenderingBase):
def render(self, string, context):
return Template(string).render(Context(context))
def test_render(self):
rendered = self.render("{{ instance.body }}", {"instance": self.instance})
self.assertHTMLEqual(rendered, self.expected)
class TestStreamFieldJinjaRendering(TestStreamFieldRenderingBase):
def setUp(self):
super().setUp()
self.engine = engines["jinja2"]
def render(self, string, context):
return self.engine.from_string(string).render(context)
def test_render(self):
rendered = self.render("{{ instance.body }}", {"instance": self.instance})
self.assertHTMLEqual(rendered, self.expected)
class TestRequiredStreamField(TestCase):
def test_non_blank_field_is_required(self):
# passing a block list
field = StreamField(
[("paragraph", blocks.CharBlock())],
blank=False,
)
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
class MyStreamBlock(blocks.StreamBlock):
paragraph = blocks.CharBlock()
class Meta:
required = False
# passing a block instance
field = StreamField(MyStreamBlock(), blank=False)
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
field = StreamField(
MyStreamBlock(required=False),
blank=False,
)
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
# passing a block class
field = StreamField(MyStreamBlock, blank=False)
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
def test_blank_false_is_implied_by_default(self):
# passing a block list
field = StreamField([("paragraph", blocks.CharBlock())])
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
class MyStreamBlock(blocks.StreamBlock):
paragraph = blocks.CharBlock()
class Meta:
required = False
# passing a block instance
field = StreamField(MyStreamBlock())
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
field = StreamField(MyStreamBlock(required=False))
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
# passing a block class
field = StreamField(MyStreamBlock)
self.assertTrue(field.stream_block.required)
with self.assertRaises(StreamBlockValidationError):
field.stream_block.clean([])
def test_blank_field_is_not_required(self):
# passing a block list
field = StreamField(
[("paragraph", blocks.CharBlock())],
blank=True,
)
self.assertFalse(field.stream_block.required)
field.stream_block.clean([]) # no validation error on empty stream
class MyStreamBlock(blocks.StreamBlock):
paragraph = blocks.CharBlock()
class Meta:
required = True
# passing a block instance
field = StreamField(MyStreamBlock(), blank=True)
self.assertFalse(field.stream_block.required)
field.stream_block.clean([]) # no validation error on empty stream
field = StreamField(MyStreamBlock(required=True), blank=True)
self.assertFalse(field.stream_block.required)
field.stream_block.clean([]) # no validation error on empty stream
# passing a block class
field = StreamField(MyStreamBlock, blank=True)
self.assertFalse(field.stream_block.required)
field.stream_block.clean([]) # no validation error on empty stream
class TestStreamFieldCountValidation(TestCase):
def setUp(self):
self.image = Image.objects.create(
title="Test image", file=get_test_image_file()
)
self.rich_text_body = {"type": "rich_text", "value": "<p>Rich text</p>"}
self.image_body = {"type": "image", "value": self.image.pk}
self.text_body = {"type": "text", "value": "Hello, World!"}
def test_minmax_pass_to_block(self):
instance = JSONMinMaxCountStreamModel.objects.create(body=json.dumps([]))
internal_block = instance.body.stream_block
self.assertEqual(internal_block.meta.min_num, 2)
self.assertEqual(internal_block.meta.max_num, 5)
def test_counts_pass_to_block(self):
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps([]))
block_counts = instance.body.stream_block.meta.block_counts
self.assertEqual(block_counts.get("text"), {"min_num": 1})
self.assertEqual(block_counts.get("rich_text"), {"max_num": 1})
self.assertEqual(block_counts.get("image"), {"min_num": 1, "max_num": 1})
def test_minimum_count(self):
# Single block should fail validation
body = [self.rich_text_body]
instance = JSONMinMaxCountStreamModel.objects.create(body=json.dumps(body))
with self.assertRaises(StreamBlockValidationError) as catcher:
instance.body.stream_block.clean(instance.body)
self.assertEqual(
catcher.exception.as_json_data(),
{"messages": ["The minimum number of items is 2"]},
)
# 2 blocks okay
body = [self.rich_text_body, self.text_body]
instance = JSONMinMaxCountStreamModel.objects.create(body=json.dumps(body))
self.assertTrue(instance.body.stream_block.clean(instance.body))
def test_maximum_count(self):
# 5 blocks okay
body = [self.rich_text_body] * 5
instance = JSONMinMaxCountStreamModel.objects.create(body=json.dumps(body))
self.assertTrue(instance.body.stream_block.clean(instance.body))
# 6 blocks should fail validation
body = [self.rich_text_body, self.text_body] * 3
instance = JSONMinMaxCountStreamModel.objects.create(body=json.dumps(body))
with self.assertRaises(StreamBlockValidationError) as catcher:
instance.body.stream_block.clean(instance.body)
self.assertEqual(
catcher.exception.as_json_data(),
{"messages": ["The maximum number of items is 5"]},
)
def test_block_counts_minimums(self):
JSONBlockCountsStreamModel.objects.create(body=json.dumps([]))
# Zero blocks should fail validation (requires one text, one image)
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps([]))
with self.assertRaises(StreamBlockValidationError) as catcher:
instance.body.stream_block.clean(instance.body)
errors = catcher.exception.as_json_data()["messages"]
self.assertIn("This field is required.", errors)
self.assertIn("Text: The minimum number of items is 1", errors)
self.assertIn("Image: The minimum number of items is 1", errors)
self.assertEqual(len(errors), 3)
# One plain text should fail validation
body = [self.text_body]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
with self.assertRaises(StreamBlockValidationError) as catcher:
instance.body.stream_block.clean(instance.body)
self.assertEqual(
catcher.exception.as_json_data(),
{"messages": ["Image: The minimum number of items is 1"]},
)
# One text, one image should be okay
body = [self.text_body, self.image_body]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
self.assertTrue(instance.body.stream_block.clean(instance.body))
def test_block_counts_maximums(self):
JSONBlockCountsStreamModel.objects.create(body=json.dumps([]))
# Base is one text, one image
body = [self.text_body, self.image_body]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
self.assertTrue(instance.body.stream_block.clean(instance.body))
# Two rich text should error
body = [
self.text_body,
self.image_body,
self.rich_text_body,
self.rich_text_body,
]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
with self.assertRaises(StreamBlockValidationError):
instance.body.stream_block.clean(instance.body)
# Two images should error
body = [self.text_body, self.image_body, self.image_body]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
with self.assertRaises(StreamBlockValidationError) as catcher:
instance.body.stream_block.clean(instance.body)
self.assertEqual(
catcher.exception.as_json_data(),
{"messages": ["Image: The maximum number of items is 1"]},
)
# One text, one rich, one image should be okay
body = [self.text_body, self.image_body, self.rich_text_body]
instance = JSONBlockCountsStreamModel.objects.create(body=json.dumps(body))
self.assertTrue(instance.body.stream_block.clean(instance.body))
def test_streamfield_count_argument_precedence(self):
class TestStreamBlock(blocks.StreamBlock):
heading = blocks.CharBlock()
paragraph = blocks.RichTextBlock()
class Meta:
min_num = 2
max_num = 5
block_counts = {"heading": {"max_num": 1}}
# args being picked up from the class definition
field = StreamField(TestStreamBlock)
self.assertEqual(field.stream_block.meta.min_num, 2)
self.assertEqual(field.stream_block.meta.max_num, 5)
self.assertEqual(field.stream_block.meta.block_counts["heading"]["max_num"], 1)
# args being overridden by StreamField
field = StreamField(
TestStreamBlock,
min_num=3,
max_num=6,
block_counts={"heading": {"max_num": 2}},
)
self.assertEqual(field.stream_block.meta.min_num, 3)
self.assertEqual(field.stream_block.meta.max_num, 6)
self.assertEqual(field.stream_block.meta.block_counts["heading"]["max_num"], 2)
# passing None from StreamField should cancel limits set at the block level
field = StreamField(
TestStreamBlock,
min_num=None,
max_num=None,
block_counts=None,
)
self.assertIsNone(field.stream_block.meta.min_num)
self.assertIsNone(field.stream_block.meta.max_num)
self.assertIsNone(field.stream_block.meta.block_counts)
class TestJSONStreamField(TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.instance = JSONStreamModel.objects.create(
body=[{"type": "text", "value": "foo"}],
)
def test_internal_type(self):
json = StreamField([("paragraph", blocks.CharBlock())])
self.assertEqual(json.get_internal_type(), "JSONField")
def test_json_body_equals_to_text_body(self):
instance_text = JSONStreamModel.objects.create(
body=json.dumps([{"type": "text", "value": "foo"}]),
)
self.assertEqual(
instance_text.body.render_as_block(), self.instance.body.render_as_block()
)
def test_json_body_create_preserialised_value(self):
instance_preserialised = JSONStreamModel.objects.create(
body=json.dumps([{"type": "text", "value": "foo"}]),
)
self.assertEqual(
instance_preserialised.body.render_as_block(),
self.instance.body.render_as_block(),
)
@skipUnlessDBFeature("supports_json_field_contains")
def test_json_contains_lookup(self):
value = {"value": "foo"}
if connection.features.json_key_contains_list_matching_requires_list:
value = [value]
instance = JSONStreamModel.objects.filter(body__contains=value).first()
self.assertIsNotNone(instance)
self.assertEqual(instance.id, self.instance.id)
class TestStreamFieldPickleSupport(TestCase):
def setUp(self):
# Find root page
self.root_page = Page.objects.get(id=2)
def test_pickle_support(self):
stream_page = StreamPage(title="stream page", body=[("text", "hello")])
self.root_page.add_child(instance=stream_page)
# check that page can be serialized / deserialized
serialized = pickle.dumps(stream_page)
deserialized = pickle.loads(serialized)
# check that serialized page can be serialized / deserialized again
serialized2 = pickle.dumps(deserialized)
deserialized2 = pickle.loads(serialized2)
# check that page data is not corrupted
self.assertEqual(stream_page.body, deserialized.body)
self.assertEqual(stream_page.body, deserialized2.body)
class TestGetBlockByContentPath(TestCase):
def setUp(self):
self.page = StreamPage(
title="Test page",
body=[
{"id": "123", "type": "text", "value": "Hello world"},
{
"id": "234",
"type": "product",
"value": {"name": "Cuddly toy", "price": "$9.95"},
},
{
"id": "345",
"type": "books",
"value": [
{"id": "111", "type": "author", "value": "Charles Dickens"},
{"id": "222", "type": "title", "value": "Great Expectations"},
],
},
{
"id": "456",
"type": "title_list",
"value": [
{"id": "111", "type": "item", "value": "Barnaby Rudge"},
{"id": "222", "type": "item", "value": "A Tale of Two Cities"},
],
},
],
)
def test_get_block_by_content_path(self):
field = self.page._meta.get_field("body")
# top-level blocks
bound_block = field.get_block_by_content_path(self.page.body, ["123"])
self.assertEqual(bound_block.value, "Hello world")
self.assertEqual(bound_block.block.name, "text")
bound_block = field.get_block_by_content_path(self.page.body, ["234"])
self.assertEqual(bound_block.block.name, "product")
bound_block = field.get_block_by_content_path(self.page.body, ["999"])
self.assertIsNone(bound_block)
# StructBlock children
bound_block = field.get_block_by_content_path(self.page.body, ["234", "name"])
self.assertEqual(bound_block.value, "Cuddly toy")
bound_block = field.get_block_by_content_path(self.page.body, ["234", "colour"])
self.assertIsNone(bound_block)
# StreamBlock children
bound_block = field.get_block_by_content_path(self.page.body, ["345", "111"])
self.assertEqual(bound_block.value, "Charles Dickens")
bound_block = field.get_block_by_content_path(self.page.body, ["345", "999"])
self.assertIsNone(bound_block)
# ListBlock children
bound_block = field.get_block_by_content_path(self.page.body, ["456", "111"])
self.assertEqual(bound_block.value, "Barnaby Rudge")
bound_block = field.get_block_by_content_path(self.page.body, ["456", "999"])
self.assertIsNone(bound_block)
class TestConstructStreamFieldFromLookup(TestCase):
def test_construct_block_list_from_lookup(self):
field = StreamField(
[
("heading", 0),
("paragraph", 1),
("button", 3),
],
block_lookup={
0: ("wagtail.blocks.CharBlock", [], {"required": True}),
1: ("wagtail.blocks.RichTextBlock", [], {}),
2: ("wagtail.blocks.PageChooserBlock", [], {}),
3: (
"wagtail.blocks.StructBlock",
[
[
("page", 2),
("link_text", 0),
]
],
{},
),
},
)
stream_block = field.stream_block
self.assertIsInstance(stream_block, blocks.StreamBlock)
self.assertEqual(len(stream_block.child_blocks), 3)
heading_block = stream_block.child_blocks["heading"]
self.assertIsInstance(heading_block, blocks.CharBlock)
self.assertTrue(heading_block.required)
self.assertEqual(heading_block.name, "heading")
paragraph_block = stream_block.child_blocks["paragraph"]
self.assertIsInstance(paragraph_block, blocks.RichTextBlock)
self.assertEqual(paragraph_block.name, "paragraph")
button_block = stream_block.child_blocks["button"]
self.assertIsInstance(button_block, blocks.StructBlock)
self.assertEqual(button_block.name, "button")
self.assertEqual(len(button_block.child_blocks), 2)
page_block = button_block.child_blocks["page"]
self.assertIsInstance(page_block, blocks.PageChooserBlock)
link_text_block = button_block.child_blocks["link_text"]
self.assertIsInstance(link_text_block, blocks.CharBlock)
self.assertEqual(link_text_block.name, "link_text")
def test_construct_top_level_block_from_lookup(self):
field = StreamField(
4,
block_lookup={
0: ("wagtail.blocks.CharBlock", [], {"required": True}),
1: ("wagtail.blocks.RichTextBlock", [], {}),
2: ("wagtail.blocks.PageChooserBlock", [], {}),
3: (
"wagtail.blocks.StructBlock",
[
[
("page", 2),
("link_text", 0),
]
],
{},
),
4: (
"wagtail.blocks.StreamBlock",
[
[
("heading", 0),
("paragraph", 1),
("button", 3),
]
],
{},
),
},
)
stream_block = field.stream_block
self.assertIsInstance(stream_block, blocks.StreamBlock)
self.assertEqual(len(stream_block.child_blocks), 3)
heading_block = stream_block.child_blocks["heading"]
self.assertIsInstance(heading_block, blocks.CharBlock)
self.assertTrue(heading_block.required)
self.assertEqual(heading_block.name, "heading")
paragraph_block = stream_block.child_blocks["paragraph"]
self.assertIsInstance(paragraph_block, blocks.RichTextBlock)
self.assertEqual(paragraph_block.name, "paragraph")
button_block = stream_block.child_blocks["button"]
self.assertIsInstance(button_block, blocks.StructBlock)
self.assertEqual(button_block.name, "button")
self.assertEqual(len(button_block.child_blocks), 2)
page_block = button_block.child_blocks["page"]
self.assertIsInstance(page_block, blocks.PageChooserBlock)
link_text_block = button_block.child_blocks["link_text"]
self.assertIsInstance(link_text_block, blocks.CharBlock)
self.assertEqual(link_text_block.name, "link_text")
# Used by TestDeconstructStreamFieldWithLookup.test_deconstruct_with_listblock_subclass -
# needs to be a module-level definition so that the path returned from deconstruct is valid
class BulletListBlock(blocks.ListBlock):
def __init__(self, **kwargs):
super().__init__(blocks.CharBlock(required=True), **kwargs)
class TestDeconstructStreamFieldWithLookup(TestCase):
def test_deconstruct(self):
class ButtonBlock(blocks.StructBlock):
page = blocks.PageChooserBlock()
link_text = blocks.CharBlock(required=True)
field = StreamField(
[
("heading", blocks.CharBlock(required=True)),
("paragraph", blocks.RichTextBlock()),
("button", ButtonBlock()),
],
blank=True,
)
field.set_attributes_from_name("body")
name, path, args, kwargs = field.deconstruct()
self.assertEqual(name, "body")
self.assertEqual(path, "wagtail.fields.StreamField")
self.assertEqual(
args,
[
[
("heading", 0),
("paragraph", 1),
("button", 3),
]
],
)
self.assertEqual(
kwargs,
{
"blank": True,
"block_lookup": {
0: ("wagtail.blocks.CharBlock", (), {"required": True}),
1: ("wagtail.blocks.RichTextBlock", (), {}),
2: ("wagtail.blocks.PageChooserBlock", (), {}),
3: (
"wagtail.blocks.StructBlock",
[
[
("page", 2),
("link_text", 0),
]
],
{},
),
},
},
)
def test_deconstruct_with_listblock(self):
field = StreamField(
[
("heading", blocks.CharBlock(required=True)),
("bullets", blocks.ListBlock(blocks.CharBlock(required=True))),
],
blank=True,
)
field.set_attributes_from_name("body")
name, path, args, kwargs = field.deconstruct()
self.assertEqual(name, "body")
self.assertEqual(path, "wagtail.fields.StreamField")
self.assertEqual(
args,
[
[
("heading", 0),
("bullets", 1),
]
],
)
self.assertEqual(
kwargs,
{
"blank": True,
"block_lookup": {
0: ("wagtail.blocks.CharBlock", (), {"required": True}),
1: ("wagtail.blocks.ListBlock", (0,), {}),
},
},
)
def test_deconstruct_with_listblock_with_child_block_kwarg(self):
field = StreamField(
[
("heading", blocks.CharBlock(required=True)),
(
"bullets",
blocks.ListBlock(child_block=blocks.CharBlock(required=True)),
),
],
blank=True,
)
field.set_attributes_from_name("body")
name, path, args, kwargs = field.deconstruct()
self.assertEqual(name, "body")
self.assertEqual(path, "wagtail.fields.StreamField")
self.assertEqual(
args,
[
[
("heading", 0),
("bullets", 1),
]
],
)
self.assertEqual(
kwargs,
{
"blank": True,
"block_lookup": {
0: ("wagtail.blocks.CharBlock", (), {"required": True}),
1: ("wagtail.blocks.ListBlock", (), {"child_block": 0}),
},
},
)
def test_deconstruct_with_listblock_subclass(self):
# See https://github.com/wagtail/wagtail/issues/12164 - unlike StructBlock and StreamBlock,
# ListBlock's deconstruct method doesn't reduce subclasses to the base ListBlock class.
# Therefore, if a ListBlock subclass defines its own __init__ method with an incompatible
# signature to the base ListBlock, this custom signature will be preserved in the result of
# deconstruct(), and we cannot rely on the first argument being the child block.
field = StreamField(
[
("heading", blocks.CharBlock(required=True)),
("bullets", BulletListBlock()),
],
blank=True,
)
field.set_attributes_from_name("body")
name, path, args, kwargs = field.deconstruct()
self.assertEqual(name, "body")
self.assertEqual(path, "wagtail.fields.StreamField")
self.assertEqual(
args,
[
[
("heading", 0),
("bullets", 1),
]
],
)
self.assertEqual(
kwargs,
{
"blank": True,
"block_lookup": {
0: ("wagtail.blocks.CharBlock", (), {"required": True}),
1: ("wagtail.tests.test_streamfield.BulletListBlock", (), {}),
},
},
)

View File

@@ -0,0 +1,353 @@
import itertools
from django.test import TestCase
from wagtail.telepath import Adapter, JSContext, register
class Artist:
def __init__(self, name):
self.name = name
class Album:
def __init__(self, title, artists):
self.title = title
self.artists = artists
class ArtistAdapter(Adapter):
js_constructor = "music.Artist"
def js_args(self, obj):
return [obj.name]
register(ArtistAdapter(), Artist)
class AlbumAdapter(Adapter):
js_constructor = "music.Album"
def js_args(self, obj):
return [obj.title, obj.artists]
class Media:
js = ["music_player.js"]
register(AlbumAdapter(), Album)
class TestPacking(TestCase):
def test_pack_object(self):
beyonce = Artist("Beyoncé")
ctx = JSContext()
result = ctx.pack(beyonce)
self.assertEqual(result, {"_type": "music.Artist", "_args": ["Beyoncé"]})
def test_pack_list(self):
destinys_child = [
Artist("Beyoncé"),
Artist("Kelly Rowland"),
Artist("Michelle Williams"),
]
ctx = JSContext()
result = ctx.pack(destinys_child)
self.assertEqual(
result,
[
{"_type": "music.Artist", "_args": ["Beyoncé"]},
{"_type": "music.Artist", "_args": ["Kelly Rowland"]},
{"_type": "music.Artist", "_args": ["Michelle Williams"]},
],
)
def test_pack_dict(self):
glastonbury = {
"pyramid_stage": Artist("Beyoncé"),
"acoustic_stage": Artist("Ed Sheeran"),
}
ctx = JSContext()
result = ctx.pack(glastonbury)
self.assertEqual(
result,
{
"pyramid_stage": {"_type": "music.Artist", "_args": ["Beyoncé"]},
"acoustic_stage": {"_type": "music.Artist", "_args": ["Ed Sheeran"]},
},
)
def test_dict_reserved_words(self):
profile = {
"_artist": Artist("Beyoncé"),
"_type": "R&B",
}
ctx = JSContext()
result = ctx.pack(profile)
self.assertEqual(
result,
{
"_dict": {
"_artist": {"_type": "music.Artist", "_args": ["Beyoncé"]},
"_type": "R&B",
}
},
)
def test_recursive_arg_packing(self):
dangerously_in_love = Album(
"Dangerously in Love",
[
Artist("Beyoncé"),
],
)
ctx = JSContext()
result = ctx.pack(dangerously_in_love)
self.assertEqual(
result,
{
"_type": "music.Album",
"_args": [
"Dangerously in Love",
[
{"_type": "music.Artist", "_args": ["Beyoncé"]},
],
],
},
)
self.assertIn("music_player.js", str(ctx.media))
def test_object_references(self):
beyonce = Artist("Beyoncé")
jay_z = Artist("Jay-Z")
discography = [
Album("Dangerously in Love", [beyonce]),
Album("Everything Is Love", [beyonce, jay_z]),
]
ctx = JSContext()
result = ctx.pack(discography)
self.assertEqual(
result,
[
{
"_type": "music.Album",
"_args": [
"Dangerously in Love",
[
{"_type": "music.Artist", "_args": ["Beyoncé"], "_id": 0},
],
],
},
{
"_type": "music.Album",
"_args": [
"Everything Is Love",
[
{"_ref": 0},
{"_type": "music.Artist", "_args": ["Jay-Z"]},
],
],
},
],
)
self.assertIn("music_player.js", str(ctx.media))
def test_list_references(self):
destinys_child = [
Artist("Beyoncé"),
Artist("Kelly Rowland"),
Artist("Michelle Williams"),
]
discography = [
Album("Destiny's Child", destinys_child),
Album("Survivor", destinys_child),
]
ctx = JSContext()
result = ctx.pack(discography)
self.assertEqual(
result,
[
{
"_type": "music.Album",
"_args": [
"Destiny's Child",
{
"_list": [
{"_type": "music.Artist", "_args": ["Beyoncé"]},
{"_type": "music.Artist", "_args": ["Kelly Rowland"]},
{
"_type": "music.Artist",
"_args": ["Michelle Williams"],
},
],
"_id": 0,
},
],
},
{
"_type": "music.Album",
"_args": [
"Survivor",
{"_ref": 0},
],
},
],
)
def test_primitive_value_references(self):
beyonce_name = "Beyoncé Giselle Knowles-Carter"
beyonce = Artist(beyonce_name)
discography = [
Album("Dangerously in Love", [beyonce]),
Album(beyonce_name, [beyonce]),
]
ctx = JSContext()
result = ctx.pack(discography)
self.assertEqual(
result,
[
{
"_type": "music.Album",
"_args": [
"Dangerously in Love",
[
{
"_type": "music.Artist",
"_args": [
{"_val": "Beyoncé Giselle Knowles-Carter", "_id": 0}
],
"_id": 1,
},
],
],
},
{
"_type": "music.Album",
"_args": [
{"_ref": 0},
[
{"_ref": 1},
],
],
},
],
)
def test_avoid_primitive_value_references_for_short_strings(self):
beyonce_name = "Beyoncé"
beyonce = Artist(beyonce_name)
discography = [
Album("Dangerously in Love", [beyonce]),
Album(beyonce_name, [beyonce]),
]
ctx = JSContext()
result = ctx.pack(discography)
self.assertEqual(
result,
[
{
"_type": "music.Album",
"_args": [
"Dangerously in Love",
[
{
"_type": "music.Artist",
"_args": ["Beyoncé"],
"_id": 1,
},
],
],
},
{
"_type": "music.Album",
"_args": [
"Beyoncé",
[
{"_ref": 1},
],
],
},
],
)
class Ark:
def __init__(self, animals):
self.animals = animals
def animals_by_type(self):
return itertools.groupby(self.animals, lambda animal: animal["type"])
class ArkAdapter(Adapter):
js_constructor = "boats.Ark"
def js_args(self, obj):
return [obj.animals_by_type()]
register(ArkAdapter(), Ark)
class TestIDCollisions(TestCase):
def test_grouper_object_collisions(self):
"""
Certain functions such as itertools.groupby will cause new objects (namely, tuples and
custom itertools._grouper iterables) to be created in the course of iterating over the
object tree. If we're not careful, these will be released and the memory reallocated to
new objects while we're still iterating, leading to ID collisions.
"""
# create 100 Ark objects all with distinct animals (no object references are re-used)
arks = [
Ark(
[
{"type": "lion", "name": "Simba %i" % i},
{"type": "lion", "name": "Nala %i" % i},
{"type": "dog", "name": "Lady %i" % i},
{"type": "dog", "name": "Tramp %i" % i},
]
)
for i in range(0, 100)
]
ctx = JSContext()
result = ctx.pack(arks)
self.assertEqual(len(result), 100)
for i, ark in enumerate(result):
# each object should be represented in full, with no _id or _ref keys
self.assertEqual(
ark,
{
"_type": "boats.Ark",
"_args": [
[
[
"lion",
[
{"type": "lion", "name": "Simba %i" % i},
{"type": "lion", "name": "Nala %i" % i},
],
],
[
"dog",
[
{"type": "dog", "name": "Lady %i" % i},
{"type": "dog", "name": "Tramp %i" % i},
],
],
]
],
},
)

View File

@@ -0,0 +1,458 @@
import json
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import TestCase
from wagtail.admin.tests.test_contentstate import content_state_equal
from wagtail.models import PAGE_MODEL_CLASSES, Page, Site
from wagtail.test.dummy_external_storage import DummyExternalStorage
from wagtail.test.testapp.models import (
BusinessChild,
BusinessIndex,
BusinessNowherePage,
BusinessSubIndex,
EventIndex,
EventPage,
SectionedRichTextPage,
SimpleChildPage,
SimplePage,
SimpleParentPage,
StreamPage,
)
from wagtail.test.utils import WagtailPageTests, WagtailTestUtils
from wagtail.test.utils.form_data import (
inline_formset,
nested_form_data,
rich_text,
streamfield,
)
class TestAssertTagInHTML(WagtailTestUtils, TestCase):
def test_assert_tag_in_html(self):
haystack = """<ul>
<li class="normal">hugh</li>
<li class="normal">pugh</li>
<li class="really important" lang="en"><em>barney</em> mcgrew</li>
</ul>"""
self.assertTagInHTML('<li lang="en" class="important really">', haystack)
self.assertTagInHTML('<li class="normal">', haystack, count=2)
with self.assertRaises(AssertionError):
self.assertTagInHTML('<div lang="en" class="important really">', haystack)
with self.assertRaises(AssertionError):
self.assertTagInHTML(
'<li lang="en" class="important really">', haystack, count=2
)
with self.assertRaises(AssertionError):
self.assertTagInHTML('<li lang="en" class="important">', haystack)
with self.assertRaises(AssertionError):
self.assertTagInHTML(
'<li lang="en" class="important really" data-extra="boom">', haystack
)
def test_assert_tag_in_html_with_extra_attrs(self):
haystack = """<ul>
<li class="normal">hugh</li>
<li class="normal">pugh</li>
<li class="really important" lang="en"><em>barney</em> mcgrew</li>
</ul>"""
self.assertTagInHTML(
'<li class="important really">', haystack, allow_extra_attrs=True
)
self.assertTagInHTML("<li>", haystack, count=3, allow_extra_attrs=True)
with self.assertRaises(AssertionError):
self.assertTagInHTML(
'<li class="normal" lang="en">', haystack, allow_extra_attrs=True
)
with self.assertRaises(AssertionError):
self.assertTagInHTML(
'<li class="important really">',
haystack,
count=2,
allow_extra_attrs=True,
)
def test_assert_tag_in_template_script(self):
haystack = """<html>
<script type="text/template">
<p class="really important">first template block</p>
</script>
<script type="text/template">
<p class="really important">second template block</p>
</script>
<p class="normal">not in a script tag</p>
</html>"""
self.assertTagInTemplateScript('<p class="important really">', haystack)
self.assertTagInTemplateScript(
'<p class="important really">', haystack, count=2
)
with self.assertRaises(AssertionError):
self.assertTagInTemplateScript('<p class="normal">', haystack)
class TestWagtailPageTests(WagtailPageTests):
def setUp(self):
super().setUp()
site = Site.objects.get(is_default_site=True)
self.root = site.root_page.specific
def test_assert_can_create_at(self):
# It should be possible to create an EventPage under an EventIndex,
self.assertCanCreateAt(EventIndex, EventPage)
self.assertCanCreateAt(Page, EventIndex)
# It should not be possible to create a SimplePage under a BusinessChild
self.assertCanNotCreateAt(SimplePage, BusinessChild)
# This should raise, as it *is not* possible
with self.assertRaises(AssertionError):
self.assertCanCreateAt(SimplePage, BusinessChild)
# This should raise, as it *is* possible
with self.assertRaises(AssertionError):
self.assertCanNotCreateAt(EventIndex, EventPage)
def test_assert_can_create(self):
self.assertFalse(EventIndex.objects.exists())
self.assertCanCreate(
self.root,
EventIndex,
{
"title": "Event Index",
"intro": """{"entityMap": {},"blocks": [
{"inlineStyleRanges": [], "text": "Event intro", "depth": 0, "type": "unstyled", "key": "00000", "entityRanges": []}
]}""",
},
)
self.assertTrue(EventIndex.objects.exists())
self.assertTrue(EventIndex.objects.get().live)
self.assertCanCreate(
self.root,
StreamPage,
{
"title": "Flierp",
"body-0-type": "text",
"body-0-value": "Dit is onze mooie text",
"body-0-order": "0",
"body-0-deleted": "",
"body-1-type": "rich_text",
"body-1-value": """{"entityMap": {},"blocks": [
{"inlineStyleRanges": [], "text": "Dit is onze mooie text in een ferrari", "depth": 0, "type": "unstyled", "key": "00000", "entityRanges": []}
]}""",
"body-1-order": "1",
"body-1-deleted": "",
"body-2-type": "product",
"body-2-value-name": "pegs",
"body-2-value-price": "a pound",
"body-2-order": "2",
"body-2-deleted": "",
"body-count": "3",
},
)
self.assertCanCreate(
self.root,
SectionedRichTextPage,
{
"title": "Fight Club",
"sections-TOTAL_FORMS": "2",
"sections-INITIAL_FORMS": "0",
"sections-MIN_NUM_FORMS": "0",
"sections-MAX_NUM_FORMS": "1000",
"sections-0-body": """{"entityMap": {},"blocks": [
{"inlineStyleRanges": [], "text": "Rule 1: You do not talk about Fight Club", "depth": 0, "type": "unstyled", "key": "00000", "entityRanges": []}
]}""",
"sections-0-ORDER": "0",
"sections-0-DELETE": "",
"sections-1-body": """{"entityMap": {},"blocks": [
{"inlineStyleRanges": [], "text": "Rule 2: You DO NOT talk about Fight Club", "depth": 0, "type": "unstyled", "key": "00000", "entityRanges": []}
]}""",
"sections-1-ORDER": "0",
"sections-1-DELETE": "",
},
)
def test_assert_can_create_for_page_without_publish(self):
self.assertCanCreate(
self.root,
SimplePage,
{"title": "Simple Lorem Page", "content": "Lorem ipsum dolor sit amet"},
publish=False,
)
created_page = Page.objects.get(title="Simple Lorem Page")
self.assertFalse(created_page.live)
def test_assert_can_create_with_form_helpers(self):
# same as test_assert_can_create, but using the helpers from wagtail.test.utils.form_data
# as an end-to-end test
self.assertFalse(EventIndex.objects.exists())
self.assertCanCreate(
self.root,
EventIndex,
nested_form_data(
{"title": "Event Index", "intro": rich_text("<p>Event intro</p>")}
),
)
self.assertTrue(EventIndex.objects.exists())
self.assertCanCreate(
self.root,
StreamPage,
nested_form_data(
{
"title": "Flierp",
"body": streamfield(
[
("text", "Dit is onze mooie text"),
(
"rich_text",
rich_text(
"<p>Dit is onze mooie text in een ferrari</p>"
),
),
("product", {"name": "pegs", "price": "a pound"}),
]
),
}
),
)
self.assertCanCreate(
self.root,
SectionedRichTextPage,
nested_form_data(
{
"title": "Fight Club",
"sections": inline_formset(
[
{
"body": rich_text(
"<p>Rule 1: You do not talk about Fight Club</p>"
)
},
{
"body": rich_text(
"<p>Rule 2: You DO NOT talk about Fight Club</p>"
)
},
]
),
}
),
)
def test_assert_can_create_subpage_rules(self):
simple_page = SimplePage(title="Simple Page", slug="simple", content="hello")
self.root.add_child(instance=simple_page)
# This should raise an error, as a BusinessChild can not be created under a SimplePage
with self.assertRaisesRegex(
AssertionError,
r"Can not create a tests.businesschild under a tests.simplepage",
):
self.assertCanCreate(simple_page, BusinessChild, {})
def test_assert_can_create_validation_error(self):
# This should raise some validation errors, complaining about missing
# title and slug fields
with self.assertRaisesRegex(AssertionError, r"\bslug:\n[\s\S]*\btitle:\n"):
self.assertCanCreate(self.root, SimplePage, {})
def test_assert_allowed_subpage_types(self):
self.assertAllowedSubpageTypes(BusinessIndex, {BusinessChild, BusinessSubIndex})
self.assertAllowedSubpageTypes(BusinessChild, {})
# All page types can be created under the Page model, except those with a parent_page_types
# rule excluding it
all_but_business = set(PAGE_MODEL_CLASSES) - {
BusinessSubIndex,
BusinessChild,
BusinessNowherePage,
SimpleChildPage,
}
self.assertAllowedSubpageTypes(Page, all_but_business)
with self.assertRaises(AssertionError):
self.assertAllowedSubpageTypes(
BusinessSubIndex, {BusinessSubIndex, BusinessChild}
)
def test_assert_allowed_parent_page_types(self):
self.assertAllowedParentPageTypes(
BusinessChild, {BusinessIndex, BusinessSubIndex}
)
self.assertAllowedParentPageTypes(BusinessSubIndex, {BusinessIndex})
# BusinessIndex can be created under all page types that do not have a subpage_types rule
all_but_business = set(PAGE_MODEL_CLASSES) - {
BusinessSubIndex,
BusinessChild,
BusinessIndex,
SimpleParentPage,
}
self.assertAllowedParentPageTypes(BusinessIndex, all_but_business)
with self.assertRaises(AssertionError):
self.assertAllowedParentPageTypes(
BusinessSubIndex, {BusinessSubIndex, BusinessIndex}
)
class TestFormDataHelpers(TestCase):
def test_nested_form_data(self):
result = nested_form_data(
{
"foo": "bar",
"parent": {
"child": "field",
},
}
)
self.assertEqual(result, {"foo": "bar", "parent-child": "field"})
def test_streamfield(self):
result = nested_form_data(
{
"content": streamfield(
[
("text", "Hello, world"),
("text", "Goodbye, world"),
("coffee", {"type": "latte", "milk": "soya"}),
]
)
}
)
self.assertEqual(
result,
{
"content-count": "3",
"content-0-type": "text",
"content-0-value": "Hello, world",
"content-0-order": "0",
"content-0-deleted": "",
"content-1-type": "text",
"content-1-value": "Goodbye, world",
"content-1-order": "1",
"content-1-deleted": "",
"content-2-type": "coffee",
"content-2-value-type": "latte",
"content-2-value-milk": "soya",
"content-2-order": "2",
"content-2-deleted": "",
},
)
def test_inline_formset(self):
result = nested_form_data(
{
"lines": inline_formset(
[
{"text": "Hello"},
{"text": "World"},
]
)
}
)
self.assertEqual(
result,
{
"lines-TOTAL_FORMS": "2",
"lines-INITIAL_FORMS": "0",
"lines-MIN_NUM_FORMS": "0",
"lines-MAX_NUM_FORMS": "1000",
"lines-0-text": "Hello",
"lines-0-ORDER": "0",
"lines-0-DELETE": "",
"lines-1-text": "World",
"lines-1-ORDER": "1",
"lines-1-DELETE": "",
},
)
def test_default_rich_text(self):
result = rich_text("<h2>title</h2><p>para</p>")
self.assertTrue(
content_state_equal(
json.loads(result),
{
"entityMap": {},
"blocks": [
{
"inlineStyleRanges": [],
"text": "title",
"depth": 0,
"type": "header-two",
"key": "00000",
"entityRanges": [],
},
{
"inlineStyleRanges": [],
"text": "para",
"depth": 0,
"type": "unstyled",
"key": "00000",
"entityRanges": [],
},
],
},
)
)
def test_rich_text_with_custom_features(self):
# feature list doesn't allow <h2>, so it should become an unstyled paragraph block
result = rich_text("<h2>title</h2><p>para</p>", features=["p"])
self.assertTrue(
content_state_equal(
json.loads(result),
{
"entityMap": {},
"blocks": [
{
"inlineStyleRanges": [],
"text": "title",
"depth": 0,
"type": "unstyled",
"key": "00000",
"entityRanges": [],
},
{
"inlineStyleRanges": [],
"text": "para",
"depth": 0,
"type": "unstyled",
"key": "00000",
"entityRanges": [],
},
],
},
)
)
def test_rich_text_with_alternative_editor(self):
result = rich_text("<h2>title</h2><p>para</p>", editor="custom")
self.assertEqual(result, "<h2>title</h2><p>para</p>")
class TestDummyExternalStorage(WagtailTestUtils, TestCase):
def test_save_with_incorrect_file_object_position(self):
"""
Test that DummyExternalStorage correctly warns about attempts
to write files that are not rewound to the start
"""
# This is a 1x1 black png
png = (
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00"
b"\x00\x01\x00\x00\x00\x01\x08\x06\x00\x00\x00"
b"\x1f\x15\xc4\x89\x00\x00\x00\rIDATx\x9cc````"
b"\x00\x00\x00\x05\x00\x01\xa5\xf6E@\x00\x00"
b"\x00\x00IEND\xaeB`\x82"
)
simple_png = SimpleUploadedFile(
name="test.png", content=png, content_type="image/png"
)
simple_png.read()
with self.assertRaisesMessage(
ValueError,
"Content file pointer should be at 0 - got 70 instead",
):
DummyExternalStorage().save("test.png", simple_png)

View File

@@ -0,0 +1,270 @@
from unittest.mock import patch
from django.conf import settings
from django.core import checks
from django.db import models
from django.test import TestCase, override_settings
from wagtail.models import Locale
from wagtail.test.i18n.models import (
ClusterableTestModel,
ClusterableTestModelChild,
ClusterableTestModelTranslatableChild,
InheritedTestModel,
TestModel,
)
def make_test_instance(model=None, **kwargs):
if model is None:
model = TestModel
return model.objects.create(**kwargs)
@override_settings(WAGTAIL_I18N_ENABLED=True)
class TestTranslatableMixin(TestCase):
def setUp(self):
language_codes = dict(settings.LANGUAGES).keys()
for language_code in language_codes:
Locale.objects.get_or_create(language_code=language_code)
# create the locales
self.locale = Locale.objects.get(language_code="en")
self.another_locale = Locale.objects.get(language_code="fr")
# add the main model
self.main_instance = make_test_instance(locale=self.locale, title="Main Model")
# add a translated model
self.translated_model = make_test_instance(
locale=self.another_locale,
translation_key=self.main_instance.translation_key,
title="Translated Model",
)
# add a random model that shouldn't show up anywhere
make_test_instance()
def test_get_translations_inclusive_false(self):
self.assertSequenceEqual(
list(self.main_instance.get_translations()), [self.translated_model]
)
def test_get_translations_inclusive_true(self):
self.assertEqual(
list(self.main_instance.get_translations(inclusive=True)),
[self.main_instance, self.translated_model],
)
def test_get_translation(self):
self.assertEqual(
self.main_instance.get_translation(self.locale), self.main_instance
)
def test_get_translation_using_locale_id(self):
self.assertEqual(
self.main_instance.get_translation(self.locale.id), self.main_instance
)
def test_get_translation_or_none_return_translation(self):
with patch.object(
self.main_instance, "get_translation"
) as mock_get_translation:
mock_get_translation.return_value = self.translated_model
self.assertEqual(
self.main_instance.get_translation_or_none(self.another_locale),
self.translated_model,
)
def test_get_translation_or_none_return_none(self):
self.translated_model.delete()
with patch.object(
self.main_instance, "get_translation"
) as mock_get_translation:
mock_get_translation.side_effect = self.main_instance.DoesNotExist
self.assertIsNone(
self.main_instance.get_translation_or_none(self.another_locale)
)
def test_has_translation_when_exists(self):
self.assertTrue(self.main_instance.has_translation(self.locale))
def test_has_translation_when_exists_using_locale_id(self):
self.assertTrue(self.main_instance.has_translation(self.locale.id))
def test_has_translation_when_none_exists(self):
self.translated_model.delete()
self.assertFalse(self.main_instance.has_translation(self.another_locale))
def test_copy_for_translation(self):
self.translated_model.delete()
copy = self.main_instance.copy_for_translation(locale=self.another_locale)
self.assertNotEqual(copy, self.main_instance)
self.assertEqual(copy.translation_key, self.main_instance.translation_key)
self.assertEqual(copy.locale, self.another_locale)
self.assertEqual("Main Model", copy.title)
def test_get_translation_model(self):
self.assertEqual(self.main_instance.get_translation_model(), TestModel)
# test with a model that inherits from `TestModel`
inherited_model = make_test_instance(model=InheritedTestModel)
self.assertEqual(inherited_model.get_translation_model(), TestModel)
def test_copy_inherited_model_for_translation(self):
instance = make_test_instance(model=InheritedTestModel)
copy = instance.copy_for_translation(locale=self.another_locale)
self.assertNotEqual(copy, instance)
self.assertEqual(copy.translation_key, instance.translation_key)
self.assertEqual(copy.locale, self.another_locale)
def test_copy_clusterable_model_for_translation(self):
instance = ClusterableTestModel.objects.create(
title="A test clusterable model",
children=[
ClusterableTestModelChild(field="A non-translatable child object"),
],
translatable_children=[
ClusterableTestModelTranslatableChild(
field="A translatable child object"
),
],
)
copy = instance.copy_for_translation(locale=self.another_locale)
instance_child = instance.children.get()
copy_child = copy.children.get()
instance_translatable_child = instance.translatable_children.get()
copy_translatable_child = copy.translatable_children.get()
self.assertNotEqual(copy, instance)
self.assertEqual(copy.translation_key, instance.translation_key)
self.assertEqual(copy.locale, self.another_locale)
# Check children were copied
self.assertNotEqual(copy_child, instance_child)
self.assertEqual(copy_child.field, "A non-translatable child object")
self.assertNotEqual(copy_translatable_child, instance_translatable_child)
self.assertEqual(copy_translatable_child.field, "A translatable child object")
# Check the translatable child's locale was updated but translation key is the same
self.assertEqual(
copy_translatable_child.translation_key,
instance_translatable_child.translation_key,
)
self.assertEqual(copy_translatable_child.locale, self.another_locale)
@override_settings(WAGTAIL_I18N_ENABLED=True)
class TestLocalized(TestCase):
def setUp(self):
self.en_locale = Locale.objects.get()
self.fr_locale = Locale.objects.create(language_code="fr")
self.en_instance = make_test_instance(locale=self.en_locale, title="Main Model")
self.fr_instance = make_test_instance(
locale=self.fr_locale,
translation_key=self.en_instance.translation_key,
title="Main Model",
)
def test_localized_same_language(self):
# Shouldn't run an extra query if the instances locale matches the active language
# FIXME: Cache active locale record so this is zero
with self.assertNumQueries(1):
instance = self.en_instance.localized
self.assertEqual(instance, self.en_instance)
def test_localized_different_language(self):
with self.assertNumQueries(2):
instance = self.fr_instance.localized
self.assertEqual(instance, self.en_instance)
class TestSystemChecks(TestCase):
def test_unique_together_raises_no_error(self):
# The default unique_together should not raise an error
errors = TestModel.check()
self.assertEqual(len(errors), 0)
def test_unique_constraint_raises_no_error(self):
# Allow replacing unique_together with UniqueConstraint
# https://github.com/wagtail/wagtail/issues/11098
previous_unique_together = TestModel._meta.unique_together
try:
TestModel._meta.unique_together = []
TestModel._meta.constraints = [
models.UniqueConstraint(
fields=["translation_key", "locale"],
name="unique_translation_key_locale_%(app_label)s_%(class)s",
)
]
errors = TestModel.check()
finally:
TestModel._meta.unique_together = previous_unique_together
TestModel._meta.constraints = []
self.assertEqual(len(errors), 0)
def test_raises_error_if_both_unique_constraint_and_unique_together_are_missing(
self,
):
# The model has unique_together and not UniqueConstraint, remove
# unique_together to trigger the error
previous_unique_together = TestModel._meta.unique_together
try:
TestModel._meta.unique_together = []
errors = TestModel.check()
finally:
TestModel._meta.unique_together = previous_unique_together
self.assertEqual(len(errors), 1)
self.assertIsInstance(errors[0], checks.Error)
self.assertEqual(errors[0].id, "wagtailcore.E003")
self.assertEqual(
errors[0].msg,
"i18n.TestModel is missing a UniqueConstraint for the fields: "
"('translation_key', 'locale').",
)
self.assertEqual(
errors[0].hint,
"Add models.UniqueConstraint(fields=('translation_key', 'locale'), "
"name='unique_translation_key_locale_i18n_testmodel') to "
"TestModel.Meta.constraints.",
)
def test_error_with_both_unique_constraint_and_unique_together(self):
# The model already has unique_together, add a UniqueConstraint
# to trigger the error
try:
TestModel._meta.constraints = [
models.UniqueConstraint(
fields=["translation_key", "locale"],
name="unique_translation_key_locale_%(app_label)s_%(class)s",
)
]
errors = TestModel.check()
finally:
TestModel._meta.constraints = []
self.assertEqual(len(errors), 1)
self.assertIsInstance(errors[0], checks.Error)
self.assertEqual(errors[0].id, "wagtailcore.E003")
self.assertEqual(
errors[0].msg,
"i18n.TestModel should not have both UniqueConstraint and unique_together for: "
"('translation_key', 'locale').",
)
self.assertEqual(
errors[0].hint,
"Remove unique_together in favor of UniqueConstraint.",
)

View File

@@ -0,0 +1,618 @@
import hashlib
import pickle
import unittest
from io import BytesIO
from pathlib import Path
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.core.files.uploadedfile import SimpleUploadedFile
from django.test import SimpleTestCase, TestCase, override_settings
from django.utils.text import slugify
from django.utils.translation import _trans
from django.utils.translation import gettext_lazy as _
from wagtail.coreutils import (
InvokeViaAttributeShortcut,
accepts_kwarg,
camelcase_to_underscore,
cautious_slugify,
find_available_slug,
get_content_languages,
get_content_type_label,
get_dummy_request,
get_supported_content_language_variant,
multigetattr,
safe_snake_case,
string_to_ascii,
)
from wagtail.models import Page, Site
from wagtail.utils.file import hash_filelike
from wagtail.utils.utils import deep_update, flatten_choices
from wagtail.utils.version import get_main_version
class TestCamelCaseToUnderscore(TestCase):
def test_camelcase_to_underscore(self):
test_cases = [
("HelloWorld", "hello_world"),
("longValueWithVarious subStrings", "long_value_with_various sub_strings"),
]
for original, expected_result in test_cases:
self.assertEqual(camelcase_to_underscore(original), expected_result)
class TestStringToAscii(TestCase):
def test_string_to_ascii(self):
test_cases = [
("30 \U0001d5c4\U0001d5c6/\U0001d5c1", "30 km/h"),
("\u5317\u4EB0", "BeiJing"),
("ぁ あ ぃ い ぅ う ぇ", "a a i i u u e"),
(
"Ա Բ Գ Դ Ե Զ Է Ը Թ Ժ Ի Լ Խ Ծ Կ Հ Ձ Ղ Ճ Մ Յ Ն",
"A B G D E Z E Y T' Zh I L Kh Ts K H Dz Gh Ch M Y N",
),
("Спорт!", "Sport!"),
("Straßenbahn", "Strassenbahn"),
("Hello world", "Hello world"),
("Ā ā Ă ă Ą ą Ć ć Ĉ ĉ Ċ ċ Č č Ď ď Đ", "A a A a A a C c C c C c C c D d D"),
("〔山脈〕", "[ShanMai]"),
]
for original, expected_result in test_cases:
self.assertEqual(string_to_ascii(original), expected_result)
class TestCautiousSlugify(TestCase):
def test_behaves_same_as_slugify_for_latin_chars(self):
test_cases = [
("", ""),
("???", ""),
("Hello world", "hello-world"),
("Hello_world", "hello_world"),
("Hellö wörld", "hello-world"),
("Hello world", "hello-world"),
(" Hello world ", "hello-world"),
("Hello, world!", "hello-world"),
("Hello*world", "helloworld"),
("Hello☃world", "helloworld"),
]
for original, expected_result in test_cases:
self.assertEqual(slugify(original), expected_result)
self.assertEqual(cautious_slugify(original), expected_result)
def test_escapes_non_latin_chars(self):
test_cases = [
("Straßenbahn", "straxdfenbahn"),
("Спорт!", "u0421u043fu043eu0440u0442"),
("〔山脈〕", "u5c71u8108"),
]
for original, expected_result in test_cases:
self.assertEqual(cautious_slugify(original), expected_result)
class TestSafeSnakeCase(TestCase):
def test_strings_with_latin_chars(self):
test_cases = [
("", ""),
("???", ""),
("using-Hyphen", "using_hyphen"),
("endash", "endash"), # unicode non-letter characters stripped
(" em—dash ", "emdash"), # unicode non-letter characters stripped
(
"horizontal―BAR",
"horizontalbar",
), # unicode non-letter characters stripped
("Hello world", "hello_world"),
("Hello_world", "hello_world"),
("Hellö wörld", "hello_world"),
("Hello world", "hello_world"),
(" Hello world ", "hello_world"),
("Hello, world!", "hello_world"),
("Hello*world", "helloworld"),
(
"Screenshot_2020-05-29 Screenshot(1).png",
"screenshot_2020_05_29_screenshot1png",
),
]
for original, expected_result in test_cases:
self.assertEqual(safe_snake_case(original), expected_result)
def test_strings_with__non_latin_chars(self):
test_cases = [
("Straßenbahn Straßenbahn", "straxdfenbahn_straxdfenbahn"),
("Сп орт!", "u0421u043f_u043eu0440u0442"),
]
for original, expected_result in test_cases:
self.assertEqual(safe_snake_case(original), expected_result)
class TestAcceptsKwarg(TestCase):
def test_accepts_kwarg(self):
def func_without_banana(apple, orange=42):
pass
def func_with_banana(apple, banana=42):
pass
def func_with_kwargs(apple, **kwargs):
pass
self.assertFalse(accepts_kwarg(func_without_banana, "banana"))
self.assertTrue(accepts_kwarg(func_with_banana, "banana"))
self.assertTrue(accepts_kwarg(func_with_kwargs, "banana"))
class TestTargetClass:
"""
Used in TestInvokeViaAttributeShortcut (below)
"""
def __init__(self):
self.target_method_called_with = []
def target_method(self, arg):
self.target_method_called_with.append(arg)
class TestInvokeViaAttributeShortcut(SimpleTestCase):
def setUp(self):
self.target_object = TestTargetClass()
self.test_object = InvokeViaAttributeShortcut(
self.target_object, "target_method"
)
def test_basic(self):
for value in ("foo", "bar", "baz"):
# Use the shortcut to call the underlying method
getattr(self.test_object, value)
# Confirm that the underlying method was called
self.assertIn(value, self.target_object.target_method_called_with)
def test_pickleability(self):
try:
pickled = pickle.dumps(self.test_object, -1)
except Exception as e: # noqa: BLE001
raise AssertionError(
"An error occurred when attempting to pickle %r: %s"
% (self.test_object, e)
)
try:
self.test_object = pickle.loads(pickled)
except Exception as e: # noqa: BLE001
raise AssertionError(
"An error occurred when attempting to unpickle %r: %s"
% (self.test_object, e)
)
# Confirm unpickled object works the same
self.target_object = self.test_object.obj
self.test_basic()
class TestFindAvailableSlug(TestCase):
def setUp(self):
self.root_page = Page.objects.get(depth=1)
self.home_page = Page.objects.get(depth=2)
self.second_home_page = self.root_page.add_child(
instance=Page(title="Second homepage", slug="home-1")
)
def test_find_available_slug(self):
with self.assertNumQueries(1):
slug = find_available_slug(self.root_page, "unique-slug")
self.assertEqual(slug, "unique-slug")
def test_find_available_slug_already_used(self):
# Even though the first two slugs are already used, this still requires only one query to find a unique one
with self.assertNumQueries(1):
slug = find_available_slug(self.root_page, "home")
self.assertEqual(slug, "home-2")
def test_find_available_slug_ignore_page_id(self):
with self.assertNumQueries(1):
slug = find_available_slug(
self.root_page, "home", ignore_page_id=self.second_home_page.id
)
self.assertEqual(slug, "home-1")
@override_settings(
USE_I18N=True,
WAGTAIL_I18N_ENABLED=True,
LANGUAGES=[
("en", "English"),
("de", "German"),
("de-at", "Austrian German"),
("pt-br", "Portuguese (Brazil)"),
],
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("de", "German"),
("de-at", "Austrian German"),
("pt-br", "Portuguese (Brazil)"),
],
)
class TestGetContentLanguages(TestCase):
def test_get_content_languages(self):
self.assertEqual(
get_content_languages(),
{
"de": "German",
"de-at": "Austrian German",
"en": "English",
"pt-br": "Portuguese (Brazil)",
},
)
@override_settings(
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("de", "German"),
],
)
def test_can_be_different_to_django_languages(self):
self.assertEqual(
get_content_languages(),
{
"de": "German",
"en": "English",
},
)
@override_settings(
WAGTAIL_CONTENT_LANGUAGES=[
("en", _("English")),
("de", _("German")),
],
)
def test_can_be_a_translation_proxy(self):
self.assertEqual(
get_content_languages(),
{
"de": "German",
"en": "English",
},
)
@override_settings(
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("de", "German"),
("zh", "Chinese"),
],
)
def test_must_be_subset_of_django_languages(self):
with self.assertRaises(ImproperlyConfigured) as e:
get_content_languages()
self.assertEqual(
e.exception.args,
(
"The language zh is specified in WAGTAIL_CONTENT_LANGUAGES but not LANGUAGES. WAGTAIL_CONTENT_LANGUAGES must be a subset of LANGUAGES.",
),
)
def TestGetContentTypeLabel(TestCase):
def test_none(self):
self.assertEqual(get_content_type_label(None), "Unknown content type")
def test_valid_content_type(self):
page_content_type = ContentType.objects.get_for_model(Page)
self.assertEqual(get_content_type_label(page_content_type), "Page")
def test_stale_content_type(self):
stale_content_type = ContentType.objects.create(
app_label="fake_app", model="deleted model"
)
self.assertEqual(get_content_type_label(stale_content_type), "Deleted model")
@override_settings(
USE_I18N=True,
WAGTAIL_I18N_ENABLED=True,
LANGUAGES=[
("en", "English"),
("de", "German"),
("de-at", "Austrian German"),
("pt-br", "Portuguese (Brazil)"),
],
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("de", "German"),
("de-at", "Austrian German"),
("pt-br", "Portuguese (Brazil)"),
],
)
class TestGetSupportedContentLanguageVariant(TestCase):
# From: https://github.com/django/django/blob/9e57b1efb5205bd94462e9de35254ec5ea6eb04e/tests/i18n/tests.py#L1481
def test_get_supported_content_language_variant(self):
g = get_supported_content_language_variant
self.assertEqual(g("en"), "en")
self.assertEqual(g("en-gb"), "en")
self.assertEqual(g("de"), "de")
self.assertEqual(g("de-at"), "de-at")
self.assertEqual(g("de-ch"), "de")
self.assertEqual(g("pt-br"), "pt-br")
self.assertEqual(g("pt"), "pt-br")
self.assertEqual(g("pt-pt"), "pt-br")
with self.assertRaises(LookupError):
g("pt", strict=True)
with self.assertRaises(LookupError):
g("pt-pt", strict=True)
with self.assertRaises(LookupError):
g("xyz")
with self.assertRaises(LookupError):
g("xy-zz")
@override_settings(
WAGTAIL_CONTENT_LANGUAGES=[
("en", "English"),
("de", "German"),
]
)
def test_uses_wagtail_content_languages(self):
# be sure it's not using Django's LANGUAGES
g = get_supported_content_language_variant
self.assertEqual(g("en"), "en")
self.assertEqual(g("en-gb"), "en")
self.assertEqual(g("de"), "de")
self.assertEqual(g("de-at"), "de")
self.assertEqual(g("de-ch"), "de")
with self.assertRaises(LookupError):
g("pt-br")
with self.assertRaises(LookupError):
g("pt")
with self.assertRaises(LookupError):
g("pt-pt")
with self.assertRaises(LookupError):
g("xyz")
with self.assertRaises(LookupError):
g("xy-zz")
@override_settings(
USE_I18N=False,
WAGTAIL_I18N_ENABLED=False,
WAGTAIL_CONTENT_LANGUAGES=None,
LANGUAGE_CODE="en-us",
)
class TestGetSupportedContentLanguageVariantWithI18nFalse(TestCase):
def setUp(self):
# Need to forcibly clear the django.utils.translation._trans object when overriding
# USE_I18N:
# https://github.com/django/django/blob/3.1/django/utils/translation/__init__.py#L46-L48
_trans.__dict__.clear()
def tearDown(self):
_trans.__dict__.clear()
def test_lookup_language_with_i18n_false(self):
# Make sure we can handle the 'null' USE_I18N=False implementation of
# get_supported_language_variant returning 'en-us' rather than 'en',
# despite 'en-us' not being in LANGUAGES.
# https://github.com/wagtail/wagtail/issues/6539
self.assertEqual(get_supported_content_language_variant("en-us"), "en")
@override_settings(LANGUAGE_CODE="zz")
def test_language_code_not_in_languages(self):
# Ensure we can handle a LANGUAGE_CODE setting that isn't defined in LANGUAGES -
# in this case get_content_languages has to cope with not being able to retrieve
# a display name for the language
self.assertEqual(get_supported_content_language_variant("zz"), "zz")
self.assertEqual(get_supported_content_language_variant("zz-gb"), "zz")
class TestMultigetattr(TestCase):
def setUp(self):
class Thing:
colour = "green"
limbs = {"arms": 2, "legs": 3}
def __init__(self):
self.poke_was_called = False
def speak(self):
return "raaargh"
def feed(self, food):
return "gobble"
def poke(self):
self.poke_was_called = True
raise Exception("don't do that")
poke.alters_data = True
self.thing = Thing()
def test_multigetattr(self):
self.assertEqual(multigetattr(self.thing, "colour"), "green")
self.assertEqual(multigetattr(self, "thing.colour"), "green")
self.assertEqual(multigetattr(self.thing, "limbs.arms"), 2)
self.assertEqual(multigetattr(self.thing, "speak"), "raaargh")
self.assertEqual(multigetattr(self, "thing.speak.0"), "r")
with self.assertRaises(AttributeError):
multigetattr(self.thing, "name")
with self.assertRaises(AttributeError):
multigetattr(self.thing, "limbs.antennae")
with self.assertRaises(AttributeError):
multigetattr(self.thing, "speak.999")
with self.assertRaises(TypeError):
multigetattr(self.thing, "feed")
with self.assertRaises(SuspiciousOperation):
multigetattr(self.thing, "poke")
self.assertFalse(self.thing.poke_was_called)
class TestGetDummyRequest(TestCase):
def test_standard_port(self):
site = Site.objects.first()
site.hostname = "other.example.com"
site.port = 80
site.save()
request = get_dummy_request(site=site)
self.assertEqual(request.get_host(), "other.example.com")
def test_non_standard_port(self):
site = Site.objects.first()
site.hostname = "other.example.com"
site.port = 8888
site.save()
request = get_dummy_request(site=site)
self.assertEqual(request.get_host(), "other.example.com:8888")
def test_server_name_for_wildcard_allowed_hosts(self):
# Django's test runner adds "testserver" at the end of ALLOWED_HOSTS.
with self.settings(ALLOWED_HOSTS=["*", "testserver"]):
request = get_dummy_request()
self.assertEqual(request.get_host(), "example.com")
class TestDeepUpdate(TestCase):
def test_deep_update(self):
val = {
"captain": "picard",
"beverage": {
"type": "coffee",
"temperature": "hot",
},
}
deep_update(
val,
{
"beverage": {
"type": "tea",
"variant": "earl grey",
},
"starship": "enterprise",
},
)
self.assertEqual(
val,
{
"captain": "picard",
"beverage": {
"type": "tea",
"variant": "earl grey",
"temperature": "hot",
},
"starship": "enterprise",
},
)
class HashFileLikeTestCase(SimpleTestCase):
test_file = Path.cwd() / "LICENSE"
def test_hashes_io(self):
self.assertEqual(
hash_filelike(BytesIO(b"test")), "a94a8fe5ccb19ba61c4c0873d391e987982fbbd3"
)
def test_hashes_file(self):
with self.test_file.open(mode="rb") as f:
self.assertEqual(
hash_filelike(f), "9e58400061ca660ef7b5c94338a5205627c77eda"
)
def test_hashes_file_bytes(self):
with self.test_file.open(mode="rb") as f:
self.assertEqual(
hash_filelike(f), "9e58400061ca660ef7b5c94338a5205627c77eda"
)
def test_hashes_django_uploaded_file(self):
"""
Check Django's file shims can be hashed as-is.
`SimpleUploadedFile` inherits the base `UploadedFile`, but is easiest to test against
"""
self.assertEqual(
hash_filelike(SimpleUploadedFile("example.txt", b"test")),
"a94a8fe5ccb19ba61c4c0873d391e987982fbbd3",
)
@unittest.skipIf(
hasattr(hashlib, "file_digest"),
reason="`file_digest` doesn't support this interface",
)
def test_hashes_large_file(self):
class FakeLargeFile:
"""
A class that pretends to be a huge file (~1.3GB)
"""
def __init__(self):
self.iterations = 5000
def read(self, bytes):
self.iterations -= 1
if not self.iterations:
return b""
return b"A" * bytes
self.assertEqual(
hash_filelike(FakeLargeFile()),
"bd36f0c5a02cd6e9e34202ea3ff8db07b533e025",
)
class TestVersion(SimpleTestCase):
def test_get_main_version(self):
cases = [
((6, 2, 0, "final", 0), False, "6.2"),
((6, 2, 1, "final", 0), False, "6.2"),
((6, 2, 0, "final", 0), True, "6.2"),
((6, 2, 1, "final", 0), True, "6.2.1"),
]
for version, include_patch, expected in cases:
with self.subTest(version=version, include_patch=include_patch):
self.assertEqual(get_main_version(version, include_patch), expected)
class TestFlattenChoices(SimpleTestCase):
def test_tuple_choices(self):
choices = [(1, "1st"), (2, "2nd")]
self.assertEqual(flatten_choices(choices), {"1": "1st", "2": "2nd"})
def test_grouped_tuple_choices(self):
choices = [("Group", [(1, "1st"), (2, "2nd")])]
self.assertEqual(flatten_choices(choices), {"1": "1st", "2": "2nd"})
def test_dictionary_choices(self):
choices = {
"Martial Arts": {"judo": "Judo", "karate": "Karate"},
"Racket": {"badminton": "Badminton", "tennis": "Tennis"},
"unknown": "Unknown",
}
self.assertEqual(
flatten_choices(choices),
{
"judo": "Judo",
"karate": "Karate",
"badminton": "Badminton",
"tennis": "Tennis",
"unknown": "Unknown",
},
)

View File

@@ -0,0 +1,100 @@
from unittest import mock
from django.test import TestCase
from django.urls import reverse
from wagtail.coreutils import get_dummy_request
from wagtail.models import Page, Site
from wagtail.test.testapp.models import SimplePage
from wagtail.test.utils import WagtailTestUtils
from wagtail.views import serve
class TestLoginView(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
def setUp(self):
self.user = self.create_test_user()
self.events_index = Page.objects.get(url_path="/home/events/")
def test_get(self):
response = self.client.get(reverse("wagtailcore_login"))
self.assertEqual(response.status_code, 200)
self.assertContains(response, "<h1>Log in</h1>")
self.assertNotContains(
response,
"<p>Your username and password didn't match. Please try again.</p>",
)
def test_post_incorrect_password(self):
response = self.client.post(
reverse("wagtailcore_login"),
{
"username": "test@email.com",
"password": "wrongpassword",
"next": self.events_index.url,
},
)
self.assertEqual(response.status_code, 200)
self.assertContains(response, "<h1>Log in</h1>")
self.assertContains(
response,
"<p>Your username and password didn't match. Please try again.</p>",
)
def test_post_correct_password(self):
response = self.client.post(
reverse("wagtailcore_login"),
{
"username": "test@email.com",
"password": "password",
"next": self.events_index.url,
},
)
self.assertRedirects(response, self.events_index.url)
@mock.patch("wagtail.hooks.get_hooks", mock.Mock(return_value=[]))
class TestServeView(TestCase):
fixtures = ["test.json"]
def test_serve_query_count(self):
request = get_dummy_request()
Site.find_for_request(request)
page, args, kwargs = Page.route_for_request(request, request.path)
with mock.patch.object(page, "serve", wraps=page.serve) as m:
with self.assertNumQueries(0):
serve(request, "/")
m.assert_called_once_with(request, *args, **kwargs)
def test_process_view_by_page_query_count(self):
expected_query_count = 3
site = Site.objects.get()
page = site.root_page.add_child(
instance=SimplePage(title="Simple page", slug="simple", content="Simple")
)
with mock.patch.object(
Page, "route_for_request", wraps=Page.route_for_request
) as m:
with self.modify_settings(
MIDDLEWARE={
"prepend": "wagtail.test.middleware.SimplePageViewInterceptorMiddleware"
}
):
with self.assertNumQueries(expected_query_count):
response_a = self.client.get("/simple/")
self.assertEqual(
response_a.content,
b'\n\n\n\n<!DOCTYPE HTML>\n<html lang="en" dir="ltr">\n <head>\n <title>Simple page</title>\n </head>\n <body>\n \n <h1>Simple page</h1>\n \n <h2>Simple page</h2>\n\n </body>\n</html>\n',
)
self.assertEqual(m.call_count, 2)
page.content = "Intercept me"
page.save_revision().publish()
m.reset_mock()
with self.assertNumQueries(expected_query_count):
# verify the same number of queries are used when the
# middleware activates to demonstrate Page.route_for_request()
# prevents extra database queries for serving pages
response_b = self.client.get("/simple/")
self.assertEqual(response_b.content, b"Intercepted")
self.assertEqual(m.call_count, 1)

View File

@@ -0,0 +1,161 @@
from django.test import TestCase
from wagtail.test.utils import WagtailTestUtils
from wagtail.whitelist import (
Whitelister,
allow_without_attributes,
attribute_rule,
check_url,
)
class TestCheckUrl(TestCase):
def test_allowed_url_schemes(self):
for url_scheme in ["", "http", "https", "ftp", "mailto", "tel"]:
url = url_scheme + "://www.example.com"
self.assertTrue(bool(check_url(url)))
def test_disallowed_url_scheme(self):
self.assertFalse(bool(check_url("invalid://url")))
def test_crafty_disallowed_url_scheme(self):
"""
Some URL parsers do not parse 'jav\tascript:' as a valid scheme.
Browsers, however, do. The checker needs to catch these crafty schemes
"""
self.assertFalse(bool(check_url("jav\tascript:alert('XSS')")))
class TestAttributeRule(WagtailTestUtils, TestCase):
def setUp(self):
self.soup = self.get_soup('<b foo="bar">baz</b>')
def test_no_rule_for_attr(self):
"""
Test that attribute_rule() drops attributes for
which no rule has been defined.
"""
tag = self.soup.b
fn = attribute_rule({"snowman": "barbecue"})
fn(tag)
self.assertEqual(str(tag), "<b>baz</b>")
def test_rule_true_for_attr(self):
"""
Test that attribute_rule() does not change attributes
when the corresponding rule returns True
"""
tag = self.soup.b
fn = attribute_rule({"foo": True})
fn(tag)
self.assertEqual(str(tag), '<b foo="bar">baz</b>')
def test_rule_false_for_attr(self):
"""
Test that attribute_rule() drops attributes
when the corresponding rule returns False
"""
tag = self.soup.b
fn = attribute_rule({"foo": False})
fn(tag)
self.assertEqual(str(tag), "<b>baz</b>")
def test_callable_called_on_attr(self):
"""
Test that when the rule returns a callable,
attribute_rule() replaces the attribute with
the result of calling the callable on the attribute.
"""
tag = self.soup.b
fn = attribute_rule({"foo": len})
fn(tag)
self.assertEqual(str(tag), '<b foo="3">baz</b>')
def test_callable_returns_None(self):
"""
Test that when the rule returns a callable,
attribute_rule() replaces the attribute with
the result of calling the callable on the attribute.
"""
tag = self.soup.b
fn = attribute_rule({"foo": lambda x: None})
fn(tag)
self.assertEqual(str(tag), "<b>baz</b>")
def test_allow_without_attributes(self):
"""
Test that attribute_rule() with will drop all
attributes.
"""
soup = self.get_soup(
'<b foo="bar" baz="quux" snowman="barbecue"></b>',
)
tag = soup.b
allow_without_attributes(tag)
self.assertEqual(str(tag), "<b></b>")
class TestWhitelister(WagtailTestUtils, TestCase):
def setUp(self):
self.whitelister = Whitelister()
def test_clean_unknown_node(self):
"""
Unknown node should remove a node from the parent document
"""
soup = self.get_soup("<foo><bar>baz</bar>quux</foo>")
tag = soup.foo
self.whitelister.clean_unknown_node("", soup.bar)
self.assertEqual(str(tag), "<foo>quux</foo>")
def test_clean_tag_node_cleans_nested_recognised_node(self):
"""
<b> tags are allowed without attributes. This remains true
when tags are nested.
"""
soup = self.get_soup('<b><b class="delete me">foo</b></b>')
tag = soup.b
self.whitelister.clean_tag_node(tag, tag)
self.assertEqual(str(tag), "<b><b>foo</b></b>")
def test_clean_tag_node_disallows_nested_unrecognised_node(self):
"""
<foo> tags should be removed, even when nested.
"""
soup = self.get_soup("<b><foo>bar</foo></b>")
tag = soup.b
self.whitelister.clean_tag_node(tag, tag)
self.assertEqual(str(tag), "<b>bar</b>")
def test_clean_string_node_does_nothing(self):
soup = self.get_soup("<b>bar</b>")
string = soup.b.string
self.whitelister.clean_string_node(string, string)
self.assertEqual(str(string), "bar")
def test_clean_node_does_not_change_navigable_strings(self):
soup = self.get_soup("<b>bar</b>")
string = soup.b.string
self.whitelister.clean_node(string, string)
self.assertEqual(str(string), "bar")
def test_clean(self):
"""
Whitelister.clean should remove disallowed tags and attributes from
a string
"""
string = '<b foo="bar">snowman <barbecue>Yorkshire</barbecue></b>'
cleaned_string = self.whitelister.clean(string)
self.assertEqual(cleaned_string, "<b>snowman Yorkshire</b>")
def test_clean_comments(self):
string = "<b>snowman Yorkshire<!--[if gte mso 10]>MS word junk<![endif]--></b>"
cleaned_string = self.whitelister.clean(string)
self.assertEqual(cleaned_string, "<b>snowman Yorkshire</b>")
def test_quoting(self):
string = '<img alt="Arthur &quot;two sheds&quot; Jackson" sheds="2">'
cleaned_string = self.whitelister.clean(string)
self.assertEqual(
cleaned_string, '<img alt="Arthur &quot;two sheds&quot; Jackson"/>'
)

View File

@@ -0,0 +1,501 @@
import datetime
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Group
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.db.utils import IntegrityError
from django.test import TestCase, override_settings
from django.utils import timezone
from freezegun import freeze_time
from wagtail.models import (
GroupApprovalTask,
Page,
Task,
TaskState,
Workflow,
WorkflowContentType,
WorkflowPage,
WorkflowState,
WorkflowTask,
)
from wagtail.test.testapp.models import FullFeaturedSnippet, ModeratedModel, SimplePage
from wagtail.test.utils.wagtail_tests import WagtailTestUtils
class TestWorkflowModels(TestCase):
fixtures = ["test.json"]
def test_create_workflow(self):
# test creating and retrieving an empty Workflow from the db
test_workflow = Workflow(name="test_workflow")
test_workflow.save()
retrieved_workflow = Workflow.objects.get(id=test_workflow.id)
self.assertEqual(retrieved_workflow.name, "test_workflow")
def test_create_task(self):
# test creating and retrieving a base Task from the db
test_task = Task(name="test_task")
test_task.save()
retrieved_task = Task.objects.get(id=test_task.id)
self.assertEqual(retrieved_task.name, "test_task")
def test_add_task_to_workflow(self):
workflow = Workflow.objects.create(name="test_workflow")
task = Task.objects.create(name="test_task")
WorkflowTask.objects.create(workflow=workflow, task=task, sort_order=1)
self.assertIn(task, Task.objects.filter(workflow_tasks__workflow=workflow))
self.assertIn(workflow, Workflow.objects.filter(workflow_tasks__task=task))
def test_add_workflow_to_page(self):
# test adding a Workflow to a Page via WorkflowPage
workflow = Workflow.objects.create(name="test_workflow")
homepage = Page.objects.get(url_path="/home/")
WorkflowPage.objects.create(page=homepage, workflow=workflow)
homepage.refresh_from_db()
self.assertEqual(homepage.workflowpage.workflow, workflow)
def test_add_workflow_to_snippet(self):
# test adding a Workflow to a snippet via WorkflowContentType
workflow = Workflow.objects.create(name="test_workflow")
content_type = ContentType.objects.get_for_model(FullFeaturedSnippet)
WorkflowContentType.objects.create(content_type=content_type, workflow=workflow)
snippet = FullFeaturedSnippet.objects.create(text="foo")
# The FullFeaturedSnippet class should now have a default workflow
self.assertEqual(FullFeaturedSnippet.get_default_workflow(), workflow)
# Instances of FullFeaturedSnippet should have a workflow
self.assertEqual(snippet.get_workflow(), workflow)
def test_get_specific_task(self):
# test ability to get instance of subclassed Task type using Task.specific
group_approval_task = GroupApprovalTask.objects.create(
name="test_group_approval"
)
group_approval_task.groups.set(Group.objects.all())
task = Task.objects.get(name="test_group_approval")
specific_task = task.specific
self.assertIsInstance(specific_task, GroupApprovalTask)
def test_get_workflow_from_parent(self):
# test ability to use Page.get_workflow() to retrieve a Workflow from a parent Page if none is set directly
workflow = Workflow.objects.create(name="test_workflow")
homepage = Page.objects.get(url_path="/home/")
WorkflowPage.objects.create(page=homepage, workflow=workflow)
hello_page = SimplePage(
title="Hello world", slug="hello-world", content="hello"
)
homepage.add_child(instance=hello_page)
self.assertEqual(hello_page.get_workflow(), workflow)
self.assertTrue(workflow.all_pages().filter(id=hello_page.id).exists())
def test_get_workflow_from_closest_ancestor(self):
# test that using Page.get_workflow() tries to get the workflow from itself, then the closest ancestor, and does
# not get Workflows from further up the page tree first
workflow_1 = Workflow.objects.create(name="test_workflow_1")
workflow_2 = Workflow.objects.create(name="test_workflow_2")
homepage = Page.objects.get(url_path="/home/")
WorkflowPage.objects.create(page=homepage, workflow=workflow_1)
hello_page = SimplePage(
title="Hello world", slug="hello-world", content="hello"
)
homepage.add_child(instance=hello_page)
WorkflowPage.objects.create(page=hello_page, workflow=workflow_2)
goodbye_page = SimplePage(
title="Goodbye world", slug="goodbye-world", content="goodbye"
)
hello_page.add_child(instance=goodbye_page)
self.assertEqual(hello_page.get_workflow(), workflow_2)
self.assertEqual(goodbye_page.get_workflow(), workflow_2)
# Check the .all_pages() method
self.assertFalse(workflow_1.all_pages().filter(id=hello_page.id).exists())
self.assertFalse(workflow_1.all_pages().filter(id=goodbye_page.id).exists())
self.assertTrue(workflow_2.all_pages().filter(id=hello_page.id).exists())
self.assertTrue(workflow_2.all_pages().filter(id=goodbye_page.id).exists())
class TestPageWorkflows(WagtailTestUtils, TestCase):
fixtures = ["test.json"]
@classmethod
def setUpTestData(cls):
cls.object = Page.objects.get(url_path="/home/")
def create_workflow_and_tasks(self):
workflow = Workflow.objects.create(name="test_workflow")
task_1 = Task.objects.create(name="test_task_1")
task_2 = Task.objects.create(name="test_task_2")
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
WorkflowTask.objects.create(workflow=workflow, task=task_2, sort_order=2)
return workflow, task_1, task_2
def start_workflow(self):
workflow, task_1, task_2 = self.create_workflow_and_tasks()
self.object.save_revision()
user = get_user_model().objects.first()
workflow_state = workflow.start(self.object, user)
return {
"workflow_state": workflow_state,
"user": user,
"object": self.object,
"task_1": task_1,
"task_2": task_2,
"workflow": workflow,
}
@override_settings(WAGTAIL_WORKFLOW_ENABLED=False)
def test_workflow_methods_generate_no_queries_when_disabled(self):
with self.assertNumQueries(0):
self.assertIs(self.object.has_workflow, False)
with self.assertNumQueries(0):
self.assertIsNone(self.object.get_workflow())
with self.assertNumQueries(0):
self.assertIs(self.object.workflow_in_progress, False)
with self.assertNumQueries(0):
self.assertIsNone(self.object.current_workflow_state)
with self.assertNumQueries(0):
self.assertIsNone(self.object.current_workflow_task_state)
with self.assertNumQueries(0):
self.assertIsNone(self.object.current_workflow_task)
@freeze_time("2017-01-01 12:00:00")
def test_start_workflow(self):
# test the first WorkflowState and TaskState models are set up correctly when Workflow.start(object) is used.
data = self.start_workflow()
workflow_state = data["workflow_state"]
self.assertEqual(workflow_state.workflow, data["workflow"])
self.assertEqual(workflow_state.content_object, data["object"])
self.assertEqual(workflow_state.status, "in_progress")
if settings.USE_TZ:
self.assertEqual(
workflow_state.created_at,
datetime.datetime(2017, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc),
)
else:
self.assertEqual(
workflow_state.created_at, datetime.datetime(2017, 1, 1, 12, 0, 0)
)
self.assertEqual(workflow_state.requested_by, data["user"])
task_state = workflow_state.current_task_state
self.assertEqual(task_state.task, data["task_1"])
self.assertEqual(task_state.status, "in_progress")
self.assertEqual(task_state.revision, data["object"].get_latest_revision())
if settings.USE_TZ:
self.assertEqual(
task_state.started_at,
datetime.datetime(2017, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc),
)
else:
self.assertEqual(
task_state.started_at, datetime.datetime(2017, 1, 1, 12, 0, 0)
)
self.assertIsNone(task_state.finished_at)
@override_settings(WAGTAIL_WORKFLOW_CANCEL_ON_PUBLISH=True)
def test_publishing_cancels_workflow_when_cancel_on_publish_true(self):
data = self.start_workflow()
data["object"].get_latest_revision().publish()
workflow_state = data["workflow_state"]
workflow_state.refresh_from_db()
self.assertEqual(workflow_state.status, WorkflowState.STATUS_CANCELLED)
@override_settings(WAGTAIL_WORKFLOW_CANCEL_ON_PUBLISH=False)
def test_publishing_does_not_cancel_workflow_when_cancel_on_publish_false(
self,
):
data = self.start_workflow()
data["object"].get_latest_revision().publish()
workflow_state = data["workflow_state"]
workflow_state.refresh_from_db()
self.assertEqual(workflow_state.status, WorkflowState.STATUS_IN_PROGRESS)
def test_error_when_starting_multiple_in_progress_workflows(self):
# test trying to start multiple status='in_progress' workflows on a single object will trigger an IntegrityError
self.start_workflow()
with self.assertRaises((IntegrityError, ValidationError)):
self.start_workflow()
@freeze_time("2017-01-01 12:00:00")
def test_approve_workflow(self):
# tests that approving both TaskStates in a Workflow via Task.on_action approves tasks and publishes the revision correctly
data = self.start_workflow()
workflow_state = data["workflow_state"]
task_2 = data["task_2"]
object = data["object"]
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
if settings.USE_TZ:
self.assertEqual(
task_state.finished_at,
datetime.datetime(2017, 1, 1, 12, 0, 0, tzinfo=datetime.timezone.utc),
)
else:
self.assertEqual(
task_state.finished_at, datetime.datetime(2017, 1, 1, 12, 0, 0)
)
self.assertEqual(task_state.status, "approved")
self.assertEqual(workflow_state.current_task_state.task, task_2)
task_2.on_action(
workflow_state.current_task_state, user=None, action_name="approve"
)
self.assertEqual(workflow_state.status, "approved")
object.refresh_from_db()
self.assertEqual(
object.live_revision, workflow_state.current_task_state.revision
)
@override_settings(WAGTAIL_WORKFLOW_REQUIRE_REAPPROVAL_ON_EDIT=True)
def test_workflow_resets_when_new_revision_created(self):
# test that a Workflow on its second Task returns to its first task (upon WorkflowState.update()) if a new revision is created
data = self.start_workflow()
workflow_state = data["workflow_state"]
task_1 = data["task_1"]
task_2 = data["task_2"]
object = data["object"]
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
self.assertEqual(workflow_state.current_task_state.task, task_2)
object.save_revision()
workflow_state.refresh_from_db()
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
workflow_state.refresh_from_db()
task_state = workflow_state.current_task_state
self.assertEqual(task_state.task, task_1)
@override_settings(WAGTAIL_WORKFLOW_REQUIRE_REAPPROVAL_ON_EDIT=False)
def test_workflow_does_not_reset_when_new_revision_created_if_reapproval_turned_off(
self,
):
# test that a Workflow on its second Task does not return to its first task (upon approval) if a new revision is created
data = self.start_workflow()
workflow_state = data["workflow_state"]
task_1 = data["task_1"]
task_2 = data["task_2"]
object = data["object"]
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
self.assertEqual(workflow_state.current_task_state.task, task_2)
object.save_revision()
workflow_state.refresh_from_db()
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="approve")
workflow_state.refresh_from_db()
task_state = workflow_state.current_task_state
self.assertNotEqual(task_state.task, task_1)
self.assertEqual(workflow_state.status, workflow_state.STATUS_APPROVED)
def test_reject_workflow(self):
# test that TaskState is marked as rejected upon Task.on_action with action=reject
# and the WorkflowState as needs changes
data = self.start_workflow()
workflow_state = data["workflow_state"]
task_state = workflow_state.current_task_state
task_state.task.on_action(task_state, user=None, action_name="reject")
self.assertEqual(task_state.status, task_state.STATUS_REJECTED)
self.assertEqual(workflow_state.status, workflow_state.STATUS_NEEDS_CHANGES)
def test_resume_workflow(self):
# test that a Workflow rejected on its second Task can be resumed on the second task
data = self.start_workflow()
workflow_state = data["workflow_state"]
task_2 = data["task_2"]
workflow_state.current_task_state.approve(user=None)
workflow_state.refresh_from_db()
workflow_state.current_task_state.reject(user=None)
workflow_state.refresh_from_db()
workflow_state.resume(user=None)
self.assertEqual(workflow_state.status, workflow_state.STATUS_IN_PROGRESS)
self.assertEqual(
workflow_state.current_task_state.status,
workflow_state.current_task_state.STATUS_IN_PROGRESS,
)
self.assertEqual(workflow_state.current_task_state.task, task_2)
self.assertTrue(workflow_state.is_active)
def test_tasks_with_status_on_resubmission(self):
# test that a Workflow rejected and resumed shows the status of the latest tasks when _`all_tasks_with_status` is called
data = self.start_workflow()
workflow_state = data["workflow_state"]
tasks = workflow_state.all_tasks_with_status()
self.assertEqual(tasks[0].status, TaskState.STATUS_IN_PROGRESS)
self.assertEqual(tasks[1].status_display, "Not started")
workflow_state.current_task_state.approve(user=None)
workflow_state.refresh_from_db()
workflow_state.current_task_state.reject(user=None)
workflow_state.refresh_from_db()
tasks = workflow_state.all_tasks_with_status()
self.assertEqual(tasks[0].status, TaskState.STATUS_APPROVED)
self.assertEqual(tasks[1].status, TaskState.STATUS_REJECTED)
workflow_state.resume(user=None)
tasks = workflow_state.all_tasks_with_status()
self.assertEqual(tasks[0].status, TaskState.STATUS_APPROVED)
self.assertEqual(tasks[1].status, TaskState.STATUS_IN_PROGRESS)
def test_cancel_workflow(self):
# test that cancelling a workflow state sets both current task state and its own statuses to cancelled, and cancels all in progress states
data = self.start_workflow()
workflow_state = data["workflow_state"]
workflow_state.cancel(user=None)
workflow_state.refresh_from_db()
self.assertEqual(workflow_state.status, WorkflowState.STATUS_CANCELLED)
self.assertEqual(
workflow_state.current_task_state.status, TaskState.STATUS_CANCELLED
)
self.assertFalse(
TaskState.objects.filter(
workflow_state=workflow_state, status=TaskState.STATUS_IN_PROGRESS
).exists()
)
self.assertFalse(workflow_state.is_active)
def test_task_workflows(self):
workflow = Workflow.objects.create(name="test_workflow")
disabled_workflow = Workflow.objects.create(
name="disabled_workflow", active=False
)
task = Task.objects.create(name="test_task")
WorkflowTask.objects.create(workflow=workflow, task=task, sort_order=1)
WorkflowTask.objects.create(workflow=disabled_workflow, task=task, sort_order=1)
self.assertEqual(list(task.workflows), [workflow, disabled_workflow])
self.assertEqual(list(task.active_workflows), [workflow])
def test_is_at_final_task(self):
# test that a Workflow rejected on its second Task can be resumed on the second task
data = self.start_workflow()
workflow_state = data["workflow_state"]
self.assertFalse(workflow_state.is_at_final_task)
workflow_state.current_task_state.approve(user=None)
workflow_state.refresh_from_db()
self.assertTrue(workflow_state.is_at_final_task)
def test_tasks_with_state(self):
data = self.start_workflow()
workflow_state = data["workflow_state"]
tasks = workflow_state.all_tasks_with_state()
self.assertEqual(tasks[0].task_state.status, TaskState.STATUS_IN_PROGRESS)
workflow_state.current_task_state.approve(user=None)
workflow_state.refresh_from_db()
workflow_state.current_task_state.reject(user=None)
workflow_state.refresh_from_db()
tasks = workflow_state.all_tasks_with_state()
self.assertEqual(tasks[0].task_state.status, TaskState.STATUS_APPROVED)
self.assertEqual(tasks[1].task_state.status, TaskState.STATUS_REJECTED)
workflow_state.resume(user=None)
tasks = workflow_state.all_tasks_with_state()
self.assertEqual(tasks[0].task_state.status, TaskState.STATUS_APPROVED)
self.assertEqual(tasks[1].task_state.status, TaskState.STATUS_IN_PROGRESS)
self.assertEqual(
tasks[1].task_state,
TaskState.objects.filter(workflow_state=workflow_state).order_by(
"-started_at", "-id"
)[0],
)
def test_start_workflow_group_approval_task_locked(self):
self.object.locked = True
self.object.locked_at = timezone.now()
self.object.locked_by = self.create_user("user1")
self.object.save()
# Create a workflow with one group approval task for the moderators group
moderators = Group.objects.get(name="Moderators")
workflow = Workflow.objects.create(name="test_workflow_foo")
task_1 = GroupApprovalTask.objects.create(name="test_task_1")
task_1.groups.add(moderators)
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
# The object was locked by a non-moderator
self.assertFalse(self.object.locked_by.groups.filter(id=moderators.id).exists())
# Start the workflow as another user
self.object.save_revision()
workflow_state = workflow.start(self.object, self.create_user("user2"))
self.assertEqual(workflow_state.workflow, workflow)
self.assertEqual(workflow_state.content_object, self.object)
self.assertEqual(workflow_state.status, "in_progress")
self.object.refresh_from_db()
# The lock should be removed as otherwise the object would be stuck
self.assertFalse(self.object.locked)
self.assertIsNone(self.object.locked_at)
self.assertIsNone(self.object.locked_by)
def test_workflow_state_cascade_on_object_delete(self, cascades=True):
data = self.start_workflow()
query = {
"base_content_type": self.object.get_base_content_type(),
"object_id": str(self.object.pk),
}
self.assertEqual(
WorkflowState.objects.filter(**query).first(),
data["workflow_state"],
)
self.object.delete()
self.assertIs(WorkflowState.objects.filter(**query).exists(), not cascades)
class TestSnippetWorkflows(TestPageWorkflows):
fixtures = None
model = FullFeaturedSnippet
@classmethod
def setUpTestData(cls):
cls.object = cls.model.objects.create(text="foo")
class TestSnippetWorkflowsNotLockable(TestSnippetWorkflows):
model = ModeratedModel
def test_start_workflow_group_approval_task_locked(self):
# Test normal GroupApprovalTask.start() as the object is not lockable
# Create a workflow with one group approval task for the moderators group
moderators = Group.objects.get(name="Moderators")
workflow = Workflow.objects.create(name="test_workflow_foo")
task_1 = GroupApprovalTask.objects.create(name="test_task_1")
task_1.groups.add(moderators)
WorkflowTask.objects.create(workflow=workflow, task=task_1, sort_order=1)
# Start the workflow
self.object.save_revision()
workflow_state = workflow.start(self.object, self.create_user("user2"))
self.assertEqual(workflow_state.workflow, workflow)
self.assertEqual(workflow_state.content_object, self.object)
self.assertEqual(workflow_state.status, "in_progress")
def test_workflow_state_cascade_on_object_delete(self):
# We expect the cascade to not happen as the model does not define
# a GenericRelation to WorkflowState. However, workflows should still
# work as expected.
# See https://github.com/wagtail/wagtail/issues/11300 for more details.
return super().test_workflow_state_cascade_on_object_delete(cascades=False)

View File

@@ -0,0 +1,90 @@
from django.apps import apps
from django.core import checks
from django.db import models
from django.test import TestCase
from wagtail.models import DraftStateMixin, LockableMixin, RevisionMixin, WorkflowMixin
class TestWorkflowMixin(TestCase):
def tearDown(self):
# Unregister the models from the overall model registry
# so that it doesn't break tests elsewhere.
# We can probably replace this with Django's @isolate_apps decorator.
for package in ("wagtailcore", "wagtail.tests"):
try:
for model in (
"workflowwithoutrevisionmodel",
"workflowwithoutdraftstatemodel",
"workflowincorrectordermodel1",
"workflowincorrectordermodel2",
"correctworkflowmodel",
"correctnotlockableworkflowmodel",
):
del apps.all_models[package][model]
except KeyError:
pass
apps.clear_cache()
def test_missing_revision_or_draftstate_mixins(self):
error = checks.Error(
"WorkflowMixin requires DraftStateMixin and RevisionMixin (in that order).",
hint=(
"Make sure your model's inheritance order is as follows: "
"WorkflowMixin, DraftStateMixin, RevisionMixin."
),
id="wagtailcore.E006",
)
class WorkflowWithoutRevisionModel(WorkflowMixin, models.Model):
pass
class WorkflowWithoutDraftStateModel(
WorkflowMixin, RevisionMixin, models.Model
):
pass
for model in (WorkflowWithoutRevisionModel, WorkflowWithoutDraftStateModel):
with self.subTest(model=model):
error.obj = model
self.assertEqual(model.check(), [error])
def test_incorrect_mixins_order(self):
error = checks.Error(
"WorkflowMixin requires DraftStateMixin and RevisionMixin (in that order).",
hint=(
"Make sure your model's inheritance order is as follows: "
"WorkflowMixin, DraftStateMixin, RevisionMixin."
),
id="wagtailcore.E006",
)
class WorkflowIncorrectOrderModel1(
DraftStateMixin, WorkflowMixin, RevisionMixin, LockableMixin, models.Model
):
pass
class WorkflowIncorrectOrderModel2(
DraftStateMixin, RevisionMixin, WorkflowMixin, models.Model
):
pass
for model in (WorkflowIncorrectOrderModel1, WorkflowIncorrectOrderModel2):
with self.subTest(model=model):
error.obj = model
self.assertEqual(model.check(), [error])
def test_correct_mixins_order(self):
class CorrectWorkflowModel(
WorkflowMixin, DraftStateMixin, LockableMixin, RevisionMixin, models.Model
):
pass
class CorrectNotLockableWorkflowModel(
WorkflowMixin, DraftStateMixin, RevisionMixin, models.Model
):
pass
for model in (CorrectWorkflowModel, CorrectNotLockableWorkflowModel):
with self.subTest(model=model):
self.assertEqual(model.check(), [])

View File

@@ -0,0 +1,812 @@
import json
from django import template
from django.core.cache import cache
from django.core.cache.utils import make_template_fragment_key
from django.http import HttpRequest
from django.template import TemplateSyntaxError, VariableDoesNotExist
from django.test import TestCase
from django.test.utils import override_settings
from django.urls.exceptions import NoReverseMatch
from django.utils.safestring import SafeString
from django.utils.translation import gettext_lazy
from wagtail.coreutils import (
get_dummy_request,
make_wagtail_template_fragment_key,
resolve_model_string,
)
from wagtail.models import Locale, Page, Site, SiteRootPath
from wagtail.models.sites import (
SITE_ROOT_PATHS_CACHE_KEY,
SITE_ROOT_PATHS_CACHE_VERSION,
)
from wagtail.templatetags.wagtail_cache import WagtailPageCacheNode
from wagtail.templatetags.wagtailcore_tags import richtext, slugurl
from wagtail.test.testapp.models import SimplePage
class TestPageUrlTags(TestCase):
fixtures = ["test.json"]
def setUp(self):
super().setUp()
# Clear caches
cache.clear()
def test_pageurl_tag(self):
response = self.client.get("/events/")
self.assertEqual(response.status_code, 200)
self.assertContains(response, '<a href="/events/christmas/">Christmas</a>')
def test_pageurl_with_named_url_fallback(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page fallback='fallback' %}">Fallback</a>"""
)
with self.assertNumQueries(0):
result = tpl.render(template.Context({"page": None}))
self.assertIn('<a href="/fallback/">Fallback</a>', result)
def test_pageurl_with_get_absolute_url_object_fallback(self):
class ObjectWithURLMethod:
def get_absolute_url(self):
return "/object-specific-url/"
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page fallback=object_with_url_method %}">Fallback</a>"""
)
result = tpl.render(
template.Context(
{"page": None, "object_with_url_method": ObjectWithURLMethod()}
)
)
self.assertIn('<a href="/object-specific-url/">Fallback</a>', result)
def test_pageurl_with_valid_url_string_fallback(self):
"""
`django.shortcuts.resolve_url` accepts strings containing '.' or '/' as they are.
"""
tpl = template.Template(
"""
{% load wagtailcore_tags %}
<a href="{% pageurl page fallback='.' %}">Same page fallback</a>
<a href="{% pageurl page fallback='/' %}">Homepage fallback</a>
<a href="{% pageurl page fallback='../' %}">Up one step fallback</a>
"""
)
result = tpl.render(template.Context({"page": None}))
self.assertIn('<a href=".">Same page fallback</a>', result)
self.assertIn('<a href="/">Homepage fallback</a>', result)
self.assertIn('<a href="../">Up one step fallback</a>', result)
def test_pageurl_with_invalid_url_string_fallback(self):
"""
Strings not containing '.' or '/', and not matching a named URL will error.
"""
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page fallback='not-existing-endpoint' %}">Fallback</a>"""
)
with self.assertRaises(NoReverseMatch):
tpl.render(template.Context({"page": None}))
def test_slugurl_tag(self):
response = self.client.get("/events/christmas/")
self.assertEqual(response.status_code, 200)
self.assertContains(response, '<a href="/events/">Back to events index</a>')
def test_pageurl_without_request_in_context(self):
page = Page.objects.get(url_path="/home/events/")
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page %}">{{ page.title }}</a>"""
)
# no 'request' object in context
with self.assertNumQueries(7):
result = tpl.render(template.Context({"page": page}))
self.assertIn('<a href="/events/">Events</a>', result)
# 'request' object in context, but no 'site' attribute
result = tpl.render(
template.Context({"page": page, "request": get_dummy_request()})
)
self.assertIn('<a href="/events/">Events</a>', result)
def test_pageurl_caches(self):
page = Page.objects.get(url_path="/home/events/")
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page %}">{{ page.title }}</a>"""
)
request = get_dummy_request()
with self.assertNumQueries(8):
result = tpl.render(template.Context({"page": page, "request": request}))
self.assertIn('<a href="/events/">Events</a>', result)
with self.assertNumQueries(0):
result = tpl.render(template.Context({"page": page, "request": request}))
self.assertIn('<a href="/events/">Events</a>', result)
@override_settings(ALLOWED_HOSTS=["testserver", "localhost", "unknown.example.com"])
def test_pageurl_with_unknown_site(self):
page = Page.objects.get(url_path="/home/events/")
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page %}">{{ page.title }}</a>"""
)
# 'request' object in context, but site is None
request = get_dummy_request()
request.META["HTTP_HOST"] = "unknown.example.com"
with self.assertNumQueries(8):
result = tpl.render(template.Context({"page": page, "request": request}))
self.assertIn('<a href="/events/">Events</a>', result)
def test_bad_pageurl(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page %}">{{ page.title }}</a>"""
)
with self.assertRaisesRegex(
ValueError, "pageurl tag expected a Page object, got None"
):
tpl.render(template.Context({"page": None}))
def test_bad_slugurl(self):
# no 'request' object in context
result = slugurl(template.Context({}), "bad-slug-doesnt-exist")
self.assertIsNone(result)
# 'request' object in context, but no 'site' attribute
result = slugurl(
context=template.Context({"request": HttpRequest()}),
slug="bad-slug-doesnt-exist",
)
self.assertIsNone(result)
@override_settings(ALLOWED_HOSTS=["testserver", "localhost", "site2.example.com"])
def test_slugurl_tag_returns_url_for_current_site(self):
home_page = Page.objects.get(url_path="/home/")
new_home_page = home_page.copy(
update_attrs={"title": "New home page", "slug": "new-home"}
)
second_site = Site.objects.create(
hostname="site2.example.com", root_page=new_home_page
)
# Add a page to the new site that has a slug that is the same as one on
# the first site, but is in a different position in the treeself.
new_christmas_page = Page(title="Christmas", slug="christmas")
new_home_page.add_child(instance=new_christmas_page)
request = get_dummy_request(site=second_site)
url = slugurl(context=template.Context({"request": request}), slug="christmas")
self.assertEqual(url, "/christmas/")
@override_settings(ALLOWED_HOSTS=["testserver", "localhost", "site2.example.com"])
def test_slugurl_tag_returns_url_for_other_site(self):
home_page = Page.objects.get(url_path="/home/")
new_home_page = home_page.copy(
update_attrs={"title": "New home page", "slug": "new-home"}
)
second_site = Site.objects.create(
hostname="site2.example.com", root_page=new_home_page
)
request = get_dummy_request(site=second_site)
# There is no page with this slug on the current site, so this
# should return an absolute URL for the page on the first site.
url = slugurl(slug="christmas", context=template.Context({"request": request}))
self.assertEqual(url, "http://localhost/events/christmas/")
def test_slugurl_without_request_in_context(self):
# no 'request' object in context
result = slugurl(template.Context({}), "events")
self.assertEqual(result, "/events/")
# 'request' object in context, but no 'site' attribute
with self.assertNumQueries(3):
result = slugurl(
template.Context({"request": get_dummy_request()}), "events"
)
self.assertEqual(result, "/events/")
@override_settings(ALLOWED_HOSTS=["testserver", "localhost", "unknown.example.com"])
def test_slugurl_with_null_site_in_request(self):
# 'request' object in context, but site is None
request = get_dummy_request()
request.META["HTTP_HOST"] = "unknown.example.com"
result = slugurl(template.Context({"request": request}), "events")
self.assertEqual(result, "/events/")
def test_fullpageurl(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% fullpageurl page %}">Events</a>"""
)
page = Page.objects.get(url_path="/home/events/")
with self.assertNumQueries(7):
result = tpl.render(template.Context({"page": page}))
self.assertIn('<a href="http://localhost/events/">Events</a>', result)
def test_fullpageurl_with_named_url_fallback(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% fullpageurl page fallback='fallback' %}">Fallback</a>"""
)
with self.assertNumQueries(0):
result = tpl.render(template.Context({"page": None}))
self.assertIn('<a href="/fallback/">Fallback</a>', result)
def test_fullpageurl_with_absolute_fallback(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% fullpageurl page fallback='fallback' %}">Fallback</a>"""
)
with self.assertNumQueries(0):
result = tpl.render(
template.Context({"page": None, "request": get_dummy_request()})
)
self.assertIn('<a href="http://localhost/fallback/">Fallback</a>', result)
def test_fullpageurl_with_invalid_page(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% fullpageurl page %}">Events</a>"""
)
with self.assertRaises(ValueError):
tpl.render(template.Context({"page": 123}))
def test_pageurl_with_invalid_page(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}<a href="{% pageurl page %}">Events</a>"""
)
with self.assertRaises(ValueError):
tpl.render(template.Context({"page": 123}))
class TestWagtailSiteTag(TestCase):
fixtures = ["test.json"]
def test_wagtail_site_tag(self):
request = get_dummy_request(site=Site.objects.first())
tpl = template.Template(
"""{% load wagtailcore_tags %}{% wagtail_site as current_site %}{{ current_site.hostname }}"""
)
result = tpl.render(template.Context({"request": request}))
self.assertEqual("localhost", result)
def test_wagtail_site_tag_with_missing_request_context(self):
tpl = template.Template(
"""{% load wagtailcore_tags %}{% wagtail_site as current_site %}{{ current_site.hostname }}"""
)
result = tpl.render(template.Context({}))
# should fail silently
self.assertEqual("", result)
class TestSiteRootPathsCache(TestCase):
fixtures = ["test.json"]
def get_cached_site_root_paths(self):
return cache.get(
SITE_ROOT_PATHS_CACHE_KEY, version=SITE_ROOT_PATHS_CACHE_VERSION
)
def test_cache(self):
"""
This tests that the cache is populated when building URLs
"""
# Get homepage
homepage = Page.objects.get(url_path="/home/")
# Warm up the cache by getting the url
_ = homepage.url
# Check that the cache has been set correctly
self.assertEqual(
self.get_cached_site_root_paths(),
[
SiteRootPath(
site_id=1,
root_path="/home/",
root_url="http://localhost",
language_code="en",
)
],
)
def test_cache_backend_uses_json_serialization(self):
"""
This tests that, even if the cache backend uses JSON serialization,
get_site_root_paths() returns a list of SiteRootPath objects.
"""
result = Site.get_site_root_paths()
self.assertEqual(
result,
[
SiteRootPath(
site_id=1,
root_path="/home/",
root_url="http://localhost",
language_code="en",
)
],
)
# Go through JSON (de)serialisation to check that the result is
# still a list of named tuples.
cache.set(
SITE_ROOT_PATHS_CACHE_KEY,
json.loads(json.dumps(result)),
version=SITE_ROOT_PATHS_CACHE_VERSION,
)
result = Site.get_site_root_paths()
self.assertIsInstance(result[0], SiteRootPath)
def test_cache_clears_when_site_saved(self):
"""
This tests that the cache is cleared whenever a site is saved
"""
# Get homepage
homepage = Page.objects.get(url_path="/home/")
# Warm up the cache by getting the url
_ = homepage.url
# Check that the cache has been set
self.assertEqual(
self.get_cached_site_root_paths(),
[
SiteRootPath(
site_id=1,
root_path="/home/",
root_url="http://localhost",
language_code="en",
)
],
)
# Save the site
Site.objects.get(is_default_site=True).save()
# Check that the cache has been cleared
self.assertIsNone(self.get_cached_site_root_paths())
def test_cache_clears_when_site_deleted(self):
"""
This tests that the cache is cleared whenever a site is deleted
"""
# Get homepage
homepage = Page.objects.get(url_path="/home/")
# Warm up the cache by getting the url
_ = homepage.url
# Check that the cache has been set
self.assertEqual(
self.get_cached_site_root_paths(),
[
SiteRootPath(
site_id=1,
root_path="/home/",
root_url="http://localhost",
language_code="en",
)
],
)
# Delete the site
Site.objects.get(is_default_site=True).delete()
# Check that the cache has been cleared
self.assertIsNone(self.get_cached_site_root_paths())
def test_cache_clears_when_site_root_moves(self):
"""
This tests for an issue where if a site root page was moved, all
the page urls in that site would change to None.
The issue was caused by the 'wagtail_site_root_paths' cache
variable not being cleared when a site root page was moved. Which
left all the child pages thinking that they are no longer in the
site and return None as their url.
Fix: d6cce69a397d08d5ee81a8cbc1977ab2c9db2682
Discussion: https://github.com/wagtail/wagtail/issues/7
"""
# Get homepage, root page and site
root_page = Page.objects.get(id=1)
homepage = Page.objects.get(url_path="/home/")
default_site = Site.objects.get(is_default_site=True)
# Create a new homepage under current homepage
new_homepage = SimplePage(
title="New Homepage", slug="new-homepage", content="hello"
)
homepage.add_child(instance=new_homepage)
# Set new homepage as the site root page
default_site.root_page = new_homepage
default_site.save()
# Warm up the cache by getting the url
_ = homepage.url
# Move new homepage to root
new_homepage.move(root_page, pos="last-child")
# Get fresh instance of new_homepage
new_homepage = Page.objects.get(id=new_homepage.id)
# Check url
self.assertEqual(new_homepage.url, "/")
def test_cache_clears_when_site_root_slug_changes(self):
"""
This tests for an issue where if a site root pages slug was
changed, all the page urls in that site would change to None.
The issue was caused by the 'wagtail_site_root_paths' cache
variable not being cleared when a site root page was changed.
Which left all the child pages thinking that they are no longer in
the site and return None as their url.
Fix: d6cce69a397d08d5ee81a8cbc1977ab2c9db2682
Discussion: https://github.com/wagtail/wagtail/issues/157
"""
# Get homepage
homepage = Page.objects.get(url_path="/home/")
# Warm up the cache by getting the url
_ = homepage.url
# Change homepage title and slug
homepage.title = "New home"
homepage.slug = "new-home"
homepage.save()
# Get fresh instance of homepage
homepage = Page.objects.get(id=homepage.id)
# Check url
self.assertEqual(homepage.url, "/")
@override_settings(WAGTAIL_I18N_ENABLED=True)
def test_cache_clears_when_site_root_is_translated_as_alias(self):
# Get homepage
homepage = Page.objects.get(url_path="/home/")
# Warm up the cache by getting the url
_ = homepage.url
# Translate the homepage
translated_homepage = homepage.copy_for_translation(
Locale.objects.create(language_code="fr"), alias=True
)
# Check url
self.assertEqual(translated_homepage.url, "/")
class TestResolveModelString(TestCase):
def test_resolve_from_string(self):
model = resolve_model_string("wagtailcore.Page")
self.assertEqual(model, Page)
def test_resolve_from_string_with_default_app(self):
model = resolve_model_string("Page", default_app="wagtailcore")
self.assertEqual(model, Page)
def test_resolve_from_string_with_different_default_app(self):
model = resolve_model_string("wagtailcore.Page", default_app="wagtailadmin")
self.assertEqual(model, Page)
def test_resolve_from_class(self):
model = resolve_model_string(Page)
self.assertEqual(model, Page)
def test_resolve_from_string_invalid(self):
self.assertRaises(ValueError, resolve_model_string, "wagtail.core.Page")
def test_resolve_from_string_with_incorrect_default_app(self):
self.assertRaises(
LookupError, resolve_model_string, "Page", default_app="wagtailadmin"
)
def test_resolve_from_string_with_unknown_model_string(self):
self.assertRaises(LookupError, resolve_model_string, "wagtailadmin.Page")
def test_resolve_from_string_with_no_default_app(self):
self.assertRaises(ValueError, resolve_model_string, "Page")
def test_resolve_from_class_that_isnt_a_model(self):
model = resolve_model_string(object)
self.assertEqual(model, object)
def test_resolve_from_bad_type(self):
self.assertRaises(ValueError, resolve_model_string, resolve_model_string)
def test_resolve_from_none(self):
self.assertRaises(ValueError, resolve_model_string, None)
class TestRichtextTag(TestCase):
def test_call_with_text(self):
result = richtext("Hello world!")
self.assertEqual(result, "Hello world!")
self.assertIsInstance(result, SafeString)
def test_call_with_lazy(self):
result = richtext(gettext_lazy("test"))
self.assertEqual(result, "test")
def test_call_with_none(self):
result = richtext(None)
self.assertEqual(result, "")
def test_call_with_invalid_value(self):
with self.assertRaisesRegex(
TypeError, "'richtext' template filter received an invalid value"
):
richtext(42)
def test_call_with_bytes(self):
with self.assertRaisesRegex(
TypeError, "'richtext' template filter received an invalid value"
):
richtext(b"Hello world!")
class TestWagtailCacheTag(TestCase):
def setUp(self):
cache.clear()
def test_caches(self):
request = get_dummy_request()
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailcache 100 test %}{{ foo.bar }}{% endwagtailcache %}"""
)
result = tpl.render(
template.Context({"request": request, "foo": {"bar": "foobar"}})
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context({"request": request, "foo": {"bar": "baz"}})
)
self.assertEqual(result2, "foobar")
self.assertEqual(cache.get(make_template_fragment_key("test")), "foobar")
def test_caches_on_additional_parameters(self):
request = get_dummy_request()
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailcache 100 test foo %}{{ foo.bar }}{% endwagtailcache %}"""
)
result = tpl.render(
template.Context({"request": request, "foo": {"bar": "foobar"}})
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context({"request": request, "foo": {"bar": "baz"}})
)
self.assertEqual(result2, "baz")
self.assertEqual(
cache.get(make_template_fragment_key("test", [{"bar": "foobar"}])), "foobar"
)
self.assertEqual(
cache.get(make_template_fragment_key("test", [{"bar": "baz"}])), "baz"
)
def test_skips_cache_in_preview(self):
request = get_dummy_request()
request.is_preview = True
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailcache 100 test %}{{ foo.bar }}{% endwagtailcache %}"""
)
result = tpl.render(
template.Context({"request": request, "foo": {"bar": "foobar"}})
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context({"request": request, "foo": {"bar": "baz"}})
)
self.assertEqual(result2, "baz")
self.assertIsNone(cache.get(make_template_fragment_key("test")))
def test_no_request(self):
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailcache 100 test %}{{ foo.bar }}{% endwagtailcache %}"""
)
result = tpl.render(template.Context({"foo": {"bar": "foobar"}}))
self.assertEqual(result, "foobar")
result2 = tpl.render(template.Context({"foo": {"bar": "baz"}}))
self.assertEqual(result2, "baz")
self.assertIsNone(cache.get(make_template_fragment_key("test"))) #
def test_invalid_usage(self):
with self.assertRaises(TemplateSyntaxError) as e:
template.Template(
"""{% load wagtail_cache %}{% wagtailcache 100 %}{{ foo.bar }}{% endwagtailcache %}"""
)
self.assertEqual(
e.exception.args[0], "'wagtailcache' tag requires at least 2 arguments."
)
class TestWagtailPageCacheTag(TestCase):
fixtures = ["test.json"]
@classmethod
def setUpTestData(cls):
cls.page_1 = Page.objects.first()
cls.page_2 = Page.objects.all()[2]
cls.site = Site.objects.get(hostname="localhost", port=80)
def test_caches(self):
request = get_dummy_request(site=self.site)
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
result = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "foobar"}, "page": self.page_1}
)
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "baz"}, "page": self.page_1}
)
)
self.assertEqual(result2, "foobar")
self.assertEqual(
cache.get(
make_wagtail_template_fragment_key("test", self.page_1, self.site)
),
"foobar",
)
def test_caches_additional_parameters(self):
request = get_dummy_request(site=self.site)
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test foo %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
result = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "foobar"}, "page": self.page_1}
)
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "baz"}, "page": self.page_1}
)
)
self.assertEqual(result2, "baz")
self.assertEqual(
cache.get(
make_wagtail_template_fragment_key(
"test", self.page_1, self.site, [{"bar": "foobar"}]
)
),
"foobar",
)
self.assertEqual(
cache.get(
make_wagtail_template_fragment_key(
"test", self.page_1, self.site, [{"bar": "baz"}]
)
),
"baz",
)
def test_doesnt_pollute_cache(self):
request = get_dummy_request(site=self.site)
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
context = template.Context(
{"request": request, "foo": {"bar": "foobar"}, "page": self.page_1}
)
result = tpl.render(context)
self.assertEqual(result, "foobar")
self.assertNotIn(WagtailPageCacheNode.CACHE_SITE_TEMPLATE_VAR, context)
def test_skips_cache_in_preview(self):
request = get_dummy_request(site=self.site)
request.is_preview = True
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
result = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "foobar"}, "page": self.page_1}
)
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context(
{"request": request, "foo": {"bar": "baz"}, "page": self.page_1}
)
)
self.assertEqual(result2, "baz")
self.assertIsNone(
cache.get(
make_wagtail_template_fragment_key("test", self.page_1, self.site)
)
)
def test_no_request(self):
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
result = tpl.render(
template.Context({"foo": {"bar": "foobar"}, "page": self.page_1})
)
self.assertEqual(result, "foobar")
result2 = tpl.render(
template.Context({"foo": {"bar": "baz"}, "page": self.page_1})
)
self.assertEqual(result2, "baz")
self.assertIsNone(
cache.get(
make_wagtail_template_fragment_key("test", self.page_1, self.site)
)
)
def test_no_page(self):
request = get_dummy_request()
tpl = template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 test %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
with self.assertRaises(VariableDoesNotExist) as e:
tpl.render(template.Context({"request": request, "foo": {"bar": "foobar"}}))
self.assertEqual(e.exception.params[0], "page")
def test_cache_key(self):
self.assertEqual(
make_wagtail_template_fragment_key("test", self.page_1, self.site),
make_template_fragment_key(
"test", vary_on=[self.page_1.cache_key, self.site.id]
),
)
def test_invalid_usage(self):
with self.assertRaises(TemplateSyntaxError) as e:
template.Template(
"""{% load wagtail_cache %}{% wagtailpagecache 100 %}{{ foo.bar }}{% endwagtailpagecache %}"""
)
self.assertEqual(
e.exception.args[0], "'wagtailpagecache' tag requires at least 2 arguments."
)