Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,350 @@
"""
Base model definitions for audit logging. These may be subclassed to accommodate specific models
such as Page, but the definitions here should remain generic and not depend on the base
wagtail.models module or specific models such as Page.
"""
from collections import defaultdict
from django.conf import settings
from django.contrib.auth import get_user_model
from django.contrib.auth.models import Permission
from django.contrib.contenttypes.models import ContentType
from django.core.exceptions import ValidationError
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models
from django.utils import timezone
from django.utils.functional import cached_property
from django.utils.translation import gettext_lazy as _
from wagtail.log_actions import registry as log_action_registry
from wagtail.users.utils import get_deleted_user_display_name
class LogEntryQuerySet(models.QuerySet):
def get_actions(self):
"""
Returns a set of actions used by at least one log entry in this QuerySet
"""
return set(self.order_by().values_list("action", flat=True).distinct())
def get_user_ids(self):
"""
Returns a set of user IDs of users who have created at least one log entry in this QuerySet
"""
return set(self.order_by().values_list("user_id", flat=True).distinct())
def get_users(self):
"""
Returns a QuerySet of Users who have created at least one log entry in this QuerySet.
The returned queryset is ordered by the username.
"""
User = get_user_model()
return User.objects.filter(pk__in=self.get_user_ids()).order_by(
User.USERNAME_FIELD
)
def get_content_type_ids(self):
"""
Returns a set of IDs of content types with logged actions in this QuerySet
"""
return set(self.order_by().values_list("content_type_id", flat=True).distinct())
def filter_on_content_type(self, content_type):
# custom method for filtering by content type, to allow overriding on log entry models
# that have a concept of object types that doesn't correspond directly to ContentType
# instances (e.g. PageLogEntry, which treats all page types as a single Page type)
return self.filter(content_type_id=content_type.id)
def with_instances(self):
# return an iterable of (log_entry, instance) tuples for all log entries in this queryset.
# instance is None if the instance does not exist.
# Note: This is an expensive operation and should only be done on small querysets
# (e.g. after pagination).
# evaluate the queryset in full now, as we'll be iterating over it multiple times
log_entries = list(self)
ids_by_content_type = defaultdict(list)
for log_entry in log_entries:
ids_by_content_type[log_entry.content_type_id].append(log_entry.object_id)
instances_by_id = {} # lookup of (content_type_id, stringified_object_id) to instance
for content_type_id, object_ids in ids_by_content_type.items():
try:
content_type = ContentType.objects.get_for_id(content_type_id)
model = content_type.model_class()
except ContentType.DoesNotExist:
model = None
if model:
model_instances = model.objects.in_bulk(object_ids)
else:
# The model class for the logged instance no longer exists,
# so we have no instance to return. Return None instead.
model_instances = {object_id: None for object_id in object_ids}
for object_id, instance in model_instances.items():
instances_by_id[(content_type_id, str(object_id))] = instance
for log_entry in log_entries:
lookup_key = (log_entry.content_type_id, str(log_entry.object_id))
yield (log_entry, instances_by_id.get(lookup_key))
class BaseLogEntryManager(models.Manager):
def get_queryset(self):
return LogEntryQuerySet(self.model, using=self._db)
def get_instance_title(self, instance):
return str(instance)
def log_action(self, instance, action, **kwargs):
"""
:param instance: The model instance we are logging an action for
:param action: The action. Should be namespaced to app (e.g. wagtail.create, wagtail.workflow.start)
:param kwargs: Addition fields to for the model deriving from BaseLogEntry
- user: The user performing the action
- uuid: uuid shared between log entries from the same user action
- title: the instance title
- data: any additional metadata
- content_changed, deleted - Boolean flags
:return: The new log entry
"""
if instance.pk is None:
raise ValueError(
"Attempted to log an action for object %r with empty primary key"
% (instance,)
)
data = kwargs.pop("data", None) or {}
title = kwargs.pop("title", None)
if not title:
title = self.get_instance_title(instance)
timestamp = kwargs.pop("timestamp", timezone.now())
return self.model.objects.create(
content_type=ContentType.objects.get_for_model(
instance, for_concrete_model=False
),
label=title,
action=action,
timestamp=timestamp,
data=data,
**kwargs,
)
def viewable_by_user(self, user):
if user.is_superuser:
return self.all()
# This will be called multiple times per request, so we cache those ids once.
if not hasattr(user, "_allowed_content_type_ids"):
# 1) Only query those permissions, where log entries exist for their content
# types.
used_content_type_ids = self.values_list(
"content_type_id", flat=True
).distinct()
permissions = Permission.objects.filter(
content_type_id__in=used_content_type_ids
)
# 2) If the user has at least one permission for a content type, we add its
# id to the allowed-set.
allowed_content_type_ids = set()
for permission in permissions:
if permission.content_type_id in allowed_content_type_ids:
continue
content_type = ContentType.objects.get_for_id(
permission.content_type_id
)
if user.has_perm(
"%s.%s" % (content_type.app_label, permission.codename)
):
allowed_content_type_ids.add(permission.content_type_id)
user._allowed_content_type_ids = allowed_content_type_ids
return self.filter(content_type_id__in=user._allowed_content_type_ids)
def get_for_model(self, model):
# Return empty queryset if the given object is not valid.
if not issubclass(model, models.Model):
return self.none()
ct = ContentType.objects.get_for_model(model)
return self.filter(content_type=ct)
def get_for_user(self, user_id):
return self.filter(user=user_id)
def for_instance(self, instance):
"""
Return a queryset of log entries from this log model that relate to the given object instance
"""
raise NotImplementedError # must be implemented by subclass
class BaseLogEntry(models.Model):
content_type = models.ForeignKey(
ContentType,
models.SET_NULL,
verbose_name=_("content type"),
blank=True,
null=True,
related_name="+",
)
label = models.TextField()
action = models.CharField(max_length=255, blank=True, db_index=True)
data = models.JSONField(blank=True, default=dict, encoder=DjangoJSONEncoder)
timestamp = models.DateTimeField(verbose_name=_("timestamp (UTC)"), db_index=True)
uuid = models.UUIDField(
blank=True,
null=True,
editable=False,
help_text="Log entries that happened as part of the same user action are assigned the same UUID",
)
user = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True, # Null if actioned by system
blank=True,
on_delete=models.DO_NOTHING,
db_constraint=False,
related_name="+",
)
# Pointer to a specific page revision
revision = models.ForeignKey(
"wagtailcore.Revision",
null=True,
blank=True,
on_delete=models.DO_NOTHING,
db_constraint=False,
related_name="+",
)
# Flags for additional context to the 'action' made by the user (or system).
content_changed = models.BooleanField(default=False, db_index=True)
deleted = models.BooleanField(default=False)
objects = BaseLogEntryManager()
wagtail_reference_index_ignore = True
class Meta:
abstract = True
verbose_name = _("log entry")
verbose_name_plural = _("log entries")
ordering = ["-timestamp"]
def save(self, *args, **kwargs):
self.full_clean()
return super().save(*args, **kwargs)
def clean(self):
if not log_action_registry.action_exists(self.action):
raise ValidationError(
{
"action": _(
"The log action '%(action_name)s' has not been registered."
)
% {"action_name": self.action}
}
)
def __str__(self):
return "LogEntry %d: '%s' on '%s'" % (
self.pk,
self.action,
self.object_verbose_name(),
)
@cached_property
def user_display_name(self):
"""
Returns the display name of the associated user;
get_full_name if available and non-empty, otherwise get_username.
Defaults to 'system' when none is provided
"""
if self.user_id:
user = self.user
if user is None:
return get_deleted_user_display_name(self.user_id)
try:
full_name = user.get_full_name().strip()
except AttributeError:
full_name = ""
return full_name or user.get_username()
else:
return _("system")
@cached_property
def object_verbose_name(self):
model_class = self.content_type.model_class()
if model_class is None:
return self.content_type_id
return model_class._meta.verbose_name.title
def object_id(self):
raise NotImplementedError
@cached_property
def formatter(self):
return log_action_registry.get_formatter(self)
@cached_property
def message(self):
if self.formatter:
return self.formatter.format_message(self)
else:
return _("Unknown %(action)s") % {"action": self.action}
@cached_property
def comment(self):
if self.formatter:
return self.formatter.format_comment(self)
else:
return ""
class ModelLogEntryManager(BaseLogEntryManager):
def log_action(self, instance, action, **kwargs):
kwargs.update(object_id=str(instance.pk))
return super().log_action(instance, action, **kwargs)
def for_instance(self, instance):
return self.filter(
content_type=ContentType.objects.get_for_model(
instance, for_concrete_model=False
),
object_id=str(instance.pk),
)
class ModelLogEntry(BaseLogEntry):
"""
Simple logger for generic Django models
"""
object_id = models.CharField(max_length=255, blank=False, db_index=True)
objects = ModelLogEntryManager()
class Meta:
ordering = ["-timestamp", "-id"]
verbose_name = _("model log entry")
verbose_name_plural = _("model log entries")
def __str__(self):
return "ModelLogEntry %d: '%s' on '%s' with id %s" % (
self.pk,
self.action,
self.object_verbose_name(),
self.object_id,
)

