Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1,165 @@
from __future__ import unicode_literals
from __future__ import absolute_import
from taggit import VERSION as TAGGIT_VERSION
from taggit.managers import TaggableManager, _TaggableManager
from taggit.utils import require_instance_manager
from modelcluster.queryset import FakeQuerySet
if TAGGIT_VERSION < (0, 20, 0):
raise Exception("modelcluster.contrib.taggit requires django-taggit version 0.20 or above")
class _ClusterTaggableManager(_TaggableManager):
@require_instance_manager
def get_tagged_item_manager(self):
"""Return the manager that handles the relation from this instance to the tagged_item class.
If content_object on the tagged_item class is defined as a ParentalKey, this will be a
DeferringRelatedManager which allows writing related objects without committing them
to the database.
"""
rel_name = self.through._meta.get_field('content_object').remote_field.get_accessor_name()
return getattr(self.instance, rel_name)
def get_queryset(self, extra_filters=None):
if self.instance is not None:
tagged_item_manager = self.get_tagged_item_manager()
# If we're already managing tags in memory for this instance,
# we want to return those uncommitted changes. This shouldn't
# require a request to the database.
if tagged_item_manager.is_deferring:
return FakeQuerySet(
self.through.tag_model(),
[tagged_item.tag for tagged_item in tagged_item_manager.all()],
)
# If we don't have any uncommitted changes for this instance,
# we'd ideally like to use the default taggit logic. There's one
# case that we need to handle specially, which is the ability to
# query tags on an unsaved model instance, for example:
#
# class TaggedPlace(TaggedItemBase):
# content_object = ParentalKey(
# "Place",
# related_name="tagged_items",
# on_delete=models.CASCADE,
# )
#
# class Place(ClusterableModel):
# tags = ClusterTaggableManager(
# through=TaggedPlace,
# blank=True,
# )
#
# instance = Place()
# instance.tags.count()
#
# Under the hood this call invokes this get_queryset method with an
# unsaved self.instance, which would trigger this query using the
# default taggit logic:
#
# TaggedPlace.objects.filter(content_object=Place())
#
# This works on Django < 5.0, returning an empty list as expected.
# But as of Django 5.0, passing unsaved model instances to related
# filters is no longer allowed, see
# https://code.djangoproject.com/ticket/31486.
#
# To handle this case we return an empty tag list since there won't
# be any existing tags in the database for an unsaved instance.
elif self.instance.pk is None:
return FakeQuerySet(self.through.tag_model(), [])
# If we've reached this point then either this manager isn't associated
# with a specific model, which probably means it's being invoked within
# a prefetch_related operation:
#
# Place.objects.prefetch_related("tags")
#
# or we're fetching tags for a model instance that doesn't have any
# uncommitted tag changes in memory:
#
# place = Place.objects.first()
# place.tags.all()
#
# In these cases we can fallback to the default taggit manager behavior
# which will fetch the tags from the database.
return super().get_queryset(extra_filters)
@require_instance_manager
def add(self, *tags):
if TAGGIT_VERSION >= (3, 1, 0):
self._remove_prefetched_objects()
if TAGGIT_VERSION >= (1, 3, 0):
tag_objs = self._to_tag_model_instances(tags, {})
else:
tag_objs = self._to_tag_model_instances(tags)
# Now write these to the relation
tagged_item_manager = self.get_tagged_item_manager()
for tag in tag_objs:
if not tagged_item_manager.filter(tag=tag):
# make an instance of the self.through model and add it to the relation
tagged_item = self.through(tag=tag)
tagged_item_manager.add(tagged_item)
@require_instance_manager
def remove(self, *tags):
if TAGGIT_VERSION >= (3, 1, 0):
self._remove_prefetched_objects()
tagged_item_manager = self.get_tagged_item_manager()
tagged_items = [
tagged_item for tagged_item in tagged_item_manager.all()
if tagged_item.tag.name in tags
]
tagged_item_manager.remove(*tagged_items)
@require_instance_manager
def set(self, *args, **kwargs):
# Ignore the 'clear' kwarg (which defaults to False) and override it to be always true;
# this means that set is implemented as a clear then an add, which was the standard behaviour
# prior to django-taggit 0.19 (https://github.com/alex/django-taggit/commit/6542a702b590a5cfb91ea0de218b7f71ffd07c33).
#
# In this way, we avoid a live database lookup that occurs in the clear=False branch.
#
# The clear=True behaviour is fine for our purposes; the distinction only exists in django-taggit
# to ensure that the correct set of m2m_changed signals is fired, and our reimplementation here
# doesn't fire them at all (which makes logical sense, because the whole point of this module is
# that the add/remove/set/clear operations don't write to the database).
#
# super().set() already calls self._remove_prefetched_objects() so we don't need to do so here.
return super().set(*args, clear=True)
@require_instance_manager
def clear(self):
if TAGGIT_VERSION >= (3, 1, 0):
self._remove_prefetched_objects()
self.get_tagged_item_manager().clear()
class ClusterTaggableManager(TaggableManager):
_need_commit_after_assignment = True
def __get__(self, instance, model):
# override TaggableManager's requirement for instance to have a primary key
# before we can access its tags
manager = _ClusterTaggableManager(
through=self.through, model=model, instance=instance, prefetch_cache_name=self.name
)
return manager
def value_from_object(self, instance):
# retrieve the queryset via the related manager on the content object,
# to accommodate the possibility of this having uncommitted changes relative to
# the live database
rel_name = self.through._meta.get_field('content_object').remote_field.get_accessor_name()
ret = getattr(instance, rel_name).all()
if TAGGIT_VERSION >= (1, ): # expects a Tag list instead of TaggedItem List
ret = [tagged_item.tag for tagged_item in ret]
return ret

View File

@@ -0,0 +1,119 @@
import datetime
from django import forms
TIMEFIELD_TRANSFORM_EXPRESSIONS = {"hour", "minute", "second"}
DATEFIELD_TRANSFORM_EXPRESSIONS = {
"year",
"iso_year",
"month",
"day",
"week",
"week_day",
"iso_week_day",
"quarter",
}
DATETIMEFIELD_TRANSFORM_EXPRESSIONS = (
{"date", "time"}
| TIMEFIELD_TRANSFORM_EXPRESSIONS
| DATEFIELD_TRANSFORM_EXPRESSIONS
)
TRANSFORM_FIELD_TYPES = {
"year": forms.IntegerField,
"iso_year": forms.IntegerField,
"month": forms.IntegerField,
"hour": forms.IntegerField,
"minute": forms.IntegerField,
"second": forms.IntegerField,
"day": forms.IntegerField,
"week": forms.IntegerField,
"week_day": forms.IntegerField,
"iso_week_day": forms.IntegerField,
"quarter": forms.IntegerField,
"date": forms.DateField,
"time": forms.TimeField,
}
def derive_from_value(value, expr):
if isinstance(value, datetime.datetime):
return derive_from_datetime(value, expr)
if isinstance(value, datetime.date):
return derive_from_date(value, expr)
if isinstance(value, datetime.time):
return derive_from_time(value, expr)
return None
def derive_from_time(value, expr):
"""
Mimics the behaviour of the ``hour``, ``minute`` and ``second`` lookup
expressions that Django querysets support for ``TimeField`` and
``DateTimeField``, by extracting the relevant value from an in-memory
``time`` or ``datetime`` value.
"""
if expr == "hour":
return value.hour
if expr == "minute":
return value.minute
if expr == "second":
return value.second
raise ValueError(
"Expression '{expression}' is not supported for {value}".format(
expression=expr, value=repr(value)
)
)
def derive_from_date(value, expr):
"""
Mimics the behaviour of the ``year``, ``iso_year`` ``month``, ``day``,
``week``, ``week_day``, ``iso_week_day`` and ``quarter`` lookup
expressions that Django querysets support for ``DateField`` and
``DateTimeField`` columns, by extracting the relevant value from an
in-memory ``date`` or ``datetime`` value.
"""
if expr == "year":
return value.year
if expr == "iso_year":
return value.isocalendar()[0]
if expr == "month":
return value.month
if expr == "day":
return value.day
if expr == "week":
return value.isocalendar()[1]
if expr == "week_day":
v = value.isoweekday()
return 1 if v == 7 else v + 1
if expr == "iso_week_day":
return value.isoweekday()
if expr == "quarter":
return (value.month - 1) // 3 + 1
raise ValueError(
"Expression '{expression}' is not supported for {value}".format(
expression=expr, value=repr(value)
)
)
def derive_from_datetime(value, expr):
"""
Mimics the behaviour of the ``date``, ``time`` and other lookup
expressions that Django querysets support for ``DateTimeField`` columns,
by extracting the relevant value from an in-memory ``datetime`` value.
"""
if expr == "date":
return value.date()
if expr == "time":
return value.time()
if expr in TIMEFIELD_TRANSFORM_EXPRESSIONS:
return derive_from_time(value, expr)
if expr in DATEFIELD_TRANSFORM_EXPRESSIONS:
return derive_from_date(value, expr)
raise ValueError(
"Expression '{expression}' is not supported for {value}".format(
expression=expr, value=repr(value)
)
)

