Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1,684 @@
import warnings
from collections import OrderedDict
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
from django.db.models import Case, When
from django.db.models.aggregates import Avg, Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F
from django.db.models.fields import BooleanField, FloatField, TextField
from django.db.models.functions.comparison import Cast
from django.db.models.functions.text import Length
from django.db.models.manager import Manager
from django.db.models.query_utils import Q
from django.utils.encoding import force_str
from django.utils.functional import cached_property
from wagtail.search.backends.base import (
BaseSearchBackend,
BaseSearchQueryCompiler,
BaseSearchResults,
FilterFieldError,
)
from wagtail.search.backends.database.mysql.query import (
Lexeme,
MatchExpression,
SearchQuery,
)
from wagtail.search.index import (
AutocompleteField,
RelatedFields,
SearchField,
get_indexed_models,
)
from wagtail.search.models import IndexEntry
from wagtail.search.query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
from wagtail.search.utils import (
OR,
balanced_reduce,
get_content_type_pk,
get_descendants_content_types_pks,
)
class ObjectIndexer:
"""
Responsible for extracting data from an object to be inserted into the index.
"""
def __init__(self, obj, backend):
self.obj = obj
self.search_fields = obj.get_search_fields()
self.config = backend.config
def prepare_value(self, value):
if isinstance(value, str):
return value
elif isinstance(value, list):
return ", ".join(self.prepare_value(item) for item in value)
elif isinstance(value, dict):
return ", ".join(self.prepare_value(item) for item in value.values())
return force_str(value)
def prepare_field(self, obj, field):
if isinstance(field, SearchField):
yield (field, self.prepare_value(field.get_value(obj)))
elif isinstance(field, AutocompleteField):
yield (field, self.prepare_value(field.get_value(obj)))
elif isinstance(field, RelatedFields):
sub_obj = field.get_value(obj)
if sub_obj is None:
return
if isinstance(sub_obj, Manager):
sub_objs = sub_obj.all()
else:
if callable(sub_obj):
sub_obj = sub_obj()
sub_objs = [sub_obj]
for sub_obj in sub_objs:
for sub_field in field.fields:
yield from self.prepare_field(sub_obj, sub_field)
@cached_property
def id(self):
"""
Returns the value to use as the ID of the record in the index
"""
return force_str(self.obj.pk)
@cached_property
def title(self):
"""
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
"""
texts = []
for field in self.search_fields:
for current_field, value in self.prepare_field(self.obj, field):
if (
isinstance(current_field, SearchField)
and current_field.field_name == "title"
):
texts.append(value)
return " ".join(texts)
@cached_property
def body(self):
"""
Returns all values to index as "body". This is the value of all SearchFields excluding the title
"""
texts = []
for field in self.search_fields:
for current_field, value in self.prepare_field(self.obj, field):
if (
isinstance(current_field, SearchField)
and not current_field.field_name == "title"
):
texts.append(value)
return " ".join(texts)
@cached_property
def autocomplete(self):
"""
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
"""
texts = []
for field in self.search_fields:
for current_field, value in self.prepare_field(self.obj, field):
if isinstance(current_field, AutocompleteField):
texts.append(value)
return " ".join(texts)
def as_vector(self, texts, for_autocomplete=False):
"""
Converts an array of strings into a SearchVector that can be indexed.
"""
texts = [(text.strip(), weight) for text, weight in texts]
texts = [(text, weight) for text, weight in texts if text]
return " ".join(texts)
class Index:
def __init__(self, backend, db_alias=None):
self.backend = backend
self.name = self.backend.index_name
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
self.connection = connections[self.db_alias]
if self.connection.vendor != "mysql":
raise NotSupportedError(
"You must select a MySQL database " "to use MySQL search."
)
self.entries = IndexEntry._default_manager.using(self.db_alias)
def add_model(self, model):
pass
def refresh(self):
pass
def _refresh_title_norms(self, full=False):
"""
Refreshes the value of the title_norm field.
This needs to be set to 'lavg/ld' where:
- lavg is the average length of titles in all documents (also in terms)
- ld is the length of the title field in this document (in terms)
"""
lavg = (
self.entries.annotate(title_length=Length("title"))
.filter(title_length__gt=0)
.aggregate(Avg("title_length"))["title_length__avg"]
)
if full:
# Update the whole table
# This is the most accurate option but requires a full table rewrite
# so we can't do it too often as it could lead to locking issues.
entries = self.entries
else:
# Only update entries where title_norm is 1.0
# This is the default value set on new entries.
# It's possible that other entries could have this exact value but there shouldn't be too many of those
entries = self.entries.filter(title_norm=1.0)
entries.annotate(title_length=Length("title")).filter(
title_length__gt=0
).update(title_norm=lavg / F("title_length"))
def delete_stale_model_entries(self, model):
existing_pks = (
model._default_manager.using(self.db_alias)
.annotate(object_id=Cast("pk", TextField()))
.values("object_id")
)
content_types_pks = get_descendants_content_types_pks(model)
stale_entries = self.entries.filter(
content_type_id__in=content_types_pks
).exclude(object_id__in=existing_pks)
stale_entries.delete()
def delete_stale_entries(self):
for model in get_indexed_models():
# We dont need to delete stale entries for non-root models,
# since we already delete them by deleting roots.
if not model._meta.parents:
self.delete_stale_model_entries(model)
def add_item(self, obj):
self.add_items(obj._meta.model, [obj])
def add_items_update_then_create(self, content_type_pk, indexers):
ids_and_data = {}
for indexer in indexers:
ids_and_data[indexer.id] = (
indexer.title,
indexer.autocomplete,
indexer.body,
)
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
indexed_ids = frozenset(
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
"object_id", flat=True
)
)
for indexed_id in indexed_ids:
title, autocomplete, body = ids_and_data[indexed_id]
index_entries_for_ct.filter(object_id=indexed_id).update(
title=title, autocomplete=autocomplete, body=body
)
to_be_created = []
for object_id in ids_and_data.keys():
if object_id not in indexed_ids:
title, autocomplete, body = ids_and_data[object_id]
to_be_created.append(
IndexEntry(
content_type_id=content_type_pk,
object_id=object_id,
title=title,
autocomplete=autocomplete,
body=body,
)
)
self.entries.bulk_create(to_be_created)
self._refresh_title_norms()
def add_items(self, model, objs):
search_fields = model.get_search_fields()
if not search_fields:
return
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
# TODO: Delete unindexed objects while dealing with proxy models.
if indexers:
content_type_pk = get_content_type_pk(model)
update_method = self.add_items_update_then_create
update_method(content_type_pk, indexers)
def delete_item(self, item):
item.index_entries.all()._raw_delete(using=self.db_alias)
def __str__(self):
return self.name
class MySQLSearchQueryCompiler(BaseSearchQueryCompiler):
DEFAULT_OPERATOR = "and"
LAST_TERM_IS_PREFIX = False
TARGET_SEARCH_FIELD_TYPE = SearchField
FTS_TABLE_FIELDS = ["title", "body"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
local_search_fields = self.get_search_fields_for_model()
if self.fields is None:
# search over the fields defined on the current model
self.search_fields = local_search_fields
else:
# build a search_fields set from the passed definition,
# which may involve traversing relations
self.search_fields = {
field_lookup: self.get_search_field(
field_lookup, fields=local_search_fields
)
for field_lookup in self.fields
}
def get_config(self, backend):
return backend.config
def get_search_fields_for_model(self):
return self.queryset.model.get_searchable_search_fields()
def get_search_field(self, field_lookup, fields=None):
if fields is None:
fields = self.search_fields
if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None
for field in fields:
if (
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
and field.field_name == field_lookup
):
return field
# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
return self.get_search_field(sub_field_name, field.fields)
def build_search_query_content(self, query, invert=False):
if isinstance(query, PlainText):
terms = query.query_string.split()
if not terms:
return None
last_term = terms.pop()
lexemes = Lexeme(last_term, invert=invert, prefix=self.LAST_TERM_IS_PREFIX)
for term in terms:
new_lexeme = Lexeme(term, invert=invert)
if query.operator == "and":
lexemes &= new_lexeme
else:
lexemes |= new_lexeme
return SearchQuery(lexemes)
elif isinstance(query, Phrase):
return SearchQuery(query.query_string, search_type="phrase")
elif isinstance(query, Boost):
# Not supported
msg = "The Boost query is not supported by the MySQL search backend."
warnings.warn(msg, RuntimeWarning)
return self.build_search_query_content(query.subquery, invert=invert)
elif isinstance(query, Not):
return self.build_search_query_content(query.subquery, invert=not invert)
elif isinstance(query, (And, Or)):
# If this part of the query is inverted, we swap the operator and
# pass down the inversion state to the child queries.
# This works thanks to De Morgan's law.
#
# For example, the following query:
#
# Not(And(Term("A"), Term("B")))
#
# Is equivalent to:
#
# Or(Not(Term("A")), Not(Term("B")))
#
# It's simpler to code it this way as we only need to store the
# invert status of the terms rather than all the operators.
subquery_lexemes = [
self.build_search_query_content(subquery, invert=invert)
for subquery in query.subqueries
]
is_and = isinstance(query, And)
if invert:
is_and = not is_and
if is_and:
return balanced_reduce(lambda a, b: a & b, subquery_lexemes)
else:
return balanced_reduce(lambda a, b: a | b, subquery_lexemes)
raise NotImplementedError(
"`%s` is not supported by the MySQL search backend."
% query.__class__.__name__
)
def build_search_query(self, query):
return self.build_search_query_content(query)
def get_index_vectors(self, search_query):
return [
(F("index_entries__title"), F("index_entries__title_norm")),
(F("index_entries__body"), 1.0),
]
def get_fields_vectors(self, search_query):
raise NotImplementedError()
def get_search_vectors(self, search_query):
if self.fields is None:
return self.get_index_vectors(search_query)
else:
return self.get_fields_vectors(search_query)
def _build_rank_expression(self, vectors, config):
rank_expressions = [
self.build_tsrank(vector, self.query, config=config) * boost
for vector, boost in vectors
]
rank_expression = rank_expressions[0]
for other_rank_expression in rank_expressions[1:]:
rank_expression += other_rank_expression
return rank_expression
def search(self, config, start, stop, score_field=None):
# TODO: Handle MatchAll nested inside other search query classes.
if isinstance(self.query, MatchAll):
return self.queryset[start:stop]
elif isinstance(self.query, Not) and isinstance(self.query.subquery, MatchAll):
return self.queryset.none()
if isinstance(
self.query, Not
): # If the outermost operator is a Not, we invert the query. This is done because, if every search term is negated, the Not() will return no results, an we want to match all results instead.
query = self.query.subquery
negated = True
else:
query = self.query
negated = False
search_query = self.build_search_query(query)
match_expression = MatchExpression(
search_query, columns=self.FTS_TABLE_FIELDS, output_field=BooleanField()
) # For example: MATCH (`title`, `body`) AGAINST ('+query' IN BOOLEAN MODE)
# In Django 4.0 the above match expression would produce this SQL WHERE clause:
#
# WHERE ... MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE)
#
# In Django 4.1, this behavior was changed:
#
# https://code.djangoproject.com/ticket/32691
# https://github.com/django/django/commit/407fe95cb116599adeb4b9ed01df5673aa5cb1db
#
# so that instead this SQL WHERE clause is generated, explicitly filtering
# against "= True":
#
# WHERE ... MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE) = True
#
# This no longer works properly because MATCH returns a floating point score
# as a measurement of the match quality, not a boolean value:
#
# https://dev.mysql.com/doc/refman/8.0/en/fulltext-boolean.html
#
# In order for filtering on "= True" to work, we change the match expression
# SQL to be:
#
# WHERE ... CASE WHEN MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE) THEN True ELSE False END = True
match_expression = Case(When(match_expression, then=True), default=False)
score_expression = MatchExpression(
search_query, columns=["title"], output_field=FloatField()
) * F("title_norm") + MatchExpression(
search_query, columns=["body"], output_field=FloatField()
)
index_entries = IndexEntry.objects.annotate(score=score_expression).filter(
content_type_id__in=get_descendants_content_types_pks(self.queryset.model)
)
if not negated:
index_entries = index_entries.filter(match_expression)
if self.order_by_relevance: # Only applies to the case where the outermost query is not a Not(), because if it is, the relevance score is always 0 (anything that matches is excluded from the results).
index_entries = index_entries.order_by(score_expression.desc())
else:
index_entries = index_entries.exclude(match_expression)
index_entries = index_entries[start:stop] # Trim the results
object_ids = {
index_entry.object_id for index_entry in index_entries
} # Get the set of IDs from the indexed objects, removes duplicates too
results = self.queryset.filter(id__in=object_ids)
return results
def _process_lookup(self, field, lookup, value):
lhs = field.get_attname(self.queryset.model) + "__" + lookup
return Q(**{lhs: value})
def _connect_filters(self, filters, connector, negated):
if connector == "AND":
q = Q(*filters)
elif connector == "OR":
q = OR([Q(fil) for fil in filters])
else:
return
if negated:
q = ~q
return q
class MySQLAutocompleteQueryCompiler(MySQLSearchQueryCompiler):
LAST_TERM_IS_PREFIX = True
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
FTS_TABLE_FIELDS = ["autocomplete"]
def get_config(self, backend):
return backend.autocomplete_config
def get_search_fields_for_model(self):
return self.queryset.model.get_autocomplete_search_fields()
def get_index_vectors(self, search_query):
return [(F("index_entries__autocomplete"), 1.0)]
def get_fields_vectors(self, search_query):
raise NotImplementedError()
class MySQLSearchResults(BaseSearchResults):
def get_queryset(self, for_count=False):
if for_count:
start = None
stop = None
else:
start = self.start
stop = self.stop
return self.query_compiler.search(
self.query_compiler.get_config(self.backend),
start,
stop,
score_field=self._score_field,
)
def _do_search(self):
return list(self.get_queryset())
def _do_count(self):
return self.get_queryset(for_count=True).count()
supports_facet = True
def facet(self, field_name):
# Get field
field = self.query_compiler._get_filterable_field(field_name)
if field is None:
raise FilterFieldError(
'Cannot facet search results with field "'
+ field_name
+ "\". Please add index.FilterField('"
+ field_name
+ "') to "
+ self.query_compiler.queryset.model.__name__
+ ".search_fields.",
field_name=field_name,
)
query = self.query_compiler.search(
self.query_compiler.get_config(self.backend), None, None
)
results = (
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
)
return OrderedDict(
[(result[field_name], result["count"]) for result in results]
)
class MySQLSearchRebuilder:
def __init__(self, index):
self.index = index
def start(self):
self.index.delete_stale_entries()
return self.index
def finish(self):
self.index._refresh_title_norms(full=True)
class MySQLSearchAtomicRebuilder(MySQLSearchRebuilder):
def __init__(self, index):
super().__init__(index)
self.transaction = transaction.atomic(using=index.db_alias)
self.transaction_opened = False
def start(self):
self.transaction.__enter__()
self.transaction_opened = True
return super().start()
def finish(self):
self.index._refresh_title_norms(full=True)
self.transaction.__exit__(None, None, None)
self.transaction_opened = False
def __del__(self):
# TODO: Implement a cleaner way to close the connection on failure.
if self.transaction_opened:
self.transaction.needs_rollback = True
self.finish()
class MySQLSearchBackend(BaseSearchBackend):
query_compiler_class = MySQLSearchQueryCompiler
autocomplete_query_compiler_class = MySQLAutocompleteQueryCompiler
results_class = MySQLSearchResults
rebuilder_class = MySQLSearchRebuilder
atomic_rebuilder_class = MySQLSearchAtomicRebuilder
def __init__(self, params):
super().__init__(params)
self.index_name = params.get("INDEX", "default")
# MySQL backend currently has no config options
self.config = None
self.autocomplete_config = None
if params.get("ATOMIC_REBUILD", False):
self.rebuilder_class = self.atomic_rebuilder_class
def get_index_for_model(self, model, db_alias=None):
return Index(self, db_alias)
def get_index_for_object(self, obj):
return self.get_index_for_model(obj._meta.model, obj._state.db)
def reset_index(self):
for connection in [
connection
for connection in connections.all()
if connection.vendor == "mysql"
]:
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
def add_type(self, model):
pass # Not needed.
def refresh_index(self):
pass # Not needed.
def add(self, obj):
self.get_index_for_object(obj).add_item(obj)
def add_bulk(self, model, obj_list):
if obj_list:
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
def delete(self, obj):
self.get_index_for_object(obj).delete_item(obj)
SearchBackend = MySQLSearchBackend

View File

@@ -0,0 +1,254 @@
import re
from typing import Any, List, Tuple, Union
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models.expressions import CombinedExpression, Expression, Value
from django.db.models.fields import BooleanField, Field
from django.db.models.sql.compiler import SQLCompiler
class LexemeCombinable(Expression):
BITAND = "+"
BITOR = ""
invert = False # By default, it is not inverted
def _combine(self, other, connector, reversed, node=None):
if not isinstance(other, LexemeCombinable):
raise TypeError(
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
)
if reversed:
return CombinedLexeme(other, connector, self)
return CombinedLexeme(self, connector, other)
# On Combinable, these are not implemented to reduce confusion with Q. In
# this case we are actually (ab)using them to do logical combination so
# it's consistent with other usage in Django.
def bitand(self, other):
return self._combine(other, self.BITAND, False)
def bitor(self, other):
return self._combine(other, self.BITOR, False)
def __or__(self, other):
return self._combine(other, self.BITOR, False)
def __and__(self, other):
return self._combine(other, self.BITAND, False)
class SearchQueryField(Field):
def db_type(self, connection):
return None
class Lexeme(LexemeCombinable, Value):
_output_field = SearchQueryField()
def __init__(
self, value, output_field=None, invert=False, prefix=False, weight=None
):
self.prefix = prefix
self.invert = invert
self.weight = weight
super().__init__(value, output_field=output_field)
def as_sql(self, compiler, connection):
param = re.sub(
r"\W+", " ", self.value
) # Remove non-word characters. This is done to disallow the usage of full text search operators in the MATCH clause, because MySQL doesn't include these kinds of characters in FULLTEXT indexes.
template = "%s"
if self.prefix:
param = f"{param}*"
if self.invert:
param = f"(-{param})"
else:
param = f"{param}"
return template, [param]
class CombinedLexeme(LexemeCombinable):
_output_field = SearchQueryField()
def __init__(self, lhs, connector, rhs, output_field=None):
super().__init__(output_field=output_field)
self.connector = connector
self.lhs = lhs
self.rhs = rhs
def as_sql(self, compiler, connection):
value_params = []
lhs_connector = self.connector
rhs_connector = self.connector
if (
self.lhs.invert and self.connector == "+"
): # NOTE: This is a special case for MySQL. If either side's operator is AND (+), and it is inverted, the operator should become NOT (-). If we did nothing, the result could would be '+X +(-Y)' for And(X, Not(Y)), which seems correct, but produces a wrong result. The solution is to turn the query into '+X -Y', which does work, and therefore this is done here.
# TODO: There may be a better solution than this.
modified_value = self.lhs.value.copy()
modified_value.invert = not modified_value.invert
lhs_connector = "-"
lsql, params = compiler.compile(modified_value)
else:
lsql, params = compiler.compile(self.lhs)
value_params.extend(params)
if self.rhs.invert and self.connector == "+": # Same explanation as above.
modified_value = self.rhs.value.copy()
modified_value.invert = not modified_value.invert
rhs_connector = "-"
rsql, params = compiler.compile(modified_value)
else:
rsql, params = compiler.compile(self.rhs)
value_params.extend(params)
combined_sql = "({}{} {}{})".format(
lhs_connector, lsql, rhs_connector, rsql
) # if self.connector is '+' (AND), then both terms will be ANDed together. We need to repeat the connector to make that work.
combined_value = combined_sql % tuple(value_params)
return "%s", [combined_value]
class SearchQueryCombinable:
BITAND = "+"
BITOR = ""
def _combine(self, other, connector: str, reversed: bool = False):
if not isinstance(other, SearchQueryCombinable):
raise TypeError(
"SearchQuery can only be combined with other SearchQuery "
"instances, got %s." % type(other).__name__
)
if reversed:
return CombinedSearchQuery(other, connector, self)
return CombinedSearchQuery(self, connector, other)
# On Combinable, these are not implemented to reduce confusion with Q. In
# this case we are actually (ab)using them to do logical combination so
# it's consistent with other usage in Django.
def __or__(self, other):
return self._combine(other, self.BITOR, False)
def __ror__(self, other):
return self._combine(other, self.BITOR, True)
def __and__(self, other):
return self._combine(other, self.BITAND, False)
def __rand__(self, other):
return self._combine(other, self.BITAND, True)
class SearchQuery(SearchQueryCombinable, Expression):
def __init__(
self, value: Union[LexemeCombinable, str], search_type: str = "lexeme", **extra
):
super().__init__(output_field=SearchQueryField())
self.extra = extra
if (
isinstance(value, str) or search_type == "phrase"
): # If the value is a string, we assume it's a phrase
safe_string = re.sub(
r"\W+", " ", value
) # Remove non-word characters. This is done to disallow the usage of full text search operators in the MATCH clause, because MySQL doesn't include these kinds of characters in FULLTEXT indexes.
self.value = Value(
'"%s"' % safe_string
) # We wrap it in quotes to make sure it's parsed as a phrase
else: # Otherwise, we assume it's a lexeme
self.value = value
def as_sql(
self,
compiler: SQLCompiler,
connection: BaseDatabaseWrapper,
**extra_context: Any,
) -> Tuple[str, List[Any]]:
sql, params = compiler.compile(self.value)
return (sql, params)
def __repr__(self) -> str:
return self.value.__repr__()
class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
def __init__(self, lhs, connector, rhs, output_field=None):
super().__init__(lhs, connector, rhs, output_field)
def __str__(self):
return "%s" % super().__str__()
def as_sql(self, compiler, connection):
value_params = []
lhs_connector = self.connector
rhs_connector = self.connector
if (
isinstance(self.lhs, SearchQuery)
and isinstance(self.lhs.value, Lexeme)
and self.lhs.value.invert
and self.connector == "+"
): # NOTE: The explanation for this special case is the same as above, in the CombinedLexeme class.
modified_value = self.lhs.value.copy()
modified_value.invert = not modified_value.invert
lhs_connector = "-"
lsql, params = compiler.compile(modified_value)
else:
lsql, params = compiler.compile(self.lhs)
value_params.extend(params)
if (
isinstance(self.rhs, SearchQuery)
and isinstance(self.rhs.value, Lexeme)
and self.rhs.value.invert
and self.connector == "+"
): # NOTE: The explanation for this special case is the same as above, in the CombinedLexeme class.
modified_value = self.rhs.value.copy()
modified_value.invert = not modified_value.invert
rhs_connector = "-"
rsql, params = compiler.compile(modified_value)
else:
rsql, params = compiler.compile(self.rhs)
value_params.extend(params)
combined_sql = "({}{} {}{})".format(
lhs_connector, lsql, rhs_connector, rsql
) # if self.connector is '+' (AND), then both terms will be ANDed together. We need to repeat the connector to make that work.
combined_value = combined_sql % tuple(value_params)
return "%s", [combined_value]
class MatchExpression(Expression):
filterable = True
template = "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)"
def __init__(
self,
query: SearchQueryCombinable,
columns: List[str] = None,
output_field: Field = BooleanField(),
) -> None:
super().__init__(output_field=output_field)
self.query = query
self.columns = (
columns
or [
"title",
"body",
]
) # We need to provide a default list of columns if the user doesn't specify one. We have a joint index for for 'title' and 'body' (see wagtail.search.migrations.0006_customise_indexentry), so we'll pick that one.
def as_sql(self, compiler, connection):
compiled_query = compiler.compile(self.query) # Compile the query to a string
formatted_query = compiled_query[0] % tuple(
compiled_query[1]
) # Substitute the params in the query
column_list = ", ".join(
[f"`{column}`" for column in self.columns]
) # ['title', 'body'] becomes '`title`, `body`'
params = [formatted_query]
return (self.template % (column_list, "%s"), params)