View File

@@ -0,0 +1,5 @@
# wagtail.models.collections was moved to wagtail.models.media in #11555;
# this import is retained to accommodate migration files importing from the old location.
# See #11874
from wagtail.models.media import get_root_collection_id # noqa

View File

@@ -0,0 +1,113 @@
from django.contrib.contenttypes.fields import GenericRelation
from django.db import models
from modelcluster.fields import ParentalKey, ParentalManyToManyField
from modelcluster.models import ClusterableModel
def _extract_field_data(source, exclude_fields=None):
"""
Get dictionaries representing the model's field data.
This excludes many to many fields (which are handled by _copy_m2m_relations)'
"""
exclude_fields = exclude_fields or []
data_dict = {}
for field in source._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Ignore reverse generic relations
if isinstance(field, GenericRelation):
continue
# Copy parental m2m relations
if field.many_to_many:
if isinstance(field, ParentalManyToManyField):
parental_field = getattr(source, field.name)
if hasattr(parental_field, "all"):
values = parental_field.all()
if values:
data_dict[field.name] = values
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
if isinstance(field, models.ForeignKey):
# Use attname to copy the ID instead of retrieving the instance
# Note: We first need to set the field to None to unset any object
# that's there already just setting _id on its own won't change the
# field until its saved.
data_dict[field.name] = None
data_dict[field.attname] = getattr(source, field.attname)
else:
data_dict[field.name] = getattr(source, field.name)
return data_dict
def _copy_m2m_relations(source, target, exclude_fields=None, update_attrs=None):
"""
Copies non-ParentalManyToMany m2m relations
"""
update_attrs = update_attrs or {}
exclude_fields = exclude_fields or []
for field in source._meta.get_fields():
# Copy m2m relations. Ignore explicitly excluded fields, reverse relations, and Parental m2m fields.
if (
field.many_to_many
and field.name not in exclude_fields
and not field.auto_created
and not isinstance(field, ParentalManyToManyField)
):
try:
# Do not copy m2m links with a through model that has a ParentalKey to the model being copied - these will be copied as child objects
through_model_parental_links = [
field
for field in field.through._meta.get_fields()
if isinstance(field, ParentalKey)
and issubclass(source.__class__, field.related_model)
]
if through_model_parental_links:
continue
except AttributeError:
pass
if field.name in update_attrs:
value = update_attrs[field.name]
else:
value = getattr(source, field.name).all()
getattr(target, field.name).set(value)
def _copy(source, exclude_fields=None, update_attrs=None):
data_dict = _extract_field_data(source, exclude_fields=exclude_fields)
target = source.__class__(**data_dict)
if update_attrs:
for field, value in update_attrs.items():
if field not in data_dict:
continue
setattr(target, field, value)
if isinstance(source, ClusterableModel):
child_object_map = source.copy_all_child_relations(
target, exclude=exclude_fields
)
else:
child_object_map = {}
return target, child_object_map

View File