View File

@@ -0,0 +1,529 @@
from __future__ import unicode_literals
from django.core import checks
from django.db import IntegrityError, connections, router
from django.db.models import CASCADE
from django.db.models.fields.related import ForeignKey, ManyToManyField
from django.utils.functional import cached_property
from django.db.models.fields.related import ReverseManyToOneDescriptor, ManyToManyDescriptor
from modelcluster.utils import sort_by_fields
from modelcluster.queryset import FakeQuerySet
def create_deferring_foreign_related_manager(related, original_manager_cls):
"""
Create a DeferringRelatedManager class that wraps an ordinary RelatedManager
with 'deferring' behaviour: any updates to the object set (via e.g. add() or clear())
are written to a holding area rather than committed to the database immediately.
Writing to the database is deferred until the model is saved.
"""
relation_name = related.get_accessor_name()
rel_field = related.field
rel_model = related.related_model
superclass = rel_model._default_manager.__class__
class DeferringRelatedManager(superclass):
def __init__(self, instance):
super().__init__()
self.model = rel_model
self.instance = instance
@property
def is_deferring(self):
return relation_name in getattr(
self.instance, '_cluster_related_objects', {}
)
def _get_cluster_related_objects(self):
# Helper to retrieve the instance's _cluster_related_objects dict,
# creating it if it does not already exist
try:
return self.instance._cluster_related_objects
except AttributeError:
cluster_related_objects = {}
self.instance._cluster_related_objects = cluster_related_objects
return cluster_related_objects
def get_live_query_set(self):
# deprecated; renamed to get_live_queryset to match the move from
# get_query_set to get_queryset in Django 1.6
return self.get_live_queryset()
def get_live_queryset(self):
"""
return the original manager's queryset, which reflects the live database
"""
return original_manager_cls(self.instance).get_queryset()
def get_queryset(self):
"""
return the current object set with any updates applied,
wrapped up in a FakeQuerySet if it doesn't match the database state
"""
try:
results = self.instance._cluster_related_objects[relation_name]
except (AttributeError, KeyError):
if self.instance.pk is None:
# use an empty fake queryset if the instance is unsaved
results = []
else:
return self.get_live_queryset()
return FakeQuerySet(related.related_model, results)
def _apply_rel_filters(self, queryset):
# Implemented as empty for compatibility sake
# But there is probably a better implementation of this function
#
# NOTE: _apply_rel_filters() must return a copy of the queryset
# to work correctly with prefetch
return queryset._next_is_sticky().all()
def get_prefetch_queryset(self, instances, queryset=None):
if queryset is None:
db = self._db or router.db_for_read(self.model, instance=instances[0])
queryset = super().get_queryset().using(db)
rel_obj_attr = rel_field.get_local_related_value
instance_attr = rel_field.get_foreign_related_value
instances_dict = dict((instance_attr(inst), inst) for inst in instances)
query = {'%s__in' % rel_field.name: instances}
qs = queryset.filter(**query)
# Since we just bypassed this class' get_queryset(), we must manage
# the reverse relation manually.
for rel_obj in qs:
instance = instances_dict[rel_obj_attr(rel_obj)]
setattr(rel_obj, rel_field.name, instance)
cache_name = rel_field.related_query_name()
return qs, rel_obj_attr, instance_attr, False, cache_name, False
def get_object_list(self):
"""
return the mutable list that forms the current in-memory state of
this relation. If there is no such list (i.e. the manager is returning
querysets from the live database instead), one is created, populating it
with the live database state
"""
cluster_related_objects = self._get_cluster_related_objects()
try:
object_list = cluster_related_objects[relation_name]
except KeyError:
if self.instance.pk is None:
object_list = []
else:
object_list = list(self.get_live_queryset())
cluster_related_objects[relation_name] = object_list
return object_list
def add(self, *new_items):
"""
Add the passed items to the stored object set, but do not commit them
to the database
"""
items = self.get_object_list()
for target in new_items:
item_matched = False
for i, item in enumerate(items):
if item == target:
# Replace the matched item with the new one. This ensures that any
# modifications to that item's fields take effect within the recordset -
# i.e. we can perform a virtual UPDATE to an object in the list
# by calling add(updated_object). Which is semantically a bit dubious,
# but it does the job...
items[i] = target
item_matched = True
break
if not item_matched:
items.append(target)
# update the foreign key on the added item to point back to the parent instance
setattr(target, related.field.name, self.instance)
# Sort list
if rel_model._meta.ordering and len(items) > 1:
sort_by_fields(items, rel_model._meta.ordering)
def remove(self, *items_to_remove):
"""
Remove the passed items from the stored object set, but do not commit the change
to the database
"""
items = self.get_object_list()
# filter items list in place: see http://stackoverflow.com/a/1208792/1853523
items[:] = [item for item in items if item not in items_to_remove]
def create(self, **kwargs):
items = self.get_object_list()
new_item = related.related_model(**kwargs)
items.append(new_item)
return new_item
def clear(self):
"""
Clear the stored object set, without affecting the database
"""
self.set([])
def set(self, objs, bulk=True, clear=False):
# cast objs to a list so that:
# 1) we can call len() on it (which we can't do on, say, a queryset)
# 2) if we need to sort it, we can do so without mutating the original
objs = list(objs)
cluster_related_objects = self._get_cluster_related_objects()
for obj in objs:
# update the foreign key on the added item to point back to the parent instance
setattr(obj, related.field.name, self.instance)
# Clone and sort the 'objs' list, if necessary
if rel_model._meta.ordering and len(objs) > 1:
sort_by_fields(objs, rel_model._meta.ordering)
cluster_related_objects[relation_name] = objs
def commit(self):
"""
Apply any changes made to the stored object set to the database.
Any objects removed from the initial set will be deleted entirely
from the database.
"""
if self.instance.pk is None:
raise IntegrityError("Cannot commit relation %r on an unsaved model" % relation_name)
try:
final_items = self.instance._cluster_related_objects[relation_name]
except (AttributeError, KeyError):
# _cluster_related_objects entry never created => no changes to make
return
original_manager = original_manager_cls(self.instance)
live_items = list(original_manager.get_queryset())
for item in live_items:
if item not in final_items:
item.delete()
for item in final_items:
# Django 1.9+ bulk updates items by default which assumes
# that they have already been saved to the database.
# Disable this behaviour.
# https://code.djangoproject.com/ticket/18556
# https://github.com/django/django/commit/adc0c4fbac98f9cb975e8fa8220323b2de638b46
original_manager.add(item, bulk=False)
# purge the _cluster_related_objects entry, so we switch back to live SQL
del self.instance._cluster_related_objects[relation_name]
return DeferringRelatedManager
class ChildObjectsDescriptor(ReverseManyToOneDescriptor):
def __get__(self, instance, instance_type=None):
if instance is None:
return self
return self.child_object_manager_cls(instance)
def __set__(self, instance, value):
manager = self.__get__(instance)
manager.set(value)
@cached_property
def child_object_manager_cls(self):
return create_deferring_foreign_related_manager(self.rel, self.related_manager_cls)
class ParentalKey(ForeignKey):
related_accessor_class = ChildObjectsDescriptor
def __init__(self, *args, **kwargs):
kwargs.setdefault('on_delete', CASCADE)
super().__init__(*args, **kwargs)
def check(self, **kwargs):
from modelcluster.models import ClusterableModel
errors = super().check(**kwargs)
# Check that the destination model is a subclass of ClusterableModel.
# If self.rel.to is a string at this point, it means that Django has been unable
# to resolve it as a model name; if so, skip this test so that Django's own
# system checks can report the appropriate error
if isinstance(self.remote_field.model, type) and not issubclass(self.remote_field.model, ClusterableModel):
errors.append(
checks.Error(
'ParentalKey must point to a subclass of ClusterableModel.',
hint='Change {model_name} into a ClusterableModel or use a ForeignKey instead.'.format(
model_name=self.remote_field.model._meta.app_label + '.' + self.remote_field.model.__name__,
),
obj=self,
id='modelcluster.E001',
)
)
# ParentalKeys must have an accessor name (#49)
if self.remote_field.get_accessor_name() == '+':
errors.append(
checks.Error(
"related_name='+' is not allowed on ParentalKey fields",
hint="Either change it to a valid name or remove it",
obj=self,
id='modelcluster.E002',
)
)
return errors
def create_deferring_forward_many_to_many_manager(rel, original_manager_cls):
rel_field = rel.field
relation_name = rel_field.name
query_field_name = rel_field.related_query_name()
source_field_name = rel_field.m2m_field_name()
rel_model = rel.model
superclass = rel_model._default_manager.__class__
rel_through = rel.through
class DeferringManyRelatedManager(superclass):
def __init__(self, instance=None):
super().__init__()
self.model = rel_model
self.through = rel_through
self.instance = instance
def get_original_manager(self):
return original_manager_cls(self.instance)
def get_live_queryset(self):
"""
return the original manager's queryset, which reflects the live database
"""
return self.get_original_manager().get_queryset()
def _get_cluster_related_objects(self):
# Helper to retrieve the instance's _cluster_related_objects dict,
# creating it if it does not already exist
try:
return self.instance._cluster_related_objects
except AttributeError:
cluster_related_objects = {}
self.instance._cluster_related_objects = cluster_related_objects
return cluster_related_objects
def get_queryset(self):
"""
return the current object set with any updates applied,
wrapped up in a FakeQuerySet if it doesn't match the database state
"""
try:
results = self.instance._cluster_related_objects[relation_name]
except (AttributeError, KeyError):
if self.instance.pk:
return self.get_live_queryset()
else:
# the standard M2M manager fails on unsaved instances,
# so bypass it and return an empty queryset
return rel_model.objects.none()
return FakeQuerySet(rel_model, results)
def get_prefetch_queryset(self, instances, queryset=None):
# Derived from Django's ManyRelatedManager.get_prefetch_queryset.
if queryset is None:
queryset = super().get_queryset()
queryset._add_hints(instance=instances[0])
queryset = queryset.using(queryset._db or self._db)
query = {'%s__in' % query_field_name: instances}
queryset = queryset._next_is_sticky().filter(**query)
fk = self.through._meta.get_field(source_field_name)
join_table = fk.model._meta.db_table
connection = connections[queryset.db]
qn = connection.ops.quote_name
queryset = queryset.extra(select={
'_prefetch_related_val_%s' % f.attname:
'%s.%s' % (qn(join_table), qn(f.column)) for f in fk.local_related_fields})
return (
queryset,
lambda result: tuple(
getattr(result, '_prefetch_related_val_%s' % f.attname)
for f in fk.local_related_fields
),
lambda inst: tuple(
f.get_db_prep_value(getattr(inst, f.attname), connection)
for f in fk.foreign_related_fields
),
False,
relation_name,
False,
)
def _apply_rel_filters(self, queryset):
# Required for get_prefetch_queryset.
return queryset._next_is_sticky()
def get_object_list(self):
"""
return the mutable list that forms the current in-memory state of
this relation. If there is no such list (i.e. the manager is returning
querysets from the live database instead), one is created, populating it
with the live database state
"""
cluster_related_objects = self._get_cluster_related_objects()
try:
object_list = cluster_related_objects[relation_name]
except KeyError:
object_list = list(self.get_live_queryset())
cluster_related_objects[relation_name] = object_list
return object_list
def add(self, *new_items):
"""
Add the passed items to the stored object set, but do not commit them
to the database
"""
items = self.get_object_list()
for target in new_items:
if target.pk is None:
raise ValueError('"%r" needs to have a primary key value before '
'it can be added to a parental many-to-many relation.' % target)
item_matched = False
for i, item in enumerate(items):
if item == target:
# Replace the matched item with the new one. This ensures that any
# modifications to that item's fields take effect within the recordset -
# i.e. we can perform a virtual UPDATE to an object in the list
# by calling add(updated_object). Which is semantically a bit dubious,
# but it does the job...
items[i] = target
item_matched = True
break
if not item_matched:
items.append(target)
# Sort list
if rel_model._meta.ordering and len(items) > 1:
sort_by_fields(items, rel_model._meta.ordering)
def clear(self):
"""
Clear the stored object set, without affecting the database
"""
self.set([])
def set(self, objs, bulk=True, clear=False):
# cast objs to a list so that:
# 1) we can call len() on it (which we can't do on, say, a queryset)
# 2) if we need to sort it, we can do so without mutating the original
objs = list(objs)
if objs and not isinstance(objs[0], rel_model):
# assume objs is a list of pks (like when loading data from a
# fixture), and allow the orignal manager to handle things
original_manager = self.get_original_manager()
original_manager.set(objs)
return
cluster_related_objects = self._get_cluster_related_objects()
# Clone and sort the 'objs' list, if necessary
if rel_model._meta.ordering and len(objs) > 1:
sort_by_fields(objs, rel_model._meta.ordering)
cluster_related_objects[relation_name] = objs
def remove(self, *items_to_remove):
"""
Remove the passed items from the stored object set, but do not commit the change
to the database
"""
items = self.get_object_list()
# filter items list in place: see http://stackoverflow.com/a/1208792/1853523
items[:] = [item for item in items if item not in items_to_remove]
def commit(self):
"""
Apply any changes made to the stored object set to the database.
"""
if not self.instance.pk:
raise IntegrityError("Cannot commit relation %r on an unsaved model" % relation_name)
try:
final_items = self.instance._cluster_related_objects[relation_name]
except (AttributeError, KeyError):
# _cluster_related_objects entry never created => no changes to make
return
original_manager = self.get_original_manager()
live_items = list(original_manager.get_queryset())
items_to_remove = [item for item in live_items if item not in final_items]
items_to_add = [item for item in final_items if item not in live_items]
if items_to_remove:
original_manager.remove(*items_to_remove)
if items_to_add:
original_manager.add(*items_to_add)
# purge the _cluster_related_objects entry, so we switch back to live SQL
del self.instance._cluster_related_objects[relation_name]
return DeferringManyRelatedManager
class ParentalManyToManyDescriptor(ManyToManyDescriptor):
def __get__(self, instance, instance_type=None):
if instance is None:
return self
return self.child_object_manager_cls(instance)
def __set__(self, instance, value):
manager = self.__get__(instance)
manager.set(value)
@cached_property
def child_object_manager_cls(self):
rel = self.rel
return create_deferring_forward_many_to_many_manager(rel, self.related_manager_cls)
class ParentalManyToManyField(ManyToManyField):
related_accessor_class = ParentalManyToManyDescriptor
_need_commit_after_assignment = True
def contribute_to_class(self, cls, name, **kwargs):
# ManyToManyField does not (as of Django 1.10) respect related_accessor_class,
# but hard-codes ManyToManyDescriptor instead:
# https://github.com/django/django/blob/6157cd6da1b27716e8f3d1ed692a6e33d970ae46/django/db/models/fields/related.py#L1538
# So, we'll let the original contribute_to_class do its thing, and then overwrite
# the accessor...
super().contribute_to_class(cls, name, **kwargs)
setattr(cls, self.name, self.related_accessor_class(self.remote_field))
def value_from_object(self, obj):
# In Django >=1.10, ManyToManyField.value_from_object special-cases objects with no PK,
# returning an empty list on the basis that unsaved objects can't have related objects.
# Remove that special case.
return getattr(obj, self.attname).all()

