353 lines
14 KiB
Python
353 lines
14 KiB
Python
import json
|
|
import logging
|
|
from collections import OrderedDict
|
|
from django.db.models import JSONField, F, Q, Subquery, OuterRef
|
|
from django.db.models.functions import Cast
|
|
from django.db.migrations import RunPython
|
|
from django.utils.functional import cached_property
|
|
from wagtail.blocks import StreamValue
|
|
from wagtail.blocks.migrations import utils
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MigrateStreamData(RunPython):
|
|
"""Subclass of RunPython for streamfield data migration operations"""
|
|
|
|
def __init__(
|
|
self,
|
|
app_name,
|
|
model_name,
|
|
field_name,
|
|
operations_and_block_paths,
|
|
revisions_from=None,
|
|
chunk_size=1024,
|
|
**kwargs,
|
|
):
|
|
"""MigrateStreamData constructor
|
|
|
|
Args:
|
|
app_name (str): Name of the app.
|
|
model_name (str): Name of the model.
|
|
field_name (str): Name of the streamfield.
|
|
operations_and_block_paths (:obj:`list` of :obj:`tuple` of (:obj:`operation`, :obj:`str`)):
|
|
List of operations and corresponding block paths to apply.
|
|
revisions_from (:obj:`datetime`, optional): Only revisions created from this date
|
|
onwards will be updated. Passing `None` updates all revisions. Defaults to `None`.
|
|
Note that live and latest revisions will be updated regardless of what value this
|
|
takes.
|
|
chunk_size (:obj:`int`, optional): chunk size for queryset.iterator and bulk_update.
|
|
Defaults to 1024.
|
|
**kwargs: atomic, elidable, hints for superclass RunPython can be given
|
|
|
|
Example:
|
|
Renaming a block named `field1` to `block1`::
|
|
MigrateStreamData(
|
|
app_name="blog",
|
|
model_name="BlogPage",
|
|
field_name="content",
|
|
operations_and_block_paths=[
|
|
(RenameStreamChildrenOperation(old_name="field1", new_name="block1"), ""),
|
|
],
|
|
revisions_from=datetime.date(2022, 7, 25)
|
|
),
|
|
"""
|
|
|
|
self.app_name = app_name
|
|
self.model_name = model_name
|
|
self.field_name = field_name
|
|
self.operations_and_block_paths = operations_and_block_paths
|
|
self.revisions_from = revisions_from
|
|
self.chunk_size = chunk_size
|
|
|
|
# TODO add reverse code when needed, will probably need another input (reversible?)
|
|
# super class kwargs - atomic,elidable,hints
|
|
super().__init__(
|
|
code=self.migrate_stream_data_forward,
|
|
reverse_code=lambda *args: None,
|
|
**kwargs,
|
|
)
|
|
|
|
def deconstruct(self):
|
|
_, args, kwargs = super().deconstruct()
|
|
kwargs["app_name"] = self.app_name
|
|
kwargs["model_name"] = self.model_name
|
|
kwargs["field_name"] = self.field_name
|
|
kwargs["operations_and_block_paths"] = self.operations_and_block_paths
|
|
kwargs["revisions_from"] = self.revisions_from
|
|
kwargs["chunk_size"] = self.chunk_size
|
|
|
|
return (self.__class__.__qualname__, args, kwargs)
|
|
|
|
@property
|
|
def migration_name_fragment(self):
|
|
# We are using an OrderedDict here to essentially get the functionality of an ordered set
|
|
# so that names generated will be consistent.
|
|
fragments = OrderedDict(
|
|
(op.operation_name_fragment, None)
|
|
for op, _ in self.operations_and_block_paths
|
|
)
|
|
return "_".join(fragments.keys())
|
|
|
|
def migrate_stream_data_forward(self, apps, schema_editor):
|
|
model = apps.get_model(self.app_name, self.model_name)
|
|
|
|
# Here we can't directly check the wagtail version, rather we need to check the wagtail
|
|
# version at the project state when the migration is being applied
|
|
try:
|
|
apps.get_model("wagtailcore", "Revision")
|
|
revision_query_maker = DefaultRevisionQueryMaker(
|
|
apps, model, self.revisions_from
|
|
)
|
|
except LookupError:
|
|
revision_query_maker = Wagtail3RevisionQueryMaker(
|
|
apps, model, self.revisions_from
|
|
)
|
|
|
|
model_queryset = model.objects.annotate(
|
|
raw_content=Cast(F(self.field_name), JSONField())
|
|
).all()
|
|
|
|
updated_model_instances_buffer = []
|
|
for instance in model_queryset.iterator(chunk_size=self.chunk_size):
|
|
if instance.raw_content is None:
|
|
continue
|
|
|
|
revision_query_maker.append_instance_data_for_revision_query(instance)
|
|
|
|
raw_data = instance.raw_content
|
|
for operation, block_path_str in self.operations_and_block_paths:
|
|
try:
|
|
raw_data = utils.apply_changes_to_raw_data(
|
|
raw_data=raw_data,
|
|
block_path_str=block_path_str,
|
|
operation=operation,
|
|
streamfield=getattr(model, self.field_name),
|
|
)
|
|
# - TODO add a return value to util to know if changes were made
|
|
# - TODO save changed only
|
|
except utils.InvalidBlockDefError as e:
|
|
raise utils.InvalidBlockDefError(instance=instance) from e
|
|
|
|
stream_block = getattr(instance, self.field_name).stream_block
|
|
setattr(
|
|
instance,
|
|
self.field_name,
|
|
StreamValue(stream_block, raw_data, is_lazy=True),
|
|
)
|
|
updated_model_instances_buffer.append(instance)
|
|
|
|
if len(updated_model_instances_buffer) == self.chunk_size:
|
|
model.objects.bulk_update(
|
|
updated_model_instances_buffer, [self.field_name]
|
|
)
|
|
updated_model_instances_buffer = []
|
|
|
|
if len(updated_model_instances_buffer) > 0:
|
|
# For any remaining chunks
|
|
model.objects.bulk_update(updated_model_instances_buffer, [self.field_name])
|
|
|
|
# For models without revisions
|
|
if not revision_query_maker.has_revisions:
|
|
return
|
|
|
|
revision_queryset = revision_query_maker.get_revision_queryset()
|
|
|
|
updated_revisions_buffer = []
|
|
for revision in revision_queryset.iterator(chunk_size=self.chunk_size):
|
|
|
|
raw_data = json.loads(revision.content[self.field_name])
|
|
for operation, block_path_str in self.operations_and_block_paths:
|
|
try:
|
|
raw_data = utils.apply_changes_to_raw_data(
|
|
raw_data=raw_data,
|
|
block_path_str=block_path_str,
|
|
operation=operation,
|
|
streamfield=getattr(model, self.field_name),
|
|
)
|
|
except utils.InvalidBlockDefError as e:
|
|
if not revision_query_maker.get_is_live_or_latest_revision(
|
|
revision
|
|
):
|
|
logger.exception(
|
|
utils.InvalidBlockDefError(
|
|
revision=revision, instance=instance
|
|
)
|
|
)
|
|
continue
|
|
else:
|
|
raise utils.InvalidBlockDefError(
|
|
revision=revision, instance=instance
|
|
) from e
|
|
# - TODO add a return value to util to know if changes were made
|
|
# - TODO save changed only
|
|
|
|
revision.content[self.field_name] = json.dumps(raw_data)
|
|
updated_revisions_buffer.append(revision)
|
|
|
|
if len(updated_revisions_buffer) == self.chunk_size:
|
|
revision_query_maker.bulk_update(updated_revisions_buffer)
|
|
updated_revisions_buffer = []
|
|
|
|
if len(updated_revisions_buffer) > 0:
|
|
revision_query_maker.bulk_update(updated_revisions_buffer)
|
|
|
|
|
|
class AbstractRevisionQueryMaker:
|
|
"""Helper class for making the revision query needed for the data migration"""
|
|
|
|
def __init__(self, apps, model, revisions_from):
|
|
self.apps = apps
|
|
self.model = model
|
|
self.revisions_from = revisions_from
|
|
self.RevisionModel = self.get_revision_model()
|
|
self.has_revisions = self.get_has_revisions()
|
|
if self.has_revisions:
|
|
# latest or live revision ids may be available directly from the instance. In that case
|
|
# we can keep track of them here.
|
|
self.instance_field_revision_ids = set()
|
|
|
|
def get_revision_model(self):
|
|
raise NotImplementedError
|
|
|
|
def get_has_revisions(self):
|
|
raise NotImplementedError
|
|
|
|
def append_instance_data_for_revision_query(self, instance):
|
|
raise NotImplementedError
|
|
|
|
def _make_revision_query(self):
|
|
raise NotImplementedError
|
|
|
|
def get_revision_queryset(self):
|
|
revision_query = self._make_revision_query()
|
|
return self.RevisionModel.objects.filter(revision_query)
|
|
|
|
def bulk_update(self, data):
|
|
self.RevisionModel.objects.bulk_update(data, ["content"])
|
|
|
|
def get_is_live_or_latest_revision(self, revision):
|
|
raise NotImplementedError
|
|
|
|
|
|
class Wagtail3RevisionQueryMaker(AbstractRevisionQueryMaker):
|
|
"""Revision Query maker to support Wagtail 3"""
|
|
|
|
def __init__(self, apps, model, revisions_from):
|
|
self.page_ids = []
|
|
|
|
super().__init__(apps, model, revisions_from)
|
|
|
|
def get_revision_model(self):
|
|
return self.apps.get_model("wagtailcore", "PageRevision")
|
|
|
|
def get_has_revisions(self):
|
|
return issubclass(self.model, self.apps.get_model("wagtailcore", "Page"))
|
|
|
|
def append_instance_data_for_revision_query(self, instance):
|
|
if self.has_revisions:
|
|
self.page_ids.append(instance.id)
|
|
self.instance_field_revision_ids.add(instance.live_revision_id)
|
|
|
|
def _make_revision_query(self):
|
|
if self.revisions_from is not None:
|
|
# All revisions created after the given date.
|
|
revision_query = Q(
|
|
created_at__gte=self.revisions_from,
|
|
page_id__in=self.page_ids,
|
|
)
|
|
# All live revisions.
|
|
revision_query = revision_query | Q(id__in=self.instance_field_revision_ids)
|
|
# All latest revisions. For each revision, we check if it is the revision with the
|
|
# last `created_at` from all revisions with its `page_id`.
|
|
revision_query = revision_query | Q(
|
|
id__in=Subquery(
|
|
self.RevisionModel.objects.filter(page_id=OuterRef("page_id"))
|
|
.order_by("-created_at", "-id")
|
|
.values_list("id", flat=True)[:1]
|
|
),
|
|
page_id__in=self.page_ids,
|
|
)
|
|
return revision_query
|
|
|
|
# otherwise query all revisions for the page
|
|
else:
|
|
return Q(page_id__in=self.page_ids)
|
|
|
|
def get_is_live_or_latest_revision(self, revision):
|
|
if revision.id in self.instance_field_revision_ids:
|
|
return True
|
|
return revision.id in self._latest_revision_ids
|
|
|
|
@cached_property
|
|
def _latest_revision_ids(self):
|
|
return self.RevisionModel.objects.filter(
|
|
id__in=Subquery(
|
|
self.RevisionModel.objects.filter(page_id=OuterRef("page_id"))
|
|
.order_by("-created_at", "-id")
|
|
.values_list("id", flat=True)[:1]
|
|
),
|
|
page_id__in=self.page_ids,
|
|
).values_list("id", flat=True)
|
|
|
|
|
|
class DefaultRevisionQueryMaker(AbstractRevisionQueryMaker):
|
|
"""Revision Query Maker for Wagtail 4+"""
|
|
|
|
def __init__(self, apps, model, revisions_from):
|
|
self.has_live_revisions = False
|
|
self.has_latest_revisions = False
|
|
|
|
super().__init__(apps, model, revisions_from)
|
|
|
|
def get_revision_model(self):
|
|
return self.apps.get_model("wagtailcore", "Revision")
|
|
|
|
def get_has_revisions(self):
|
|
# We check if the models have a field `latest_revision` and make sure it points to the
|
|
# Revision model. This relation is there on models with `RevisionMixin`.
|
|
self.has_latest_revisions = (
|
|
hasattr(self.model, "latest_revision")
|
|
and self.model.latest_revision.field.remote_field.model
|
|
is self.RevisionModel
|
|
)
|
|
# Again, check for `live_revision`. This relation is there on models with `DraftStateMixin`.
|
|
self.has_live_revisions = (
|
|
hasattr(self.model, "live_revision")
|
|
and self.model.live_revision.field.remote_field.model is self.RevisionModel
|
|
)
|
|
return self.has_latest_revisions or self.has_live_revisions
|
|
|
|
def append_instance_data_for_revision_query(self, instance):
|
|
if self.has_revisions:
|
|
# From wagtail 4 onwards, there can be non page models which may have live or latest
|
|
# revisions, but not necessarily having both at the same time.
|
|
if self.has_latest_revisions:
|
|
self.instance_field_revision_ids.add(instance.latest_revision_id)
|
|
|
|
if self.has_live_revisions:
|
|
self.instance_field_revision_ids.add(instance.live_revision_id)
|
|
|
|
def _make_revision_query(self):
|
|
ContentType = self.apps.get_model("contenttypes", "ContentType")
|
|
contenttype_id = ContentType.objects.get_for_model(self.model).id
|
|
|
|
# if revisions_from is given, then query only the revisions created after that
|
|
# datetime (and the latest and live revisions if they are not after revisions_from)
|
|
if self.revisions_from is not None:
|
|
# All revisions created after the given date.
|
|
revision_query = Q(
|
|
created_at__gte=self.revisions_from,
|
|
content_type_id=contenttype_id,
|
|
)
|
|
# All live and latest revisions
|
|
revision_query = revision_query | Q(id__in=self.instance_field_revision_ids)
|
|
return revision_query
|
|
|
|
# otherwise query all revisions for the model
|
|
else:
|
|
return Q(content_type_id=contenttype_id)
|
|
|
|
def get_is_live_or_latest_revision(self, revision):
|
|
return revision.id in self.instance_field_revision_ids
|