@@ -0,0 +1,487 @@
import uuid
from typing import Dict
from django.apps import apps
from django.conf import settings
from django.core import checks
from django.db import migrations, models, transaction
from django.db.models.signals import pre_save
from django.dispatch import receiver
from django.utils import translation
from django.utils.encoding import force_str
from django.utils.translation import gettext_lazy as _
from modelcluster.fields import ParentalKey
from wagtail.actions.copy_for_translation import CopyForTranslationAction
from wagtail.coreutils import (
get_content_languages,
get_supported_content_language_variant,
)
from wagtail.signals import pre_validate_delete
def pk(obj):
if isinstance(obj, models.Model):
return obj.pk
else:
return obj
class LocaleManager(models.Manager):
def get_for_language(self, language_code):
"""
Gets a Locale from a language code.
"""
return self.get(
language_code=get_supported_content_language_variant(language_code)
)
class Locale(models.Model):
#: The language code that represents this locale
#:
#: The language code can either be a language code on its own (such as ``en``, ``fr``),
#: or it can include a region code (such as ``en-gb``, ``fr-fr``).
language_code = models.CharField(max_length=100, unique=True)
# Objects excludes any Locales that have been removed from LANGUAGES, This effectively disables them
# The Locale management UI needs to be able to see these so we provide a separate manager `all_objects`
objects = LocaleManager()
all_objects = models.Manager()
class Meta:
ordering = [
"language_code",
]
@classmethod
def get_default(cls):
"""
Returns the default Locale based on the site's LANGUAGE_CODE setting
"""
return cls.objects.get_for_language(settings.LANGUAGE_CODE)
@classmethod
def get_active(cls):
"""
Returns the Locale that corresponds to the currently activated language in Django.
"""
try:
return cls.objects.get_for_language(translation.get_language())
except (cls.DoesNotExist, LookupError):
return cls.get_default()
@transaction.atomic
def delete(self, *args, **kwargs):
# Provide a signal like pre_delete, but sent before on_delete validation.
# This allows us to use the signal to fix up references to the locale to be deleted
# that would otherwise fail validation.
# Workaround for https://code.djangoproject.com/ticket/6870
pre_validate_delete.send(sender=Locale, instance=self)
return super().delete(*args, **kwargs)
def language_code_is_valid(self):
return self.language_code in get_content_languages()
def get_display_name(self) -> str:
try:
return get_content_languages()[self.language_code]
except KeyError:
pass
try:
return self.language_name
except KeyError:
pass
return self.language_code
def __str__(self):
return force_str(self.get_display_name())
def _get_language_info(self) -> Dict[str, str]:
return translation.get_language_info(self.language_code)
@property
def language_info(self):
return translation.get_language_info(self.language_code)
@property
def language_name(self):
"""
Uses data from ``django.conf.locale`` to return the language name in
English. For example, if the object's ``language_code`` were ``"fr"``,
the return value would be ``"French"``.
Raises ``KeyError`` if ``django.conf.locale`` has no information
for the object's ``language_code`` value.
"""
return self.language_info["name"]
@property
def language_name_local(self):
"""
Uses data from ``django.conf.locale`` to return the language name in
the language itself. For example, if the ``language_code`` were
``"fr"`` (French), the return value would be ``"français"``.
Raises ``KeyError`` if ``django.conf.locale`` has no information
for the object's ``language_code`` value.
"""
return self.language_info["name_local"]
@property
def language_name_localized(self):
"""
Uses data from ``django.conf.locale`` to return the language name in
the currently active language. For example, if ``language_code`` were
``"fr"`` (French), and the active language were ``"da"`` (Danish), the
return value would be ``"Fransk"``.
Raises ``KeyError`` if ``django.conf.locale`` has no information
for the object's ``language_code`` value.
"""
return translation.gettext(self.language_name)
@property
def is_bidi(self) -> bool:
"""
Returns a boolean indicating whether the language is bi-directional.
"""
return self.language_code in settings.LANGUAGES_BIDI
@property
def is_default(self) -> bool:
"""
Returns a boolean indicating whether this object is the default locale.
"""
try:
return self.language_code == get_supported_content_language_variant(
settings.LANGUAGE_CODE
)
except LookupError:
return False
@property
def is_active(self) -> bool:
"""
Returns a boolean indicating whether this object is the currently active locale.
"""
try:
return self.language_code == get_supported_content_language_variant(
translation.get_language()
)
except LookupError:
return self.is_default
class TranslatableMixin(models.Model):
translation_key = models.UUIDField(default=uuid.uuid4, editable=False)
locale = models.ForeignKey(
Locale,
on_delete=models.PROTECT,
related_name="+",
editable=False,
verbose_name=_("locale"),
)
locale.wagtail_reference_index_ignore = True
class Meta:
abstract = True
unique_together = [("translation_key", "locale")]
@classmethod
def check(cls, **kwargs):
errors = super().check(**kwargs)
# No need to check on multi-table-inheritance children as it only needs to be applied to
# the table that has the translation_key/locale fields
is_translation_model = cls.get_translation_model() is cls
if not is_translation_model:
return errors
unique_constraint_fields = ("translation_key", "locale")
has_unique_constraint = any(
isinstance(constraint, models.UniqueConstraint)
and set(constraint.fields) == set(unique_constraint_fields)
for constraint in cls._meta.constraints
)
has_unique_together = unique_constraint_fields in cls._meta.unique_together
# Raise error if subclass has removed constraints
if not (has_unique_constraint or has_unique_together):
errors.append(
checks.Error(
"%s is missing a UniqueConstraint for the fields: %s."
% (cls._meta.label, unique_constraint_fields),
hint=(
"Add models.UniqueConstraint(fields=%s, "
"name='unique_translation_key_locale_%s_%s') to %s.Meta.constraints."
% (
unique_constraint_fields,
cls._meta.app_label,
cls._meta.model_name,
cls.__name__,
)
),
obj=cls,
id="wagtailcore.E003",
)
)
# Raise error if subclass has both UniqueConstraint and unique_together
if has_unique_constraint and has_unique_together:
errors.append(
checks.Error(
"%s should not have both UniqueConstraint and unique_together for: %s."
% (cls._meta.label, unique_constraint_fields),
hint="Remove unique_together in favor of UniqueConstraint.",
obj=cls,
id="wagtailcore.E003",
)
)
return errors
@property
def localized(self):
"""
Finds the translation in the current active language.
If there is no translation in the active language, self is returned.
Note: This will not return the translation if it is in draft.
If you want to include drafts, use the ``.localized_draft`` attribute instead.
"""
from wagtail.models import DraftStateMixin
localized = self.localized_draft
if isinstance(self, DraftStateMixin) and not localized.live:
return self
return localized
@property
def localized_draft(self):
"""
Finds the translation in the current active language.
If there is no translation in the active language, self is returned.
Note: This will return translations that are in draft. If you want to exclude
these, use the ``.localized`` attribute.
"""
if not getattr(settings, "WAGTAIL_I18N_ENABLED", False):
return self
try:
locale = Locale.get_active()
except (LookupError, Locale.DoesNotExist):
return self
if locale.id == self.locale_id:
return self
return self.get_translation_or_none(locale) or self
def get_translations(self, inclusive=False):
"""
Returns a queryset containing the translations of this instance.
"""
translations = self.__class__.objects.filter(
translation_key=self.translation_key
)
if inclusive is False:
translations = translations.exclude(id=self.id)
return translations
def get_translation(self, locale):
"""
Finds the translation in the specified locale.
If there is no translation in that locale, this raises a ``model.DoesNotExist`` exception.
"""
return self.get_translations(inclusive=True).get(locale_id=pk(locale))
def get_translation_or_none(self, locale):
"""
Finds the translation in the specified locale.
If there is no translation in that locale, this returns None.
"""
try:
return self.get_translation(locale)
except self.__class__.DoesNotExist:
return None
def has_translation(self, locale):
"""
Returns True if a translation exists in the specified locale.
"""
return (
self.get_translations(inclusive=True).filter(locale_id=pk(locale)).exists()
)
def copy_for_translation(self, locale, exclude_fields=None):
"""
Creates a copy of this instance with the specified locale.
Note that the copy is initially unsaved.
"""
return CopyForTranslationAction(
self,
locale,
exclude_fields=exclude_fields,
).execute()
def get_default_locale(self):
"""
Finds the default locale to use for this object.
This will be called just before the initial save.
"""
# Check if the object has any parental keys to another translatable model
# If so, take the locale from the object referenced in that parental key
parental_keys = [
field
for field in self._meta.get_fields()
if isinstance(field, ParentalKey)
and issubclass(field.related_model, TranslatableMixin)
]
if parental_keys:
parent_id = parental_keys[0].value_from_object(self)
return (
parental_keys[0]
.related_model.objects.defer()
.select_related("locale")
.get(id=parent_id)
.locale
)
return Locale.get_default()
@classmethod
def get_translation_model(cls):
"""
Returns this model's "Translation model".
The "Translation model" is the model that has the ``locale`` and
``translation_key`` fields.
Typically this would be the current model, but it may be a
super-class if multi-table inheritance is in use (as is the case
for ``wagtailcore.Page``).
"""
return cls._meta.get_field("locale").model
def bootstrap_translatable_model(model, locale):
"""
This function populates the "translation_key", and "locale" fields on model instances that were created
before wagtail-localize was added to the site.
This can be called from a data migration, or instead you could use the "bootstrap_translatable_models"
management command.
"""
for instance in (
model.objects.filter(translation_key__isnull=True).defer().iterator()
):
instance.translation_key = uuid.uuid4()
instance.locale = locale
instance.save(update_fields=["translation_key", "locale"])
class BootstrapTranslatableModel(migrations.RunPython):
def __init__(self, model_string, language_code=None):
if language_code is None:
language_code = get_supported_content_language_variant(
settings.LANGUAGE_CODE
)
def forwards(apps, schema_editor):
model = apps.get_model(model_string)
Locale = apps.get_model("wagtailcore.Locale")
locale = Locale.objects.get(language_code=language_code)
bootstrap_translatable_model(model, locale)
def backwards(apps, schema_editor):
pass
super().__init__(forwards, backwards)
class BootstrapTranslatableMixin(TranslatableMixin):
"""
A version of TranslatableMixin without uniqueness constraints.
This is to make it easy to transition existing models to being translatable.
The process is as follows:
- Add BootstrapTranslatableMixin to the model
- Run makemigrations
- Create a data migration for each app, then use the BootstrapTranslatableModel operation in
wagtail.models on each model in that app
- Change BootstrapTranslatableMixin to TranslatableMixin
- Run makemigrations again
- Migrate!
"""
translation_key = models.UUIDField(null=True, editable=False)
locale = models.ForeignKey(
Locale, on_delete=models.PROTECT, null=True, related_name="+", editable=False
)
@classmethod
def check(cls, **kwargs):
# skip the check in TranslatableMixin that enforces the unique-together constraint
return super(TranslatableMixin, cls).check(**kwargs)
class Meta:
abstract = True
def get_translatable_models(include_subclasses=False):
"""
Returns a list of all concrete models that inherit from TranslatableMixin.
By default, this only includes models that are direct children of TranslatableMixin,
to get all models, set the include_subclasses attribute to True.
"""
translatable_models = [
model
for model in apps.get_models()
if issubclass(model, TranslatableMixin) and not model._meta.abstract
]
if include_subclasses is False:
# Exclude models that inherit from another translatable model
root_translatable_models = set()
for model in translatable_models:
root_translatable_models.add(model.get_translation_model())
translatable_models = [
model for model in translatable_models if model in root_translatable_models
]
return translatable_models
@receiver(pre_save)
def set_locale_on_new_instance(sender, instance, **kwargs):
if not isinstance(instance, TranslatableMixin):
return
if instance.locale_id is not None:
return
# If this is a fixture load, use the global default Locale
# as the page tree is probably in flux
if kwargs["raw"]:
instance.locale = Locale.get_default()
return
instance.locale = instance.get_default_locale()