View File

@@ -0,0 +1,439 @@
from __future__ import unicode_literals
from django.forms import ValidationError
from django.core.exceptions import NON_FIELD_ERRORS
from django.forms.formsets import TOTAL_FORM_COUNT
from django.forms.models import (
BaseModelFormSet, modelformset_factory,
ModelForm, _get_foreign_key, ModelFormMetaclass, ModelFormOptions
)
from django.db.models.fields.related import ForeignObjectRel
from django.utils.html import format_html_join
from modelcluster.models import get_all_child_relations
class BaseTransientModelFormSet(BaseModelFormSet):
""" A ModelFormSet that doesn't assume that all its initial data instances exist in the db """
def _construct_form(self, i, **kwargs):
# Need to override _construct_form to avoid calling to_python on an empty string PK value
if self.is_bound and i < self.initial_form_count():
pk_key = "%s-%s" % (self.add_prefix(i), self.model._meta.pk.name)
pk = self.data[pk_key]
if pk == '':
kwargs['instance'] = self.model()
else:
pk_field = self.model._meta.pk
to_python = self._get_to_python(pk_field)
pk = to_python(pk)
kwargs['instance'] = self._existing_object(pk)
if i < self.initial_form_count() and 'instance' not in kwargs:
kwargs['instance'] = self.get_queryset()[i]
if i >= self.initial_form_count() and self.initial_extra:
# Set initial values for extra forms
try:
kwargs['initial'] = self.initial_extra[i - self.initial_form_count()]
except IndexError:
pass
# bypass BaseModelFormSet's own _construct_form
return super(BaseModelFormSet, self)._construct_form(i, **kwargs)
def save_existing_objects(self, commit=True):
# Need to override _construct_form so that it doesn't skip over initial forms whose instance
# has a blank PK (which is taken as an indication that the form was constructed with an
# instance not present in our queryset)
self.changed_objects = []
self.deleted_objects = []
if not self.initial_forms:
return []
saved_instances = []
forms_to_delete = self.deleted_forms
for form in self.initial_forms:
obj = form.instance
if form in forms_to_delete:
if obj.pk is None:
# no action to be taken to delete an object which isn't in the database
continue
self.deleted_objects.append(obj)
self.delete_existing(obj, commit=commit)
elif form.has_changed():
self.changed_objects.append((obj, form.changed_data))
saved_instances.append(self.save_existing(form, obj, commit=commit))
if not commit:
self.saved_forms.append(form)
return saved_instances
def transientmodelformset_factory(model, formset=BaseTransientModelFormSet, **kwargs):
return modelformset_factory(model, formset=formset, **kwargs)
class BaseChildFormSet(BaseTransientModelFormSet):
inherit_kwargs = None
def __init__(self, data=None, files=None, instance=None, queryset=None, **kwargs):
if instance is None:
self.instance = self.fk.remote_field.model()
else:
self.instance = instance
self.rel_name = ForeignObjectRel(self.fk, self.fk.remote_field.model, related_name=self.fk.remote_field.related_name).get_accessor_name()
if queryset is None:
queryset = getattr(self.instance, self.rel_name).all()
super().__init__(data, files, queryset=queryset, **kwargs)
def save(self, commit=True):
# The base ModelFormSet's save(commit=False) will populate the lists
# self.changed_objects, self.deleted_objects and self.new_objects;
# use these to perform the appropriate updates on the relation's manager.
saved_instances = super().save(commit=False)
manager = getattr(self.instance, self.rel_name)
# if model has a sort_order_field defined, assign order indexes to the attribute
# named in it
if self.can_order and hasattr(self.model, 'sort_order_field'):
sort_order_field = getattr(self.model, 'sort_order_field')
for i, form in enumerate(self.ordered_forms):
setattr(form.instance, sort_order_field, i)
# If the manager has existing instances with a blank ID, we have no way of knowing
# whether these correspond to items in the submitted data. We'll assume that they do,
# as that's the most common case (i.e. the formset contains the full set of child objects,
# not just a selection of additions / updates) and so we delete all ID-less objects here
# on the basis that they will be re-added by the formset saving mechanism.
no_id_instances = [obj for obj in manager.all() if obj.pk is None]
if no_id_instances:
manager.remove(*no_id_instances)
manager.add(*saved_instances)
manager.remove(*self.deleted_objects)
self.save_m2m() # ensures any parental-m2m fields are saved.
if commit:
manager.commit()
return saved_instances
def clean(self, *args, **kwargs):
self.validate_unique()
return super().clean(*args, **kwargs)
def validate_unique(self):
'''This clean method will check for unique_together condition'''
# Collect unique_checks and to run from all the forms.
all_unique_checks = set()
all_date_checks = set()
forms_to_delete = self.deleted_forms
valid_forms = [form for form in self.forms if form.is_valid() and form not in forms_to_delete]
for form in valid_forms:
unique_checks, date_checks = form.instance._get_unique_checks()
all_unique_checks.update(unique_checks)
all_date_checks.update(date_checks)
errors = []
# Do each of the unique checks (unique and unique_together)
for uclass, unique_check in all_unique_checks:
seen_data = set()
for form in valid_forms:
# Get the data for the set of fields that must be unique among the forms.
row_data = (
field if field in self.unique_fields else form.cleaned_data[field]
for field in unique_check if field in form.cleaned_data
)
# Reduce Model instances to their primary key values
row_data = tuple(d._get_pk_val() if hasattr(d, '_get_pk_val') else d
for d in row_data)
if row_data and None not in row_data:
# if we've already seen it then we have a uniqueness failure
if row_data in seen_data:
# poke error messages into the right places and mark
# the form as invalid
errors.append(self.get_unique_error_message(unique_check))
form._errors[NON_FIELD_ERRORS] = self.error_class([self.get_form_error()])
# remove the data from the cleaned_data dict since it was invalid
for field in unique_check:
if field in form.cleaned_data:
del form.cleaned_data[field]
# mark the data as seen
seen_data.add(row_data)
if errors:
raise ValidationError(errors)
def childformset_factory(
parent_model, model, form=ModelForm,
formset=BaseChildFormSet, fk_name=None, fields=None, exclude=None,
extra=3, can_order=False, can_delete=True, max_num=None, validate_max=False,
formfield_callback=None, widgets=None, min_num=None, validate_min=False,
inherit_kwargs=None, formsets=None, exclude_formsets=None
):
fk = _get_foreign_key(parent_model, model, fk_name=fk_name)
# enforce a max_num=1 when the foreign key to the parent model is unique.
if fk.unique:
max_num = 1
validate_max = True
if exclude is None:
exclude = []
exclude += [fk.name]
if issubclass(form, ClusterForm) and (formsets is not None or exclude_formsets is not None):
# the modelformset_factory helper that we ultimately hand off to doesn't recognise
# formsets / exclude_formsets, so we need to prepare a specific subclass of our `form`
# class, with these pre-embedded in Meta, to use as the base form
# If parent form class already has an inner Meta, the Meta we're
# creating needs to inherit from the parent's inner meta.
bases = (form.Meta,) if hasattr(form, "Meta") else ()
Meta = type("Meta", bases, {
'formsets': formsets,
'exclude_formsets': exclude_formsets,
})
# Instantiate type(form) in order to use the same metaclass as form.
form = type(form)("_ClusterForm", (form,), {"Meta": Meta})
kwargs = {
'form': form,
'formfield_callback': formfield_callback,
'formset': formset,
'extra': extra,
'can_delete': can_delete,
# if the model supplies a sort_order_field, enable ordering regardless of
# the current setting of can_order
'can_order': (can_order or hasattr(model, 'sort_order_field')),
'fields': fields,
'exclude': exclude,
'max_num': max_num,
'validate_max': validate_max,
'widgets': widgets,
'min_num': min_num,
'validate_min': validate_min,
}
FormSet = transientmodelformset_factory(model, **kwargs)
FormSet.fk = fk
# A list of keyword argument names that should be passed on from ClusterForm's constructor
# to child forms in this formset
FormSet.inherit_kwargs = inherit_kwargs
return FormSet
class ClusterFormOptions(ModelFormOptions):
def __init__(self, options=None):
super().__init__(options=options)
self.formsets = getattr(options, 'formsets', None)
self.exclude_formsets = getattr(options, 'exclude_formsets', None)
class ClusterFormMetaclass(ModelFormMetaclass):
extra_form_count = 3
@classmethod
def child_form(cls):
return ClusterForm
def __new__(cls, name, bases, attrs):
try:
parents = [b for b in bases if issubclass(b, ClusterForm)]
except NameError:
# We are defining ClusterForm itself.
parents = None
# grab any formfield_callback that happens to be defined in attrs -
# so that we can pass it on to child formsets - before ModelFormMetaclass deletes it.
# BAD METACLASS NO BISCUIT.
formfield_callback = attrs.get('formfield_callback')
new_class = super().__new__(cls, name, bases, attrs)
if not parents:
return new_class
# ModelFormMetaclass will have set up new_class._meta as a ModelFormOptions instance;
# replace that with ClusterFormOptions so that we can access _meta.formsets
opts = new_class._meta = ClusterFormOptions(getattr(new_class, 'Meta', None))
if opts.model:
formsets = {}
for rel in get_all_child_relations(opts.model):
# to build a childformset class from this relation, we need to specify:
# - the base model (opts.model)
# - the child model (rel.field.model)
# - the fk_name from the child model to the base (rel.field.name)
rel_name = rel.get_accessor_name()
# apply 'formsets' and 'exclude_formsets' rules from meta
if opts.exclude_formsets is not None and rel_name in opts.exclude_formsets:
# formset is explicitly excluded
continue
elif opts.formsets is not None and rel_name not in opts.formsets:
# a formset list has been specified and this isn't on it
continue
elif opts.formsets is None and opts.exclude_formsets is None:
# neither formsets nor exclude_formsets has been specified - no formsets at all
continue
try:
widgets = opts.widgets.get(rel_name)
except AttributeError: # thrown if opts.widgets is None
widgets = None
kwargs = {
'extra': cls.extra_form_count,
'form': cls.child_form(),
'formfield_callback': formfield_callback,
'fk_name': rel.field.name,
'widgets': widgets,
'formset_name': rel_name
}
# see if opts.formsets looks like a dict; if so, allow the value
# to override kwargs
try:
kwargs.update(opts.formsets.get(rel_name))
except AttributeError:
pass
formset_name = kwargs.pop('formset_name')
formset = childformset_factory(opts.model, rel.field.model, **kwargs)
formsets[formset_name] = formset
new_class.formsets = formsets
return new_class
class ClusterForm(ModelForm, metaclass=ClusterFormMetaclass):
def __init__(self, data=None, files=None, instance=None, prefix=None, **kwargs):
super().__init__(data, files, instance=instance, prefix=prefix, **kwargs)
self.formsets = {}
for rel_name, formset_class in self.__class__.formsets.items():
if prefix:
formset_prefix = "%s-%s" % (prefix, rel_name)
else:
formset_prefix = rel_name
child_form_kwargs = {}
if formset_class.inherit_kwargs:
for kwarg_name in formset_class.inherit_kwargs:
child_form_kwargs[kwarg_name] = getattr(self, kwarg_name, None)
self.formsets[rel_name] = formset_class(
data, files, instance=instance, prefix=formset_prefix, form_kwargs=child_form_kwargs
)
def as_p(self):
form_as_p = super().as_p()
return form_as_p + format_html_join('', '{}', [(formset.as_p(),) for formset in self.formsets.values()])
def is_valid(self):
form_is_valid = super().is_valid()
formsets_are_valid = all(formset.is_valid() for formset in self.formsets.values())
return form_is_valid and formsets_are_valid
def is_multipart(self):
return (
super().is_multipart()
or any(formset.is_multipart() for formset in self.formsets.values())
)
@property
def media(self):
media = super().media
for formset in self.formsets.values():
media = media + formset.media
return media
def save(self, commit=True):
# do we have any fields that expect us to call save_m2m immediately?
save_m2m_now = False
exclude = self._meta.exclude
fields = self._meta.fields
for f in self.instance._meta.get_fields():
if fields and f.name not in fields:
continue
if exclude and f.name in exclude:
continue
if getattr(f, '_need_commit_after_assignment', False):
save_m2m_now = True
break
instance = super().save(commit=(commit and not save_m2m_now))
# The M2M-like fields designed for use with ClusterForm (currently
# ParentalManyToManyField and ClusterTaggableManager) will manage their own in-memory
# relations, and not immediately write to the database when we assign to them.
# For these fields (identified by the _need_commit_after_assignment
# flag), save_m2m() is a safe operation that does not affect the database and is thus
# valid for commit=False. In the commit=True case, committing to the database happens
# in the subsequent instance.save (so this needs to happen after save_m2m to ensure
# we have the updated relation data in place).
# For annoying legacy reasons we sometimes need to accommodate 'classic' M2M fields
# (particularly taggit.TaggableManager) within ClusterForm. These fields
# generally do require our instance to exist in the database at the point we call
# save_m2m() - for this reason, we only proceed with the customisation described above
# (i.e. postpone the instance.save() operation until after save_m2m) if there's a
# _need_commit_after_assignment field on the form that demands it.
if save_m2m_now:
self.save_m2m()
if commit:
instance.save()
for formset in self.formsets.values():
formset.instance = instance
formset.save(commit=commit)
return instance
def has_changed(self):
"""Return True if data differs from initial."""
# Need to recurse over nested formsets so that the form is saved if there are changes
# to child forms but not the parent
if self.formsets:
for formset in self.formsets.values():
for form in formset.forms:
if form.has_changed():
return True
return bool(self.changed_data)
def clusterform_factory(model, form=ClusterForm, **kwargs):
# Same as Django's modelform_factory, but arbitrary kwargs are accepted and passed on to the
# Meta class.
# Build up a list of attributes that the Meta object will have.
meta_class_attrs = kwargs
meta_class_attrs["model"] = model
# If parent form class already has an inner Meta, the Meta we're
# creating needs to inherit from the parent's inner meta.
bases = (form.Meta,) if hasattr(form, "Meta") else ()
Meta = type("Meta", bases, meta_class_attrs)
formfield_callback = meta_class_attrs.get('formfield_callback')
if formfield_callback:
Meta.formfield_callback = staticmethod(formfield_callback)
# Give this new form class a reasonable name.
class_name = model.__name__ + "Form"
# Class attributes for the new form class.
form_class_attrs = {"Meta": Meta, "formfield_callback": formfield_callback}
# Instantiate type(form) in order to use the same metaclass as form.
return type(form)(class_name, (form,), form_class_attrs)