View File

@@ -0,0 +1,221 @@
from django.conf import settings
from django.contrib.auth.models import Group, Permission
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.utils.html import format_html
from django.utils.safestring import mark_safe
from django.utils.translation import gettext_lazy as _
from treebeard.mp_tree import MP_Node
from wagtail.query import TreeQuerySet
from wagtail.search import index
from .view_restrictions import BaseViewRestriction
class CollectionQuerySet(TreeQuerySet):
def get_min_depth(self):
return self.aggregate(models.Min("depth"))["depth__min"] or 2
def get_indented_choices(self):
"""
Return a list of (id, label) tuples for use as a list of choices in a collection chooser
dropdown, where the label is formatted with get_indented_name to provide a tree layout.
The indent level is chosen to place the minimum-depth collection at indent 0.
"""
min_depth = self.get_min_depth()
return [
(collection.pk, collection.get_indented_name(min_depth, html=True))
for collection in self
]
class BaseCollectionManager(models.Manager):
def get_queryset(self):
return CollectionQuerySet(self.model).order_by("path")
CollectionManager = BaseCollectionManager.from_queryset(CollectionQuerySet)
class CollectionViewRestriction(BaseViewRestriction):
collection = models.ForeignKey(
"Collection",
verbose_name=_("collection"),
related_name="view_restrictions",
on_delete=models.CASCADE,
)
passed_view_restrictions_session_key = "passed_collection_view_restrictions"
class Meta:
verbose_name = _("collection view restriction")
verbose_name_plural = _("collection view restrictions")
class Collection(MP_Node):
"""
A location in which resources such as images and documents can be grouped
"""
name = models.CharField(max_length=255, verbose_name=_("name"))
objects = CollectionManager()
# Tell treebeard to order Collections' paths such that they are ordered by name at each level.
node_order_by = ["name"]
def __str__(self):
return self.name
def get_ancestors(self, inclusive=False):
return Collection.objects.ancestor_of(self, inclusive)
def get_descendants(self, inclusive=False):
return Collection.objects.descendant_of(self, inclusive)
def get_siblings(self, inclusive=True):
return Collection.objects.sibling_of(self, inclusive)
def get_next_siblings(self, inclusive=False):
return self.get_siblings(inclusive).filter(path__gte=self.path).order_by("path")
def get_prev_siblings(self, inclusive=False):
return (
self.get_siblings(inclusive).filter(path__lte=self.path).order_by("-path")
)
def get_view_restrictions(self):
"""Return a query set of all collection view restrictions that apply to this collection"""
return CollectionViewRestriction.objects.filter(
collection__in=self.get_ancestors(inclusive=True)
)
def get_indented_name(self, indentation_start_depth=2, html=False):
"""
Renders this Collection's name as a formatted string that displays its hierarchical depth via indentation.
If indentation_start_depth is supplied, the Collection's depth is rendered relative to that depth.
indentation_start_depth defaults to 2, the depth of the first non-Root Collection.
Pass html=True to get an HTML representation, instead of the default plain-text.
Example text output: " ↳ Pies"
Example HTML output: "    &#x21b3 Pies"
"""
display_depth = self.depth - indentation_start_depth
# A Collection with a display depth of 0 or less (Root's can be -1), should have no indent.
if display_depth <= 0:
return self.name
# Indent each level of depth by 4 spaces (the width of the ↳ character in our admin font), then add ↳
# before adding the name.
if html:
# NOTE: &#x21b3 is the hex HTML entity for ↳.
return format_html(
"{indent}{icon} {name}",
indent=mark_safe("&nbsp;" * 4 * display_depth),
icon=mark_safe("&#x21b3"),
name=self.name,
)
# Output unicode plain-text version
return "{}{}".format(" " * 4 * display_depth, self.name)
class Meta:
verbose_name = _("collection")
verbose_name_plural = _("collections")
def get_root_collection_id():
return Collection.get_first_root_node().id
class CollectionMember(models.Model):
"""
Base class for models that are categorised into collections
"""
collection = models.ForeignKey(
Collection,
default=get_root_collection_id,
verbose_name=_("collection"),
related_name="+",
on_delete=models.CASCADE,
)
search_fields = [
index.FilterField("collection"),
]
class Meta:
abstract = True
class GroupCollectionPermissionManager(models.Manager):
def get_by_natural_key(self, group, collection, permission):
return self.get(group=group, collection=collection, permission=permission)
class GroupCollectionPermission(models.Model):
"""
A rule indicating that a group has permission for some action (e.g. "create document")
within a specified collection.
"""
group = models.ForeignKey(
Group,
verbose_name=_("group"),
related_name="collection_permissions",
on_delete=models.CASCADE,
)
collection = models.ForeignKey(
Collection,
verbose_name=_("collection"),
related_name="group_permissions",
on_delete=models.CASCADE,
)
permission = models.ForeignKey(
Permission, verbose_name=_("permission"), on_delete=models.CASCADE
)
def __str__(self):
return "Group %d ('%s') has permission '%s' on collection %d ('%s')" % (
self.group.id,
self.group,
self.permission,
self.collection.id,
self.collection,
)
def natural_key(self):
return (self.group, self.collection, self.permission)
objects = GroupCollectionPermissionManager()
class Meta:
unique_together = ("group", "collection", "permission")
verbose_name = _("group collection permission")
verbose_name_plural = _("group collection permissions")
class UploadedFile(models.Model):
"""
Temporary storage for media fields uploaded through the multiple image/document uploader.
When validation rules (e.g. required metadata fields) prevent creating an Image/Document object from the file alone.
In this case, the file is stored against this model, to be turned into an Image/Document object once the full form
has been filled in.
"""
for_content_type = models.ForeignKey(
ContentType,
verbose_name=_("for content type"),
related_name="uploads",
on_delete=models.CASCADE,
null=True,
)
file = models.FileField(upload_to="wagtail_uploads", max_length=200)
uploaded_by_user = models.ForeignKey(
settings.AUTH_USER_MODEL,
verbose_name=_("uploaded by user"),
null=True,
blank=True,
editable=False,
on_delete=models.SET_NULL,
)

View File