View File

@@ -0,0 +1,421 @@
from __future__ import unicode_literals
import json
import datetime
from django.core.exceptions import FieldDoesNotExist
from django.db import models, transaction
from django.db.models.fields.related import ForeignObjectRel
from django.utils.encoding import is_protected_type
from django.core.serializers.json import DjangoJSONEncoder
from django.conf import settings
from django.utils import timezone
from modelcluster.fields import ParentalKey, ParentalManyToManyField
def get_field_value(field, model):
if field.remote_field is None:
value = field.pre_save(model, add=model.pk is None)
# Make datetimes timezone aware
# https://github.com/django/django/blob/master/django/db/models/fields/__init__.py#L1394-L1403
if isinstance(value, datetime.datetime) and settings.USE_TZ:
if timezone.is_naive(value):
default_timezone = timezone.get_default_timezone()
value = timezone.make_aware(value, default_timezone).astimezone(datetime.timezone.utc)
else:
# convert to UTC
value = timezone.localtime(value, datetime.timezone.utc)
if is_protected_type(value):
return value
else:
return field.value_to_string(model)
else:
return getattr(model, field.get_attname())
def get_serializable_data_for_fields(model):
"""
Return a serialised version of the model's fields which exist as local database
columns (i.e. excluding m2m and incoming foreign key relations)
"""
pk_field = model._meta.pk
# If model is a child via multitable inheritance, use parent's pk
while pk_field.remote_field and pk_field.remote_field.parent_link:
pk_field = pk_field.remote_field.model._meta.pk
obj = {'pk': get_field_value(pk_field, model)}
for field in model._meta.fields:
if field.serialize:
obj[field.name] = get_field_value(field, model)
return obj
def model_from_serializable_data(model, data, check_fks=True, strict_fks=False):
pk_field = model._meta.pk
kwargs = {}
# If model is a child via multitable inheritance, we need to set ptr_id fields all the way up
# to the main PK field, as Django won't populate these for us automatically.
while pk_field.remote_field and pk_field.remote_field.parent_link:
kwargs[pk_field.attname] = data['pk']
pk_field = pk_field.remote_field.model._meta.pk
kwargs[pk_field.attname] = data['pk']
for field_name, field_value in data.items():
try:
field = model._meta.get_field(field_name)
except FieldDoesNotExist:
continue
# Filter out reverse relations
if isinstance(field, ForeignObjectRel):
continue
if field.remote_field and isinstance(field.remote_field, models.ManyToManyRel):
related_objects = field.remote_field.model._default_manager.filter(pk__in=field_value)
kwargs[field.attname] = list(related_objects)
elif field.remote_field and isinstance(field.remote_field, models.ManyToOneRel):
if field_value is None:
kwargs[field.attname] = None
else:
clean_value = field.remote_field.model._meta.get_field(field.remote_field.field_name).to_python(field_value)
kwargs[field.attname] = clean_value
if check_fks:
try:
field.remote_field.model._default_manager.get(**{field.remote_field.field_name: clean_value})
except field.remote_field.model.DoesNotExist:
if field.remote_field.on_delete == models.DO_NOTHING:
pass
elif field.remote_field.on_delete == models.CASCADE:
if strict_fks:
return None
else:
kwargs[field.attname] = None
elif field.remote_field.on_delete == models.SET_NULL:
kwargs[field.attname] = None
else:
raise Exception("can't currently handle on_delete types other than CASCADE, SET_NULL and DO_NOTHING")
else:
value = field.to_python(field_value)
# Make sure datetimes are converted to localtime
if isinstance(field, models.DateTimeField) and settings.USE_TZ and value is not None:
default_timezone = timezone.get_default_timezone()
if timezone.is_aware(value):
value = timezone.localtime(value, default_timezone)
else:
value = timezone.make_aware(value, default_timezone)
kwargs[field.name] = value
obj = model(**kwargs)
if data['pk'] is not None:
# Set state to indicate that this object has come from the database, so that
# ModelForm validation doesn't try to enforce a uniqueness check on the primary key
obj._state.adding = False
return obj
def get_all_child_relations(model):
"""
Return a list of RelatedObject records for child relations of the given model,
including ones attached to ancestors of the model
"""
return [
field for field in model._meta.get_fields()
if isinstance(field.remote_field, ParentalKey)
]
def get_all_child_m2m_relations(model):
"""
Return a list of ParentalManyToManyFields on the given model,
including ones attached to ancestors of the model
"""
return [
field for field in model._meta.get_fields()
if isinstance(field, ParentalManyToManyField)
]
class ClusterableModel(models.Model):
def __init__(self, *args, **kwargs):
"""
Extend the standard model constructor to allow child object lists to be passed in
via kwargs
"""
child_relation_names = (
[rel.get_accessor_name() for rel in get_all_child_relations(self)] +
[field.name for field in get_all_child_m2m_relations(self)]
)
if any(name in kwargs for name in child_relation_names):
# One or more child relation values is being passed in the constructor; need to
# separate these from the standard field kwargs to be passed to 'super'
kwargs_for_super = kwargs.copy()
relation_assignments = {}
for rel_name in child_relation_names:
if rel_name in kwargs:
relation_assignments[rel_name] = kwargs_for_super.pop(rel_name)
super().__init__(*args, **kwargs_for_super)
for (field_name, related_instances) in relation_assignments.items():
setattr(self, field_name, related_instances)
else:
super().__init__(*args, **kwargs)
def save(self, **kwargs):
"""
Save the model and commit all child relations.
"""
child_relation_names = [rel.get_accessor_name() for rel in get_all_child_relations(self)]
child_m2m_field_names = [field.name for field in get_all_child_m2m_relations(self)]
update_fields = kwargs.pop('update_fields', None)
if update_fields is None:
real_update_fields = None
relations_to_commit = child_relation_names
m2m_fields_to_commit = child_m2m_field_names
else:
real_update_fields = []
relations_to_commit = []
m2m_fields_to_commit = []
for field in update_fields:
if field in child_relation_names:
relations_to_commit.append(field)
elif field in child_m2m_field_names:
m2m_fields_to_commit.append(field)
else:
real_update_fields.append(field)
super().save(update_fields=real_update_fields, **kwargs)
for relation in relations_to_commit:
getattr(self, relation).commit()
for field in m2m_fields_to_commit:
getattr(self, field).commit()
def serializable_data(self):
obj = get_serializable_data_for_fields(self)
for rel in get_all_child_relations(self):
rel_name = rel.get_accessor_name()
children = getattr(self, rel_name).all()
if hasattr(rel.related_model, 'serializable_data'):
obj[rel_name] = [child.serializable_data() for child in children]
else:
obj[rel_name] = [get_serializable_data_for_fields(child) for child in children]
for field in get_all_child_m2m_relations(self):
if field.serialize:
children = getattr(self, field.name).all()
obj[field.name] = [child.pk for child in children]
return obj
def to_json(self):
return json.dumps(self.serializable_data(), cls=DjangoJSONEncoder)
@classmethod
def from_serializable_data(cls, data, check_fks=True, strict_fks=False):
"""
Build an instance of this model from the JSON-like structure passed in,
recursing into related objects as required.
If check_fks is true, it will check whether referenced foreign keys still
exist in the database.
- dangling foreign keys on related objects are dealt with by either nullifying the key or
dropping the related object, according to the 'on_delete' setting.
- dangling foreign keys on the base object will be nullified, unless strict_fks is true,
in which case any dangling foreign keys with on_delete=CASCADE will cause None to be
returned for the entire object.
"""
obj = model_from_serializable_data(cls, data, check_fks=check_fks, strict_fks=strict_fks)
if obj is None:
return None
child_relations = get_all_child_relations(cls)
for rel in child_relations:
rel_name = rel.get_accessor_name()
try:
child_data_list = data[rel_name]
except KeyError:
continue
related_model = rel.related_model
if hasattr(related_model, 'from_serializable_data'):
children = [
related_model.from_serializable_data(child_data, check_fks=check_fks, strict_fks=True)
for child_data in child_data_list
]
else:
children = [
model_from_serializable_data(related_model, child_data, check_fks=check_fks, strict_fks=True)
for child_data in child_data_list
]
children = filter(lambda child: child is not None, children)
setattr(obj, rel_name, children)
return obj
@classmethod
def from_json(cls, json_data, check_fks=True, strict_fks=False):
return cls.from_serializable_data(json.loads(json_data), check_fks=check_fks, strict_fks=strict_fks)
@transaction.atomic
def copy_child_relation(self, child_relation, target, commit=False, append=False):
"""
Copies all of the objects in the accessor_name to the target object.
For example, say we have an event with speakers (my_event) and we need to copy these to another event (my_other_event):
my_event.copy_child_relation('speakers', my_other_event)
By default, this copies the child objects without saving them. Set the commit paremter to True to save the objects
but note that this would cause an exception if the target object is not saved.
This will overwrite the child relation on the target object. This is to avoid any issues with unique keys
and/or sort_order. If you want it to append. set the `append` parameter to True.
This method returns a dictionary mapping the child relation/primary key on the source object to the new object created for the
target object.
"""
# A dict that maps child objects from their old IDs to their new objects
child_object_map = {}
if isinstance(child_relation, str):
child_relation = self._meta.get_field(child_relation)
if not isinstance(child_relation.remote_field, ParentalKey):
raise LookupError("copy_child_relation can only be used for relationships defined with a ParentalKey")
# The name of the ParentalKey field on the child model
parental_key_name = child_relation.field.attname
# Get managers for both the source and target objects
source_manager = getattr(self, child_relation.get_accessor_name())
target_manager = getattr(target, child_relation.get_accessor_name())
if not append:
target_manager.clear()
for child_object in source_manager.all().order_by('pk'):
old_pk = child_object.pk
is_saved = old_pk is not None
child_object.pk = None
setattr(child_object, parental_key_name, target.id)
target_manager.add(child_object)
# Add mapping to object
# If the PK is none, add them into a list since there may be multiple of these
if old_pk is not None:
child_object_map[(child_relation, old_pk)] = child_object
else:
if (child_relation, None) not in child_object_map:
child_object_map[(child_relation, None)] = []
child_object_map[(child_relation, None)].append(child_object)
if commit:
target_manager.commit()
return child_object_map
def copy_all_child_relations(self, target, exclude=None, commit=False, append=False):
"""
Copies all of the objects in all child relations to the target object.
This will overwrite all of the child relations on the target object.
Set exclude to a list of child relation accessor names that shouldn't be copied.
This method returns a dictionary mapping the child_relation/primary key on the source object to the new object created for the
target object.
"""
exclude = exclude or []
child_object_map = {}
for child_relation in get_all_child_relations(self):
if child_relation.get_accessor_name() in exclude:
continue
child_object_map.update(self.copy_child_relation(child_relation, target, commit=commit, append=append))
return child_object_map
def copy_cluster(self, exclude_fields=None):
"""
Makes a copy of this object and all child relations.
Includes all field data including child relations and parental many to many fields.
Doesn't include non-parental many to many.
The result of this method is unsaved.
"""
exclude_fields = exclude_fields or []
# Extract field data from self into a dictionary
data_dict = {}
for field in self._meta.get_fields():
# Ignore explicitly excluded fields
if field.name in exclude_fields:
continue
# Ignore reverse relations
if field.auto_created:
continue
# Copy parental m2m relations
# Otherwise add them to the m2m dict to be set after saving
if field.many_to_many:
if isinstance(field, ParentalManyToManyField):
parental_field = getattr(self, field.name)
if hasattr(parental_field, 'all'):
values = parental_field.all()
if values:
data_dict[field.name] = values
continue
# Ignore parent links (page_ptr)
if isinstance(field, models.OneToOneField) and field.remote_field.parent_link:
continue
if isinstance(field, models.ForeignKey):
# Use attname to copy the ID instead of retrieving the instance
# Note: We first need to set the field to None to unset any object
# that's there already just setting _id on its own won't change the
# field until its saved.
data_dict[field.name] = None
data_dict[field.attname] = getattr(self, field.attname)
else:
data_dict[field.name] = getattr(self, field.name)
# Create copy
copy = self.__class__(**data_dict)
# Copy child relations
child_object_map = self.copy_all_child_relations(copy, exclude=exclude_fields)
return copy, child_object_map
class Meta:
abstract = True