@@ -0,0 +1,720 @@
import uuid
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRel
from django.contrib.contenttypes.models import ContentType
from django.db import connection, models
from django.utils.functional import cached_property
from django.utils.text import capfirst
from django.utils.translation import gettext_lazy as _
from modelcluster.fields import ParentalKey
from modelcluster.models import ClusterableModel, get_all_child_relations
from taggit.models import ItemBase
from wagtail.blocks import StreamBlock
from wagtail.fields import StreamField
class ReferenceGroups:
"""
Groups records in a ReferenceIndex queryset by their source object.
Args:
qs: (QuerySet[ReferenceIndex]) A QuerySet on the ReferenceIndex model
Yields:
A tuple (source_object, references) for each source object that appears
in the queryset. source_object is the model instance of the source object
and references is a list of references that occur in the QuerySet from
that source object.
"""
def __init__(self, qs):
self.qs = qs.order_by("base_content_type", "object_id")
def __iter__(self):
reference_fk = None
references = []
for reference in self.qs:
if reference_fk != (reference.base_content_type_id, reference.object_id):
if reference_fk is not None:
content_type = ContentType.objects.get_for_id(reference_fk[0])
object = content_type.get_object_for_this_type(pk=reference_fk[1])
yield object, references
references = []
reference_fk = (reference.base_content_type_id, reference.object_id)
references.append(reference)
if references:
content_type = ContentType.objects.get_for_id(reference_fk[0])
object = content_type.get_object_for_this_type(pk=reference_fk[1])
yield object, references
def __len__(self):
return self._count
@cached_property
def _count(self):
return self.qs.values("base_content_type", "object_id").distinct().count()
@cached_property
def is_protected(self):
return any(reference.on_delete == models.PROTECT for reference in self.qs)
def count(self):
"""
Returns the number of rows that will be returned by iterating this
ReferenceGroups.
Just calls len(self) internally, this method only exists to allow
instances of this class to be used in a Paginator.
"""
return len(self)
def __getitem__(self, key):
return list(self)[key]
class ReferenceIndexQuerySet(models.QuerySet):
def group_by_source_object(self):
"""
Returns a ReferenceGroups object for this queryset that will yield
references grouped by their source instance.
"""
return ReferenceGroups(self)
class ReferenceIndex(models.Model):
"""
Records references between objects for quick retrieval of object usage.
References are extracted from Foreign Keys, Chooser Blocks in StreamFields, and links in Rich Text Fields.
This index allows us to efficiently find all of the references to a particular object from all of these sources.
"""
# The object where the reference was extracted from
# content_type represents the content type of the model that contains
# the field where the reference came from. If the model sub-classes another
# concrete model (such as Page), that concrete model will be set in
# base_content_type, otherwise it would be the same as content_type
content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="+"
)
base_content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="+"
)
object_id = models.CharField(
max_length=255,
verbose_name=_("object id"),
)
# The object that has been referenced
# to_content_type is always the base content type of the referenced object
to_content_type = models.ForeignKey(
ContentType, on_delete=models.CASCADE, related_name="+"
)
to_object_id = models.CharField(
max_length=255,
verbose_name=_("object id"),
)
# The model_path is the path to the field on content_type where the reference was extracted from.
# the content_path is the path to a specific block on the instance where the reference is extracted from.
# These are dotted path, always starting with a field or child relation name. If
# the reference was extracted from an inline panel or streamfield, other components
# of the path can be used to locate where the reference was extracted.
#
# For example, say we have a StreamField called 'body' which has a struct block type
# called 'my_struct_block' that has a field called 'my_field'. If we extracted a
# reference from that field, the model_path would be set to the following:
#
# 'body.my_struct_block.my_field'
#
# The content path would follow the same format, but anything repeatable would be replaced by an ID.
# For example:
#
# 'body.bdc70d8b-e7a2-4c2a-bf43-2a3e3fcbbe86.my_field'
#
# We can use the model_path with the 'content_type' to find the original definition of
# the field block and display information to the user about where the reference was
# extracted from.
#
# We can use the content_path to link the user directly to the block/field that contains
# the reference.
model_path = models.TextField()
content_path = models.TextField()
# We need a separate hash field for content_path in order to use it in a unique key because
# MySQL has a limit to the size of fields that are included in unique keys
content_path_hash = models.UUIDField()
objects = ReferenceIndexQuerySet.as_manager()
wagtail_reference_index_ignore = True
# The set of models that should have signals attached to watch for outbound references.
# This includes those registered with `register_model`, as well as their child models
# linked by a ParentalKey.
tracked_models = set()
# The set of models that can appear as the 'from' object in the reference index.
# This only includes those registered with `register_model`, and NOT child models linked
# by ParentalKey (object references on those are recorded under the parent).
indexed_models = set()
class Meta:
unique_together = [
(
"base_content_type",
"object_id",
"to_content_type",
"to_object_id",
"content_path_hash",
)
]
@classmethod
def _get_base_content_type(cls, model_or_object):
"""
Returns the ContentType record that represents the base model of the
given model or object.
For a model that uses multi-table-inheritance, this returns the model
that contains the primary key. For example, for any page object, this
will return the content type of the Page model.
"""
parents = model_or_object._meta.get_parent_list()
if parents:
return ContentType.objects.get_for_model(
parents[-1], for_concrete_model=False
)
else:
return ContentType.objects.get_for_model(
model_or_object, for_concrete_model=False
)
@classmethod
def model_is_indexable(cls, model, allow_child_models=False):
"""
Returns True if the given model may have outbound references that we would be interested in recording in the index.
Args:
model (type): a Django model class
allow_child_models (boolean): Child models are not indexable on their own. If you are looking at
a child model from the perspective of indexing it through its parent,
set this to True to disable checking for this. Default False.
"""
if getattr(model, "wagtail_reference_index_ignore", False):
return False
# Don't check any models that have a parental key, references from these will be collected from the parent
if not allow_child_models and any(
isinstance(field, ParentalKey) for field in model._meta.get_fields()
):
return False
for field in model._meta.get_fields():
if field.is_relation and field.many_to_one:
if getattr(field, "wagtail_reference_index_ignore", False):
continue
if getattr(
field.related_model, "wagtail_reference_index_ignore", False
):
continue
if isinstance(field, (ParentalKey, GenericRel)):
continue
return True
if hasattr(field, "extract_references"):
return True
if issubclass(model, ClusterableModel):
for child_relation in get_all_child_relations(model):
if cls.model_is_indexable(
child_relation.related_model,
allow_child_models=True,
):
return True
return False
@classmethod
def register_model(cls, model):
"""
Registers the model for indexing.
"""
if model in cls.indexed_models:
return
if cls.model_is_indexable(model):
cls.indexed_models.add(model)
cls._register_as_tracked_model(model)
@classmethod
def _register_as_tracked_model(cls, model):
"""
Add the model and all of its ParentalKey-linked children to the set of
models to be tracked by signal handlers.
"""
if model in cls.tracked_models:
return
from wagtail.signal_handlers import (
connect_reference_index_signal_handlers_for_model,
)
cls.tracked_models.add(model)
connect_reference_index_signal_handlers_for_model(model)
for child_relation in get_all_child_relations(model):
if cls.model_is_indexable(
child_relation.related_model,
allow_child_models=True,
):
cls._register_as_tracked_model(child_relation.related_model)
@classmethod
def is_indexed(cls, model):
return model in cls.indexed_models
@classmethod
def _extract_references_from_object(cls, object):
"""
Generator that scans the given object and yields any references it finds.
Args:
object (Model): an instance of a Django model to scan for references
Yields:
A tuple (content_type_id, object_id, model_path, content_path) for each
reference found.
content_type_id (int): The ID of the ContentType record representing
the model of the referenced object
object_id (str): The primary key of the referenced object, converted
to a string
model_path (str): The path to the field on the model of the source
object where the reference was found
content_path (str): The path to the piece of content on the source
object instance where the reference was found
"""
# Extract references from fields
for field in object._meta.get_fields():
if field.is_relation and field.many_to_one:
if getattr(field, "wagtail_reference_index_ignore", False):
continue
if getattr(
field.related_model, "wagtail_reference_index_ignore", False
):
continue
if isinstance(field, (ParentalKey, GenericRel)):
continue
if isinstance(field, GenericForeignKey):
ct_field = object._meta.get_field(field.ct_field)
fk_field = object._meta.get_field(field.fk_field)
ct_value = ct_field.value_from_object(object)
fk_value = fk_field.value_from_object(object)
if ct_value is not None and fk_value is not None:
# The content type ID referenced by the GenericForeignKey might be a subclassed
# model, but the reference index requires us to index it under the base model's
# content type, as that's what will be used for lookups. So, we need to convert
# the content type back to a model class so that _get_base_content_type can
# select the appropriate superclass if necessary, before converting back to a
# content type.
model = ContentType.objects.get_for_id(ct_value).model_class()
yield (
cls._get_base_content_type(model).id,
str(fk_value),
field.name,
field.name,
)
continue
if isinstance(field, GenericRel):
continue
value = field.value_from_object(object)
if value is not None:
yield (
cls._get_base_content_type(field.related_model).id,
str(value),
field.name,
field.name,
)
if hasattr(field, "extract_references"):
value = field.value_from_object(object)
if value is not None:
yield from (
(
cls._get_base_content_type(to_model).id,
to_object_id,
f"{field.name}.{model_path}",
f"{field.name}.{content_path}",
)
for to_model, to_object_id, model_path, content_path in field.extract_references(
value
)
)
# Extract references from child relations
if isinstance(object, ClusterableModel):
for child_relation in get_all_child_relations(object):
relation_name = child_relation.get_accessor_name()
child_objects = getattr(object, relation_name).all()
for child_object in child_objects:
yield from (
(
to_content_type_id,
to_object_id,
f"{relation_name}.item.{model_path}",
f"{relation_name}.{str(child_object.id)}.{content_path}",
)
for to_content_type_id, to_object_id, model_path, content_path in cls._extract_references_from_object(
child_object
)
)
@classmethod
def _get_content_path_hash(cls, content_path):
"""
Returns a UUID for the given content path. Used to enforce uniqueness.
Note: MySQL has a limit on the length of fields that are used in unique keys so
we need a separate hash field to allow us to support long content paths.
Args:
content_path (str): The content path to get a hash for
Returns:
A UUID instance containing the hash of the given content path
"""
return uuid.uuid5(
uuid.UUID("bdc70d8b-e7a2-4c2a-bf43-2a3e3fcbbe86"), content_path
)
@classmethod
def create_or_update_for_object(cls, object):
"""
Creates or updates ReferenceIndex records for the given object.
This method will extract any outbound references from the given object
and insert/update them in the database.
Note: This method must be called within a `django.db.transaction.atomic()` block.
Args:
object (Model): The model instance to create/update ReferenceIndex records for
"""
# For the purpose of this method, a "reference record" is a tuple of
# (to_content_type_id, to_object_id, model_path, content_path) - the properties that
# uniquely define a reference
# Extract new references and construct a set of reference records
references = set(cls._extract_references_from_object(object))
# Find content types for this model and all of its ancestor classes,
# ordered from most to least specific
content_types = [
ContentType.objects.get_for_model(model_or_object, for_concrete_model=False)
for model_or_object in ([object] + object._meta.get_parent_list())
]
content_type = content_types[0]
base_content_type = content_types[-1]
known_content_type_ids = [ct.id for ct in content_types]
# Find existing references in the database so we know what to add/delete.
# Construct a dict mapping reference records to the (content_type_id, id) pair that the
# existing database entry is found under
existing_references = {
(to_content_type_id, to_object_id, model_path, content_path): (
content_type_id,
id,
)
for id, content_type_id, to_content_type_id, to_object_id, model_path, content_path in cls.objects.filter(
base_content_type=base_content_type, object_id=object.pk
).values_list(
"id",
"content_type_id",
"to_content_type",
"to_object_id",
"model_path",
"content_path",
)
}
# Construct the set of reference records that have been found on the object but are not
# already present in the database
new_references = references - set(existing_references.keys())
bulk_create_kwargs = {}
if connection.features.supports_ignore_conflicts:
bulk_create_kwargs["ignore_conflicts"] = True
# Create database records for those reference records
cls.objects.bulk_create(
[
cls(
content_type=content_type,
base_content_type=base_content_type,
object_id=object.pk,
to_content_type_id=to_content_type_id,
to_object_id=to_object_id,
model_path=model_path,
content_path=content_path,
content_path_hash=cls._get_content_path_hash(content_path),
)
for to_content_type_id, to_object_id, model_path, content_path in new_references
],
**bulk_create_kwargs,
)
# Delete removed references
deleted_reference_ids = []
# Look at the reference record and the supporting content_type / id for each existing
# reference in the database
for reference_data, (content_type_id, id) in existing_references.items():
if reference_data in references:
# Do not delete this reference, as it is still present in the new set
continue
if content_type_id not in known_content_type_ids:
# The content type for the existing record does not match the current model or any
# superclass. We can infer that the existing record is for a more specific subclass
# than the one we're currently indexing - e.g. we are indexing <Page id=123> while
# the existing reference was recorded against <BlogPage id=123>. In this case, do
# not treat the missing reference as a deletion - it likely still exists, but on a
# relation which can only be seen on the more specific model.
continue
# If we reach here, this is a legitimate deletion - add it to the list of IDs to delete
deleted_reference_ids.append(id)
# Perform the deletion
cls.objects.filter(id__in=deleted_reference_ids).delete()
@classmethod
def remove_for_object(cls, object):
"""
Deletes all outbound references for the given object.
Use this before deleting the object itself.
Args:
object (Model): The model instance to delete ReferenceIndex records for
"""
base_content_type = cls._get_base_content_type(object)
cls.objects.filter(
base_content_type=base_content_type, object_id=object.pk
).delete()
@classmethod
def get_references_for_object(cls, object):
"""
Returns all outbound references for the given object.
Args:
object (Model): The model instance to fetch ReferenceIndex records for
Returns:
A QuerySet of ReferenceIndex records
"""
return cls.objects.filter(
base_content_type_id=cls._get_base_content_type(object),
object_id=object.pk,
)
@classmethod
def get_references_to(cls, object):
"""
Returns all inbound references for the given object.
Args:
object (Model): The model instance to fetch ReferenceIndex records for
Returns:
A QuerySet of ReferenceIndex records
"""
return cls.objects.filter(
to_content_type_id=cls._get_base_content_type(object),
to_object_id=object.pk,
)
@classmethod
def get_grouped_references_to(cls, object):
"""
Returns all inbound references for the given object, grouped by the object
they are found on.
Args:
object (Model): The model instance to fetch ReferenceIndex records for
Returns:
A ReferenceGroups object
"""
return cls.get_references_to(object).group_by_source_object()
@property
def _content_type(self):
# Accessing a ContentType from a ForeignKey does not make use of the
# ContentType manager's cache, so we use this property to make use of
# the cache.
return ContentType.objects.get_for_id(self.content_type_id)
@cached_property
def model_name(self):
"""
The model name of the object from which the reference was extracted.
For most cases, this is also where the reference exists on the database
(i.e. ``related_field_model_name``). However, for ClusterableModels, the
reference is extracted from the parent model.
Example:
A relationship between a BlogPage, BlogPageGalleryImage, and Image
is extracted from the BlogPage model, but the reference is stored on
on the BlogPageGalleryImage model.
"""
return self._content_type.name
@cached_property
def related_field_model_name(self):
"""
The model name where the reference exists on the database.
"""
return self.related_field.model._meta.verbose_name
@cached_property
def on_delete(self):
try:
return self.reverse_related_field.on_delete
except AttributeError:
# It might be a custom field/relation that doesn't have an on_delete attribute,
# or other reference collected from extract_references(), e.g. StreamField.
return models.SET_NULL
@cached_property
def source_field(self):
"""
The field from which the reference was extracted.
This may be a related field (e.g. ForeignKey), a reverse related field
(e.g. ManyToOneRel), a StreamField, or any other field that defines
extract_references().
"""
model_path_components = self.model_path.split(".")
field_name = model_path_components[0]
field = self._content_type.model_class()._meta.get_field(field_name)
return field
@cached_property
def related_field(self):
# The field stored on the reference index can be a related field or a
# reverse related field, depending on whether the reference was extracted
# directly from a ForeignKey or through a parent ClusterableModel. This
# property normalises to the related field.
if isinstance(self.source_field, models.ForeignObjectRel):
return self.source_field.remote_field
return self.source_field
@cached_property
def reverse_related_field(self):
# This property normalises to the reverse related field, which is where
# the on_delete attribute is stored.
return self.related_field.remote_field
def describe_source_field(self):
"""
Returns a string describing the field that this reference was extracted from.
For StreamField, this returns the label of the block that contains the reference.
For other fields, this returns the verbose name of the field.
"""
field = self.source_field
model_path_components = self.model_path.split(".")
# ManyToOneRel (reverse accessor for ParentalKey) does not have a verbose name. So get the name of the child field instead
if isinstance(field, models.ManyToOneRel):
child_field = field.related_model._meta.get_field(model_path_components[2])
return capfirst(child_field.verbose_name)
elif isinstance(field, StreamField):
label = f"{capfirst(field.verbose_name)}"
block = field.stream_block
block_idx = 1
while isinstance(block, StreamBlock):
block = block.child_blocks[model_path_components[block_idx]]
block_label = capfirst(block.label)
label += f"{block_label}"
block_idx += 1
return label
else:
try:
field_name = field.verbose_name
except AttributeError:
# generate verbose name from field name in the same way that Django does:
# https://github.com/django/django/blob/7b94847e384b1a8c05a7d4c8778958c0290bdf9a/django/db/models/fields/__init__.py#L858
field_name = field.name.replace("_", " ")
return capfirst(field_name)
def describe_on_delete(self):
"""
Returns a string describing the action that will be taken when the referenced object is deleted.
"""
if self.on_delete == models.CASCADE:
return _("the %(model_name)s will also be deleted") % {
"model_name": self.related_field_model_name,
}
if self.on_delete == models.PROTECT:
return _("prevents deletion")
if self.on_delete == models.SET_DEFAULT:
return _("will be set to the default %(model_name)s") % {
"model_name": self.related_field_model_name,
}
if self.on_delete == models.DO_NOTHING:
return _("will do nothing")
# It's technically possible to know whether RESTRICT will prevent the
# deletion or not, but the only way to reliably do so is to use Django's
# internal Collector class, which is not publicly documented.
# It also uses its own logic to find the references in real-time, which
# may be slower than our ReferenceIndex. For now, we'll just say that
# RESTRICT *may* prevent deletion, but we do not add any safe guards
# around the possible exception.
if self.on_delete == models.RESTRICT:
return _("may prevent deletion")
# SET is a function that returns the actual callable used for on_delete,
# so we need to check for it by inspecting the deconstruct() result.
if (
hasattr(self.on_delete, "deconstruct")
and self.on_delete.deconstruct()[0] == "django.db.models.SET"
):
return _("will be set to a %(model_name)s specified by the system") % {
"model_name": self.related_field_model_name,
}
# It's either models.SET_NULL or a custom value, but we cannot be sure what
# will happen with the latter, so assume that the reference will be unset.
return _("will unset the reference")
# Ignore relations formed by any django-taggit 'through' model, as this causes any tag attached to
# a tagged object to appear as a reference to that object. Ideally we would follow the reference to
# the Tag model so that we can use the references index to find uses of a tag, but doing that
# correctly will require support for ManyToMany relations with through models:
# https://github.com/wagtail/wagtail/issues/9629
ItemBase.wagtail_reference_index_ignore = True

View File

@@ -0,0 +1,269 @@
from collections import namedtuple
from django.apps import apps
from django.conf import settings
from django.core.cache import cache
from django.core.exceptions import ValidationError
from django.db import models
from django.db.models import Case, IntegerField, Q, When
from django.db.models.functions import Lower
from django.http.request import split_domain_port
from django.utils.translation import gettext_lazy as _
MATCH_HOSTNAME_PORT = 0
MATCH_HOSTNAME_DEFAULT = 1
MATCH_DEFAULT = 2
MATCH_HOSTNAME = 3
def get_site_for_hostname(hostname, port):
"""Return the wagtailcore.Site object for the given hostname and port."""
Site = apps.get_model("wagtailcore.Site")
sites = list(
Site.objects.annotate(
match=Case(
# annotate the results by best choice descending
# put exact hostname+port match first
When(hostname=hostname, port=port, then=MATCH_HOSTNAME_PORT),
# then put hostname+default (better than just hostname or just default)
When(
hostname=hostname, is_default_site=True, then=MATCH_HOSTNAME_DEFAULT
),
# then match default with different hostname. there is only ever
# one default, so order it above (possibly multiple) hostname
# matches so we can use sites[0] below to access it
When(is_default_site=True, then=MATCH_DEFAULT),
# because of the filter below, if it's not default then its a hostname match
default=MATCH_HOSTNAME,
output_field=IntegerField(),
)
)
.filter(Q(hostname=hostname) | Q(is_default_site=True))
.order_by("match")
.select_related("root_page")
)
if sites:
# if there's a unique match or hostname (with port or default) match
if len(sites) == 1 or sites[0].match in (
MATCH_HOSTNAME_PORT,
MATCH_HOSTNAME_DEFAULT,
):
return sites[0]
# if there is a default match with a different hostname, see if
# there are many hostname matches. if only 1 then use that instead
# otherwise we use the default
if sites[0].match == MATCH_DEFAULT:
return sites[len(sites) == 2]
raise Site.DoesNotExist()
class SiteManager(models.Manager):
def get_queryset(self):
return super().get_queryset().order_by(Lower("hostname"))
def get_by_natural_key(self, hostname, port):
return self.get(hostname=hostname, port=port)
SiteRootPath = namedtuple("SiteRootPath", "site_id root_path root_url language_code")
SITE_ROOT_PATHS_CACHE_KEY = "wagtail_site_root_paths"
# Increase the cache version whenever the structure SiteRootPath tuple changes
SITE_ROOT_PATHS_CACHE_VERSION = 2
class Site(models.Model):
hostname = models.CharField(
verbose_name=_("hostname"), max_length=255, db_index=True
)
port = models.IntegerField(
verbose_name=_("port"),
default=80,
help_text=_(
"Set this to something other than 80 if you need a specific port number to appear in URLs"
" (e.g. development on port 8000). Does not affect request handling (so port forwarding still works)."
),
)
site_name = models.CharField(
verbose_name=_("site name"),
max_length=255,
blank=True,
help_text=_("Human-readable name for the site."),
)
root_page = models.ForeignKey(
"Page",
verbose_name=_("root page"),
related_name="sites_rooted_here",
on_delete=models.CASCADE,
)
is_default_site = models.BooleanField(
verbose_name=_("is default site"),
default=False,
help_text=_(
"If true, this site will handle requests for all other hostnames that do not have a site entry of their own"
),
)
objects = SiteManager()
class Meta:
unique_together = ("hostname", "port")
verbose_name = _("site")
verbose_name_plural = _("sites")
def natural_key(self):
return (self.hostname, self.port)
def __str__(self):
default_suffix = " [{}]".format(_("default"))
if self.site_name:
return self.site_name + (default_suffix if self.is_default_site else "")
else:
return (
self.hostname
+ ("" if self.port == 80 else (":%d" % self.port))
+ (default_suffix if self.is_default_site else "")
)
def clean(self):
self.hostname = self.hostname.lower()
@staticmethod
def find_for_request(request):
"""
Find the site object responsible for responding to this HTTP
request object. Try:
* unique hostname first
* then hostname and port
* if there is no matching hostname at all, or no matching
hostname:port combination, fall back to the unique default site,
or raise an exception
NB this means that high-numbered ports on an extant hostname may
still be routed to a different hostname which is set as the default
The site will be cached via request._wagtail_site
"""
if request is None:
return None
if not hasattr(request, "_wagtail_site"):
site = Site._find_for_request(request)
setattr(request, "_wagtail_site", site)
return request._wagtail_site
@staticmethod
def _find_for_request(request):
hostname = split_domain_port(request.get_host())[0]
port = request.get_port()
site = None
try:
site = get_site_for_hostname(hostname, port)
except Site.DoesNotExist:
pass
# copy old SiteMiddleware behaviour
return site
@property
def root_url(self):
if self.port == 80:
return "http://%s" % self.hostname
elif self.port == 443:
return "https://%s" % self.hostname
else:
return "http://%s:%d" % (self.hostname, self.port)
def clean_fields(self, exclude=None):
super().clean_fields(exclude)
# Only one site can have the is_default_site flag set
try:
default = Site.objects.get(is_default_site=True)
except Site.DoesNotExist:
pass
except Site.MultipleObjectsReturned:
raise
else:
if self.is_default_site and self.pk != default.pk:
raise ValidationError(
{
"is_default_site": [
_(
"%(hostname)s is already configured as the default site."
" You must unset that before you can save this site as default."
)
% {"hostname": default.hostname}
]
}
)
@staticmethod
def get_site_root_paths():
"""
Return a list of `SiteRootPath` instances, most specific path
first - used to translate url_paths into actual URLs with hostnames
Each root path is an instance of the `SiteRootPath` named tuple,
and have the following attributes:
- `site_id` - The ID of the Site record
- `root_path` - The internal URL path of the site's home page (for example '/home/')
- `root_url` - The scheme/domain name of the site (for example 'https://www.example.com/')
- `language_code` - The language code of the site (for example 'en')
"""
result = cache.get(
SITE_ROOT_PATHS_CACHE_KEY, version=SITE_ROOT_PATHS_CACHE_VERSION
)
if result is None:
result = []
for site in Site.objects.select_related(
"root_page", "root_page__locale"
).order_by("-root_page__url_path", "-is_default_site", "hostname"):
if getattr(settings, "WAGTAIL_I18N_ENABLED", False):
result.extend(
[
SiteRootPath(
site.id,
root_page.url_path,
site.root_url,
root_page.locale.language_code,
)
for root_page in site.root_page.get_translations(
inclusive=True
).select_related("locale")
]
)
else:
result.append(
SiteRootPath(
site.id,
site.root_page.url_path,
site.root_url,
site.root_page.locale.language_code,
)
)
cache.set(
SITE_ROOT_PATHS_CACHE_KEY,
result,
3600,
version=SITE_ROOT_PATHS_CACHE_VERSION,
)
else:
# Convert the cache result to a list of SiteRootPath tuples, as some
# cache backends (e.g. Redis) don't support named tuples.
result = [SiteRootPath(*result) for result in result]
return result
@staticmethod
def clear_site_root_paths_cache():
cache.delete(SITE_ROOT_PATHS_CACHE_KEY, version=SITE_ROOT_PATHS_CACHE_VERSION)