View File

@@ -0,0 +1,558 @@
from __future__ import unicode_literals
import re
from django.core.exceptions import FieldDoesNotExist
from django.db.models import Model, Q, prefetch_related_objects
from modelcluster.utils import NullRelationshipValueEncountered, extract_field_value, get_model_field, sort_by_fields
# Constructor for test functions that determine whether an object passes some boolean condition
def test_exact(model, attribute_name, value):
if isinstance(value, Model):
if value.pk is None:
# comparing against an unsaved model, so objects need to match by reference
def _test(obj):
try:
other_value = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return other_value is value
return _test
else:
# comparing against a saved model; objects need to match by type and ID.
# Additionally, where model inheritance is involved, we need to treat it as a
# positive match if one is a subclass of the other
def _test(obj):
try:
other_value = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return value.pk == other_value.pk and (
isinstance(value, other_value.__class__)
or isinstance(other_value, value.__class__)
)
return _test
else:
field = get_model_field(model, attribute_name)
# convert value to the correct python type for this field
typed_value = field.to_python(value)
# just a plain Python value = do a normal equality check
def _test(obj):
try:
other_value = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return other_value == typed_value
return _test
def test_iexact(model, attribute_name, match_value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(match_value)
if match_value is None:
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is None
else:
match_value = match_value.upper()
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val.upper() == match_value
return _test
def test_contains(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and match_value in val
return _test
def test_icontains(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value).upper()
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and match_value in val.upper()
return _test
def test_lt(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val < match_value
return _test
def test_lte(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val <= match_value
return _test
def test_gt(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val > match_value
return _test
def test_gte(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val >= match_value
return _test
def test_in(model, attribute_name, value_list):
field = get_model_field(model, attribute_name)
match_values = set(field.to_python(val) for val in value_list)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val in match_values
return _test
def test_startswith(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val.startswith(match_value)
return _test
def test_istartswith(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value).upper()
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val.upper().startswith(match_value)
return _test
def test_endswith(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val.endswith(match_value)
return _test
def test_iendswith(model, attribute_name, value):
field = get_model_field(model, attribute_name)
match_value = field.to_python(value).upper()
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and val.upper().endswith(match_value)
return _test
def test_range(model, attribute_name, range_val):
field = get_model_field(model, attribute_name)
start_val = field.to_python(range_val[0])
end_val = field.to_python(range_val[1])
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return (val is not None and val >= start_val and val <= end_val)
return _test
def test_isnull(model, attribute_name, sense):
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
if sense:
return val is None
else:
return val is not None
return _test
def test_regex(model, attribute_name, regex_string):
regex = re.compile(regex_string)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and regex.search(val)
return _test
def test_iregex(model, attribute_name, regex_string):
regex = re.compile(regex_string, re.I)
def _test(obj):
try:
val = extract_field_value(obj, attribute_name)
except NullRelationshipValueEncountered:
return False
return val is not None and regex.search(val)
return _test
FILTER_EXPRESSION_TOKENS = {
'exact': test_exact,
'iexact': test_iexact,
'contains': test_contains,
'icontains': test_icontains,
'lt': test_lt,
'lte': test_lte,
'gt': test_gt,
'gte': test_gte,
'in': test_in,
'startswith': test_startswith,
'istartswith': test_istartswith,
'endswith': test_endswith,
'iendswith': test_iendswith,
'range': test_range,
'isnull': test_isnull,
'regex': test_regex,
'iregex': test_iregex,
}
def _build_test_function_from_filter(model, key_clauses, val):
# Translate a filter kwarg rule (e.g. foo__bar__exact=123) into a function which can
# take a model instance and return a boolean indicating whether it passes the rule
try:
get_model_field(model, "__".join(key_clauses))
except FieldDoesNotExist:
# it is safe to assume the last clause indicates the type of test
field_match_found = False
else:
field_match_found = True
if not field_match_found and key_clauses[-1] in FILTER_EXPRESSION_TOKENS:
constructor = FILTER_EXPRESSION_TOKENS[key_clauses.pop()]
else:
constructor = test_exact
# recombine the remaining items to be interpretted
# by get_model_field() and extract_field_value()
attribute_name = "__".join(key_clauses)
return constructor(model, attribute_name, val)
class FakeQuerySetIterable:
def __init__(self, queryset):
self.queryset = queryset
class ModelIterable(FakeQuerySetIterable):
def __iter__(self):
yield from self.queryset.results
class DictIterable(FakeQuerySetIterable):
def __iter__(self):
field_names = self.queryset.dict_fields or [field.name for field in self.queryset.model._meta.fields]
for obj in self.queryset.results:
yield {
field_name: extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
for field_name in field_names
}
class ValuesListIterable(FakeQuerySetIterable):
def __iter__(self):
field_names = self.queryset.tuple_fields or [field.name for field in self.queryset.model._meta.fields]
for obj in self.queryset.results:
yield tuple([extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True) for field_name in field_names])
class FlatValuesListIterable(FakeQuerySetIterable):
def __iter__(self):
field_name = self.queryset.tuple_fields[0]
for obj in self.queryset.results:
yield extract_field_value(obj, field_name, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
class FakeQuerySet(object):
def __init__(self, model, results):
self.model = model
self.results = results
self.dict_fields = []
self.tuple_fields = []
self.iterable_class = ModelIterable
def all(self):
return self
def get_clone(self, results = None):
new = FakeQuerySet(self.model, results if results is not None else self.results)
new.dict_fields = self.dict_fields
new.tuple_fields = self.tuple_fields
new.iterable_class = self.iterable_class
return new
def resolve_q_object(self, q_object):
connector = q_object.connector
filters = []
def test(filters):
def test_inner(obj):
result = False
if connector == Q.AND:
result = all([test(obj) for test in filters])
elif connector == Q.OR:
result = any([test(obj) for test in filters])
else:
result = sum([test(obj) for test in filters]) == 1
if q_object.negated:
return not result
return result
return test_inner
for child in q_object.children:
if isinstance(child, Q):
filters.append(self.resolve_q_object(child))
else:
key_clauses, val = child
filters.append(_build_test_function_from_filter(self.model, key_clauses.split('__'), val))
return test(filters)
def _get_filters(self, *args, **kwargs):
# a list of test functions; objects must pass all tests to be included
# in the filtered list
filters = []
for q_object in args:
filters.append(self.resolve_q_object(q_object))
for key, val in kwargs.items():
filters.append(
_build_test_function_from_filter(self.model, key.split('__'), val)
)
return filters
def filter(self, *args, **kwargs):
filters = self._get_filters(*args, **kwargs)
clone = self.get_clone(results=[
obj for obj in self.results
if all([test(obj) for test in filters])
])
return clone
def exclude(self, *args, **kwargs):
filters = self._get_filters(*args, **kwargs)
clone = self.get_clone(results=[
obj for obj in self.results
if not all([test(obj) for test in filters])
])
return clone
def get(self, *args, **kwargs):
clone = self.filter(*args, **kwargs)
result_count = clone.count()
if result_count == 0:
raise self.model.DoesNotExist("%s matching query does not exist." % self.model._meta.object_name)
elif result_count == 1:
for result in clone:
return result
else:
raise self.model.MultipleObjectsReturned(
"get() returned more than one %s -- it returned %s!" % (self.model._meta.object_name, result_count)
)
def count(self):
return len(self.results)
def exists(self):
return bool(self.results)
def first(self):
for result in self:
return result
def last(self):
if self.results:
clone = self.get_clone(results=reversed(self.results))
for result in clone:
return result
def select_related(self, *args):
# has no meaningful effect on non-db querysets
return self
def prefetch_related(self, *args):
prefetch_related_objects(self.results, *args)
return self
def only(self, *args):
# has no meaningful effect on non-db querysets
return self
def defer(self, *args):
# has no meaningful effect on non-db querysets
return self
def values(self, *fields):
clone = self.get_clone()
clone.dict_fields = fields
# Ensure all 'fields' are available model fields
for f in fields:
get_model_field(self.model, f)
clone.iterable_class = DictIterable
return clone
def values_list(self, *fields, flat=None):
clone = self.get_clone()
clone.tuple_fields = fields
# Ensure all 'fields' are available model fields
for f in fields:
get_model_field(self.model, f)
if flat:
if len(fields) > 1:
raise TypeError("'flat' is not valid when values_list is called with more than one field.")
clone.iterable_class = FlatValuesListIterable
else:
clone.iterable_class = ValuesListIterable
return clone
def order_by(self, *fields):
clone = self.get_clone(results=self.results[:])
sort_by_fields(clone.results, fields)
return clone
def distinct(self, *fields):
unique_results = []
if not fields:
fields = [field.name for field in self.model._meta.fields if not field.primary_key]
seen_keys = set()
for result in self.results:
key = tuple(str(extract_field_value(result, field)) for field in fields)
if key not in seen_keys:
seen_keys.add(key)
unique_results.append(result)
return self.get_clone(results=unique_results)
# a standard QuerySet will store the results in _result_cache on running the query;
# this is effectively the same as self.results on a FakeQuerySet, and so we'll make
# _result_cache an alias of self.results for the benefit of Django internals that
# exploit it
def _get_result_cache(self):
return self.results
def _set_result_cache(self, val):
self.results = list(val)
_result_cache = property(_get_result_cache, _set_result_cache)
def __getitem__(self, k):
return self.results[k]
def __iter__(self):
iterator = self.iterable_class(self)
yield from iterator
def __nonzero__(self):
return bool(self.results)
def __repr__(self):
return repr(list(self))
def __len__(self):
return len(self.results)
ordered = True # results are returned in a consistent order

View File

@@ -0,0 +1,8 @@
import warnings
from modelcluster.contrib.taggit import * # NOQA
warnings.warn(
"The modelcluster.tags module has been moved to "
"modelcluster.contrib.taggit", DeprecationWarning)

View File

@@ -0,0 +1,216 @@
import datetime
from functools import lru_cache
import random
from django.core.exceptions import FieldDoesNotExist
from django.db.models import (
DateField,
DateTimeField,
ManyToManyField,
ManyToManyRel,
Model,
TimeField,
)
from modelcluster import datetime_utils
REL_DELIMETER = "__"
class ManyToManyTraversalError(ValueError):
pass
class NullRelationshipValueEncountered(Exception):
pass
class TraversedRelationship:
__slots__ = ['from_model', 'field']
def __init__(self, from_model, field):
self.from_model = from_model
self.field = field
@property
def field_name(self) -> str:
return self.field.name
@property
def to_model(self):
return self.field.target_model
@lru_cache(maxsize=None)
def get_model_field(model, name):
"""
Returns a model field matching the supplied ``name``, which can include
double-underscores (`'__'`) to indicate relationship traversal - in which
case, the model field will be lookuped up from the related model.
Multiple traversals for the same field are supported, but at this
moment in time, only traversal of many-to-one and one-to-one relationships
is supported.
Details of any relationships traversed in order to reach the returned
field are made available as `field.traversals`. The value is a tuple of
``TraversedRelationship`` instances.
Raises ``FieldDoesNotExist`` if the name cannot be mapped to a model field.
"""
subject_model = model
traversals = []
field = None
for field_name in name.split(REL_DELIMETER):
if field is not None:
if isinstance(field, (ManyToManyField, ManyToManyRel)):
raise ManyToManyTraversalError(
"The lookup '{name}' from {model} cannot be replicated "
"by modelcluster, because the '{field_name}' "
"relationship from {subject_model} is a many-to-many, "
"and traversal is only supported for one-to-one or "
"many-to-one relationships."
.format(
name=name,
model=model,
field_name=field_name,
subject_model=subject_model,
)
)
elif getattr(field, "related_model", None):
traversals.append(TraversedRelationship(subject_model, field))
subject_model = field.related_model
elif (
(
isinstance(field, DateTimeField)
and field_name in datetime_utils.DATETIMEFIELD_TRANSFORM_EXPRESSIONS
) or (
isinstance(field, DateField)
and field_name in datetime_utils.DATEFIELD_TRANSFORM_EXPRESSIONS
) or (
isinstance(field, TimeField)
and field_name in datetime_utils.TIMEFIELD_TRANSFORM_EXPRESSIONS
)
):
transform_field_type = datetime_utils.TRANSFORM_FIELD_TYPES[field_name]
field = transform_field_type()
break
else:
raise FieldDoesNotExist(
"Failed attempting to traverse from {from_field} (a {from_field_type}) to '{to_field}'."
.format(
from_field=subject_model._meta.label + '.' + field.name,
from_field_type=type(field),
to_field=field_name,
)
)
try:
field = subject_model._meta.get_field(field_name)
except FieldDoesNotExist:
if field_name.endswith("_id"):
field = subject_model._meta.get_field(field_name[:-3]).target_field
raise
field.traversals = tuple(traversals)
return field
def extract_field_value(obj, key, pk_only=False, suppress_fielddoesnotexist=False, suppress_nullrelationshipvalueencountered=False):
"""
Attempts to extract a field value from ``obj`` matching the ``key`` - which,
can contain double-underscores (`'__'`) to indicate traversal of relationships
to related objects.
For keys that specify ``ForeignKey`` or ``OneToOneField`` field values, full
related objects are returned by default. If only the primary key values are
required ((.g. when ordering, or using ``values()`` or ``values_list()``)),
call the function with ``pk_only=True``.
By default, ``FieldDoesNotExist`` is raised if the key cannot be mapped to
a model field. Call the function with ``suppress_fielddoesnotexist=True``
to instead receive a ``None`` value when this occurs.
By default, ``NullRelationshipValueEncountered`` is raised if a ``None``
value is encountered while attempting to traverse relationships in order to
access further fields. Call the function with
``suppress_nullrelationshipvalueencountered`` to instead receive a ``None``
value when this occurs.
"""
source = obj
latest_obj = obj
segments = key.split(REL_DELIMETER)
for i, segment in enumerate(segments, start=1):
if (
(
isinstance(source, datetime.datetime)
and segment in datetime_utils.DATETIMEFIELD_TRANSFORM_EXPRESSIONS
)
or (
isinstance(source, datetime.date)
and segment in datetime_utils.DATEFIELD_TRANSFORM_EXPRESSIONS
)
or (
isinstance(source, datetime.time)
and segment in datetime_utils.TIMEFIELD_TRANSFORM_EXPRESSIONS
)
):
source = datetime_utils.derive_from_value(source, segment)
value = source
elif hasattr(source, segment):
value = getattr(source, segment)
if isinstance(value, Model):
latest_obj = value
if value is None and i < len(segments):
if suppress_nullrelationshipvalueencountered:
return None
raise NullRelationshipValueEncountered(
"'{key}' cannot be reached for {obj} because {model_class}.{field_name} "
"is null.".format(
key=key,
obj=repr(obj),
model_class=latest_obj._meta.label,
field_name=segment,
)
)
source = value
elif suppress_fielddoesnotexist:
return None
else:
raise FieldDoesNotExist(
"'{name}' is not a valid field name for {model}".format(
name=segment, model=type(source)
)
)
if pk_only and hasattr(value, 'pk'):
return value.pk
return value
def sort_by_fields(items, fields):
"""
Sort a list of objects on the given fields. The field list works analogously to
queryset.order_by(*fields): each field is either a property of the object,
or is prefixed by '-' (e.g. '-name') to indicate reverse ordering.
"""
# To get the desired behaviour, we need to order by keys in reverse order
# See: https://docs.python.org/2/howto/sorting.html#sort-stability-and-complex-sorts
for key in reversed(fields):
if key == '?':
random.shuffle(items)
continue
# Check if this key has been reversed
reverse = False
if key[0] == '-':
reverse = True
key = key[1:]
def get_sort_value(item):
# Use a tuple of (v is not None, v) as the key, to ensure that None sorts before other values,
# as comparing directly with None breaks on python3
value = extract_field_value(item, key, pk_only=True, suppress_fielddoesnotexist=True, suppress_nullrelationshipvalueencountered=True)
return (value is not None, value)
# Sort items
items.sort(key=get_sort_value, reverse=reverse)