View File

@@ -0,0 +1,128 @@
from django.contrib.contenttypes.models import ContentType
from django.db.models import DEFERRED
from django.utils.functional import cached_property
class SpecificMixin:
"""
Mixin for models that support multi-table inheritance and provide a
``content_type`` field pointing to the specific model class, to provide
methods and properties for retrieving the specific instance of the model.
"""
def get_specific(self, deferred=False, copy_attrs=None, copy_attrs_exclude=None):
"""
Return this object in its most specific subclassed form.
By default, a database query is made to fetch all field values for the
specific object. If you only require access to custom methods or other
non-field attributes on the specific object, you can use
``deferred=True`` to avoid this query. However, any attempts to access
specific field values from the returned object will trigger additional
database queries.
By default, references to all non-field attribute values are copied
from current object to the returned one. This includes:
* Values set by a queryset, for example: annotations, or values set as
a result of using ``select_related()`` or ``prefetch_related()``.
* Any ``cached_property`` values that have been evaluated.
* Attributes set elsewhere in Python code.
For fine-grained control over which non-field values are copied to the
returned object, you can use ``copy_attrs`` to specify a complete list
of attribute names to include. Alternatively, you can use
``copy_attrs_exclude`` to specify a list of attribute names to exclude.
If called on an object that is already an instance of the most specific
class, the object will be returned as is, and no database queries or
other operations will be triggered.
If the object was originally created using a model that has since
been removed from the codebase, an instance of the base class will be
returned (without any custom field values or other functionality
present on the original class). Usually, deleting these objects is the
best course of action, but there is currently no safe way for Wagtail
to do that at migration time.
"""
model_class = self.specific_class
if model_class is None:
# The codebase and database are out of sync (e.g. the model exists
# on a different git branch and migrations were not applied or
# reverted before switching branches). So, the best we can do is
# return the page in it's current form.
return self
if isinstance(self, model_class):
# self is already an instance of the most specific class.
return self
if deferred:
# Generate a tuple of values in the order expected by __init__(),
# with missing values substituted with DEFERRED ()
values = tuple(
getattr(self, f.attname, self.pk if f.primary_key else DEFERRED)
for f in model_class._meta.concrete_fields
)
# Create object from known attribute values
specific_obj = model_class(*values)
specific_obj._state.adding = self._state.adding
else:
# Fetch object from database
specific_obj = model_class._default_manager.get(id=self.id)
# Copy non-field attribute values
if copy_attrs is not None:
for attr in (attr for attr in copy_attrs if attr in self.__dict__):
setattr(specific_obj, attr, getattr(self, attr))
else:
exclude = copy_attrs_exclude or ()
for k, v in ((k, v) for k, v in self.__dict__.items() if k not in exclude):
# only set values that haven't already been set
specific_obj.__dict__.setdefault(k, v)
return specific_obj
@cached_property
def specific(self):
"""
Returns this object in its most specific subclassed form with all field
values fetched from the database. The result is cached in memory.
"""
return self.get_specific()
@cached_property
def specific_deferred(self):
"""
Returns this object in its most specific subclassed form without any
additional field values being fetched from the database. The result
is cached in memory.
"""
return self.get_specific(deferred=True)
@cached_property
def specific_class(self):
"""
Return the class that this object would be if instantiated in its
most specific form.
If the model class can no longer be found in the codebase, and the
relevant ``ContentType`` has been removed by a database migration,
the return value will be ``None``.
If the model class can no longer be found in the codebase, but the
relevant ``ContentType`` is still present in the database (usually a
result of switching between git branches without running or reverting
database migrations beforehand), the return value will be ``None``.
"""
return self.cached_content_type.model_class()
@property
def cached_content_type(self):
"""
Return this object's ``content_type`` value from the ``ContentType``
model's cached manager, which will avoid a database query if the
content type is already in memory.
"""
return ContentType.objects.get_for_id(self.content_type_id)

View File

@@ -0,0 +1,81 @@
"""
Base model definitions for validating front-end user access to resources such as pages and
documents. These may be subclassed to accommodate specific models such as Page or Collection,
but the definitions here should remain generic and not depend on the base wagtail.models
module or specific models defined there.
"""
from django.conf import settings
from django.contrib.auth.models import Group
from django.db import models
from django.utils.translation import gettext_lazy as _
class BaseViewRestriction(models.Model):
NONE = "none"
PASSWORD = "password"
GROUPS = "groups"
LOGIN = "login"
RESTRICTION_CHOICES = (
(NONE, _("Public")),
(PASSWORD, _("Private, accessible with a shared password")),
(LOGIN, _("Private, accessible to any logged-in users")),
(GROUPS, _("Private, accessible to users in specific groups")),
)
restriction_type = models.CharField(max_length=20, choices=RESTRICTION_CHOICES)
password = models.CharField(
verbose_name=_("shared password"),
max_length=255,
blank=True,
help_text=_(
"Shared passwords should not be used to protect sensitive content. Anyone who has this password will be able to view the content."
),
)
groups = models.ManyToManyField(Group, verbose_name=_("groups"), blank=True)
def accept_request(self, request):
if self.restriction_type == BaseViewRestriction.PASSWORD:
passed_restrictions = request.session.get(
self.passed_view_restrictions_session_key, []
)
if self.id not in passed_restrictions:
return False
elif self.restriction_type == BaseViewRestriction.LOGIN:
if not request.user.is_authenticated:
return False
elif self.restriction_type == BaseViewRestriction.GROUPS:
if not request.user.is_superuser:
current_user_groups = request.user.groups.all()
if not any(group in current_user_groups for group in self.groups.all()):
return False
return True
def mark_as_passed(self, request):
"""
Update the session data in the request to mark the user as having passed this
view restriction
"""
has_existing_session = settings.SESSION_COOKIE_NAME in request.COOKIES
passed_restrictions = request.session.setdefault(
self.passed_view_restrictions_session_key, []
)
if self.id not in passed_restrictions:
passed_restrictions.append(self.id)
request.session[
self.passed_view_restrictions_session_key
] = passed_restrictions
if not has_existing_session:
# if this is a session we've created, set it to expire at the end
# of the browser session
request.session.set_expiry(0)
class Meta:
abstract = True
verbose_name = _("view restriction")
verbose_name_plural = _("view restrictions")