Initial commit

This commit is contained in:
2024-08-27 20:33:44 +02:00
commit 1f1832267d
14794 changed files with 1599592 additions and 0 deletions

View File

@@ -0,0 +1,784 @@
import warnings
from collections import OrderedDict
from functools import reduce
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
from django.db.models import Avg, Count, F, Manager, Q, TextField, Value
from django.db.models.constants import LOOKUP_SEP
from django.db.models.functions import Cast, Length
from django.db.models.sql.subqueries import InsertQuery
from django.utils.encoding import force_str
from django.utils.functional import cached_property
from ....index import AutocompleteField, RelatedFields, SearchField, get_indexed_models
from ....models import IndexEntry
from ....query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
from ....utils import (
ADD,
MUL,
OR,
get_content_type_pk,
get_descendants_content_types_pks,
)
from ...base import (
BaseSearchBackend,
BaseSearchQueryCompiler,
BaseSearchResults,
FilterFieldError,
)
from .query import Lexeme
from .weights import get_sql_weights, get_weight
EMPTY_VECTOR = SearchVector(Value("", output_field=TextField()))
class ObjectIndexer:
"""
Responsible for extracting data from an object to be inserted into the index.
"""
def __init__(self, obj, backend):
self.obj = obj
self.search_fields = obj.get_search_fields()
self.config = backend.config
self.autocomplete_config = backend.autocomplete_config
def prepare_value(self, value):
if isinstance(value, str):
return value
elif isinstance(value, list):
return ", ".join(self.prepare_value(item) for item in value)
elif isinstance(value, dict):
return ", ".join(self.prepare_value(item) for item in value.values())
return force_str(value)
def prepare_field(self, obj, field):
if isinstance(field, SearchField):
yield (
field,
get_weight(field.boost),
self.prepare_value(field.get_value(obj)),
)
elif isinstance(field, AutocompleteField):
# AutocompleteField does not define a boost parameter, so use a base weight of 'D'
yield (field, "D", self.prepare_value(field.get_value(obj)))
elif isinstance(field, RelatedFields):
sub_obj = field.get_value(obj)
if sub_obj is None:
return
if isinstance(sub_obj, Manager):
sub_objs = sub_obj.all()
else:
if callable(sub_obj):
sub_obj = sub_obj()
sub_objs = [sub_obj]
for sub_obj in sub_objs:
for sub_field in field.fields:
yield from self.prepare_field(sub_obj, sub_field)
def as_vector(self, texts, for_autocomplete=False):
"""
Converts an array of strings into a SearchVector that can be indexed.
"""
texts = [(text.strip(), weight) for text, weight in texts]
texts = [(text, weight) for text, weight in texts if text]
if not texts:
return EMPTY_VECTOR
search_config = self.autocomplete_config if for_autocomplete else self.config
return ADD(
[
SearchVector(
Value(text, output_field=TextField()),
weight=weight,
config=search_config,
)
for text, weight in texts
]
)
@cached_property
def id(self):
"""
Returns the value to use as the ID of the record in the index
"""
return force_str(self.obj.pk)
@cached_property
def title(self):
"""
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
"""
texts = []
for field in self.search_fields:
for current_field, boost, value in self.prepare_field(self.obj, field):
if (
isinstance(current_field, SearchField)
and current_field.field_name == "title"
):
texts.append((value, boost))
return self.as_vector(texts)
@cached_property
def body(self):
"""
Returns all values to index as "body". This is the value of all SearchFields excluding the title
"""
texts = []
for field in self.search_fields:
for current_field, boost, value in self.prepare_field(self.obj, field):
if (
isinstance(current_field, SearchField)
and not current_field.field_name == "title"
):
texts.append((value, boost))
return self.as_vector(texts)
@cached_property
def autocomplete(self):
"""
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
"""
texts = []
for field in self.search_fields:
for current_field, boost, value in self.prepare_field(self.obj, field):
if isinstance(current_field, AutocompleteField):
texts.append((value, boost))
return self.as_vector(texts, for_autocomplete=True)
class Index:
def __init__(self, backend, db_alias=None):
self.backend = backend
self.name = self.backend.index_name
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
self.connection = connections[self.db_alias]
if self.connection.vendor != "postgresql":
raise NotSupportedError(
"You must select a PostgreSQL database " "to use PostgreSQL search."
)
# Whether to allow adding items via the faster upsert method available in Postgres >=9.5
self._enable_upsert = self.connection.pg_version >= 90500
self.entries = IndexEntry._default_manager.using(self.db_alias)
def add_model(self, model):
pass
def refresh(self):
pass
def _refresh_title_norms(self, full=False):
"""
Refreshes the value of the title_norm field.
This needs to be set to 'lavg/ld' where:
- lavg is the average length of titles in all documents (also in terms)
- ld is the length of the title field in this document (in terms)
"""
lavg = (
self.entries.annotate(title_length=Length("title"))
.filter(title_length__gt=0)
.aggregate(Avg("title_length"))["title_length__avg"]
)
if full:
# Update the whole table
# This is the most accurate option but requires a full table rewrite
# so we can't do it too often as it could lead to locking issues.
entries = self.entries
else:
# Only update entries where title_norm is 1.0
# This is the default value set on new entries.
# It's possible that other entries could have this exact value but there shouldn't be too many of those
entries = self.entries.filter(title_norm=1.0)
entries.annotate(title_length=Length("title")).filter(
title_length__gt=0
).update(title_norm=lavg / F("title_length"))
def delete_stale_model_entries(self, model):
existing_pks = (
model._default_manager.using(self.db_alias)
.annotate(object_id=Cast("pk", TextField()))
.values("object_id")
)
content_types_pks = get_descendants_content_types_pks(model)
stale_entries = self.entries.filter(
content_type_id__in=content_types_pks
).exclude(object_id__in=existing_pks)
stale_entries.delete()
def delete_stale_entries(self):
for model in get_indexed_models():
# We dont need to delete stale entries for non-root models,
# since we already delete them by deleting roots.
if not model._meta.parents:
self.delete_stale_model_entries(model)
def add_item(self, obj):
self.add_items(obj._meta.model, [obj])
def add_items_upsert(self, content_type_pk, indexers):
compiler = InsertQuery(IndexEntry).get_compiler(connection=self.connection)
title_sql = []
autocomplete_sql = []
body_sql = []
data_params = []
for indexer in indexers:
data_params.extend((content_type_pk, indexer.id))
# Compile title value
value = compiler.prepare_value(
IndexEntry._meta.get_field("title"), indexer.title
)
sql, params = value.as_sql(compiler, self.connection)
title_sql.append(sql)
data_params.extend(params)
# Compile autocomplete value
value = compiler.prepare_value(
IndexEntry._meta.get_field("autocomplete"), indexer.autocomplete
)
sql, params = value.as_sql(compiler, self.connection)
autocomplete_sql.append(sql)
data_params.extend(params)
# Compile body value
value = compiler.prepare_value(
IndexEntry._meta.get_field("body"), indexer.body
)
sql, params = value.as_sql(compiler, self.connection)
body_sql.append(sql)
data_params.extend(params)
data_sql = ", ".join(
[
f"(%s, %s, {a}, {b}, {c}, 1.0)"
for a, b, c in zip(title_sql, autocomplete_sql, body_sql)
]
)
with self.connection.cursor() as cursor:
cursor.execute(
"""
INSERT INTO %s (content_type_id, object_id, title, autocomplete, body, title_norm)
(VALUES %s)
ON CONFLICT (content_type_id, object_id)
DO UPDATE SET title = EXCLUDED.title,
title_norm = 1.0,
autocomplete = EXCLUDED.autocomplete,
body = EXCLUDED.body
"""
% (IndexEntry._meta.db_table, data_sql),
data_params,
)
self._refresh_title_norms()
def add_items_update_then_create(self, content_type_pk, indexers):
ids_and_data = {}
for indexer in indexers:
ids_and_data[indexer.id] = (
indexer.title,
indexer.autocomplete,
indexer.body,
)
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
indexed_ids = frozenset(
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
"object_id", flat=True
)
)
for indexed_id in indexed_ids:
title, autocomplete, body = ids_and_data[indexed_id]
index_entries_for_ct.filter(object_id=indexed_id).update(
title=title, autocomplete=autocomplete, body=body
)
to_be_created = []
for object_id in ids_and_data.keys():
if object_id not in indexed_ids:
title, autocomplete, body = ids_and_data[object_id]
to_be_created.append(
IndexEntry(
content_type_id=content_type_pk,
object_id=object_id,
title=title,
autocomplete=autocomplete,
body=body,
)
)
self.entries.bulk_create(to_be_created)
self._refresh_title_norms()
def add_items(self, model, objs):
search_fields = model.get_search_fields()
if not search_fields:
return
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
# TODO: Delete unindexed objects while dealing with proxy models.
if indexers:
content_type_pk = get_content_type_pk(model)
update_method = (
self.add_items_upsert
if self._enable_upsert
else self.add_items_update_then_create
)
update_method(content_type_pk, indexers)
def delete_item(self, item):
item.index_entries.all()._raw_delete(using=self.db_alias)
def __str__(self):
return self.name
class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
DEFAULT_OPERATOR = "and"
LAST_TERM_IS_PREFIX = False
TARGET_SEARCH_FIELD_TYPE = SearchField
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
local_search_fields = self.get_search_fields_for_model()
# Due to a Django bug, arrays are not automatically converted
# when we use WEIGHTS_VALUES.
self.sql_weights = get_sql_weights()
if self.fields is None:
# search over the fields defined on the current model
self.search_fields = local_search_fields
else:
# build a search_fields set from the passed definition,
# which may involve traversing relations
self.search_fields = {
field_lookup: self.get_search_field(
field_lookup, fields=local_search_fields
)
for field_lookup in self.fields
}
def get_config(self, backend):
return backend.config
def get_search_fields_for_model(self):
return self.queryset.model.get_searchable_search_fields()
def get_search_field(self, field_lookup, fields=None):
if fields is None:
fields = self.search_fields
if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None
for field in fields:
if (
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
and field.field_name == field_lookup
):
return field
# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
return self.get_search_field(sub_field_name, field.fields)
def build_tsquery_content(self, query, config=None, invert=False):
if isinstance(query, PlainText):
terms = query.query_string.split()
if not terms:
return None
last_term = terms.pop()
lexemes = Lexeme(last_term, invert=invert, prefix=self.LAST_TERM_IS_PREFIX)
for term in terms:
new_lexeme = Lexeme(term, invert=invert)
if query.operator == "and":
lexemes &= new_lexeme
else:
lexemes |= new_lexeme
return SearchQuery(lexemes, search_type="raw", config=config)
elif isinstance(query, Phrase):
return SearchQuery(query.query_string, search_type="phrase", config=config)
elif isinstance(query, Boost):
# Not supported
msg = "The Boost query is not supported by the PostgreSQL search backend."
warnings.warn(msg, RuntimeWarning)
return self.build_tsquery_content(
query.subquery, config=config, invert=invert
)
elif isinstance(query, Not):
return self.build_tsquery_content(
query.subquery, config=config, invert=not invert
)
elif isinstance(query, (And, Or)):
# If this part of the query is inverted, we swap the operator and
# pass down the inversion state to the child queries.
# This works thanks to De Morgan's law.
#
# For example, the following query:
#
# Not(And(Term("A"), Term("B")))
#
# Is equivalent to:
#
# Or(Not(Term("A")), Not(Term("B")))
#
# It's simpler to code it this way as we only need to store the
# invert status of the terms rather than all the operators.
subquery_lexemes = [
self.build_tsquery_content(subquery, config=config, invert=invert)
for subquery in query.subqueries
]
is_and = isinstance(query, And)
if invert:
is_and = not is_and
if is_and:
return reduce(lambda a, b: a & b, subquery_lexemes)
else:
return reduce(lambda a, b: a | b, subquery_lexemes)
raise NotImplementedError(
"`%s` is not supported by the PostgreSQL search backend."
% query.__class__.__name__
)
def build_tsquery(self, query, config=None):
return self.build_tsquery_content(query, config=config)
def build_tsrank(self, vector, query, config=None, boost=1.0):
if isinstance(query, (Phrase, PlainText, Not)):
rank_expression = SearchRank(
vector,
self.build_tsquery(query, config=config),
weights=self.sql_weights,
)
if boost != 1.0:
rank_expression *= boost
return rank_expression
elif isinstance(query, Boost):
boost *= query.boost
return self.build_tsrank(vector, query.subquery, config=config, boost=boost)
elif isinstance(query, And):
return (
MUL(
1 + self.build_tsrank(vector, subquery, config=config, boost=boost)
for subquery in query.subqueries
)
- 1
)
elif isinstance(query, Or):
return ADD(
self.build_tsrank(vector, subquery, config=config, boost=boost)
for subquery in query.subqueries
) / (len(query.subqueries) or 1)
raise NotImplementedError(
"`%s` is not supported by the PostgreSQL search backend."
% query.__class__.__name__
)
def get_index_vectors(self, search_query):
return [
(F("index_entries__title"), F("index_entries__title_norm")),
(F("index_entries__body"), 1.0),
]
def get_fields_vectors(self, search_query):
return [
(
SearchVector(
field_lookup,
config=search_query.config,
),
search_field.boost,
)
for field_lookup, search_field in self.search_fields.items()
]
def get_search_vectors(self, search_query):
if self.fields is None:
return self.get_index_vectors(search_query)
else:
return self.get_fields_vectors(search_query)
def _build_rank_expression(self, vectors, config):
rank_expressions = [
self.build_tsrank(vector, self.query, config=config) * boost
for vector, boost in vectors
]
rank_expression = rank_expressions[0]
for other_rank_expression in rank_expressions[1:]:
rank_expression += other_rank_expression
return rank_expression
def search(self, config, start, stop, score_field=None):
# TODO: Handle MatchAll nested inside other search query classes.
if isinstance(self.query, MatchAll):
return self.queryset[start:stop]
elif isinstance(self.query, Not) and isinstance(self.query.subquery, MatchAll):
return self.queryset.none()
search_query = self.build_tsquery(self.query, config=config)
vectors = self.get_search_vectors(search_query)
rank_expression = self._build_rank_expression(vectors, config)
combined_vector = vectors[0][0]
for vector, boost in vectors[1:]:
combined_vector = combined_vector._combine(vector, "||", False)
queryset = self.queryset.annotate(_vector_=combined_vector).filter(
_vector_=search_query
)
if self.order_by_relevance:
queryset = queryset.order_by(rank_expression.desc(), "-pk")
elif not queryset.query.order_by:
# Adds a default ordering to avoid issue #3729.
queryset = queryset.order_by("-pk")
rank_expression = F("pk")
if score_field is not None:
queryset = queryset.annotate(**{score_field: rank_expression})
return queryset[start:stop]
def _process_lookup(self, field, lookup, value):
lhs = field.get_attname(self.queryset.model) + "__" + lookup
return Q(**{lhs: value})
def _connect_filters(self, filters, connector, negated):
if connector == "AND":
q = Q(*filters)
elif connector == "OR":
q = OR([Q(fil) for fil in filters])
else:
return
if negated:
q = ~q
return q
class PostgresAutocompleteQueryCompiler(PostgresSearchQueryCompiler):
LAST_TERM_IS_PREFIX = True
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
def get_config(self, backend):
return backend.autocomplete_config
def get_search_fields_for_model(self):
return self.queryset.model.get_autocomplete_search_fields()
def get_index_vectors(self, search_query):
return [(F("index_entries__autocomplete"), 1.0)]
def get_fields_vectors(self, search_query):
return [
(
SearchVector(
field_lookup,
config=search_query.config,
weight="D",
),
1.0,
)
for field_lookup, search_field in self.search_fields.items()
]
class PostgresSearchResults(BaseSearchResults):
def get_queryset(self, for_count=False):
if for_count:
start = None
stop = None
else:
start = self.start
stop = self.stop
return self.query_compiler.search(
self.query_compiler.get_config(self.backend),
start,
stop,
score_field=self._score_field,
)
def _do_search(self):
return list(self.get_queryset())
def _do_count(self):
return self.get_queryset(for_count=True).count()
supports_facet = True
def facet(self, field_name):
# Get field
field = self.query_compiler._get_filterable_field(field_name)
if field is None:
raise FilterFieldError(
'Cannot facet search results with field "'
+ field_name
+ "\". Please add index.FilterField('"
+ field_name
+ "') to "
+ self.query_compiler.queryset.model.__name__
+ ".search_fields.",
field_name=field_name,
)
query = self.query_compiler.search(
self.query_compiler.get_config(self.backend), None, None
)
results = (
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
)
return OrderedDict(
[(result[field_name], result["count"]) for result in results]
)
class PostgresSearchRebuilder:
def __init__(self, index):
self.index = index
def start(self):
self.index.delete_stale_entries()
return self.index
def finish(self):
self.index._refresh_title_norms(full=True)
class PostgresSearchAtomicRebuilder(PostgresSearchRebuilder):
def __init__(self, index):
super().__init__(index)
self.transaction = transaction.atomic(using=index.db_alias)
self.transaction_opened = False
def start(self):
self.transaction.__enter__()
self.transaction_opened = True
return super().start()
def finish(self):
self.index._refresh_title_norms(full=True)
self.transaction.__exit__(None, None, None)
self.transaction_opened = False
def __del__(self):
# TODO: Implement a cleaner way to close the connection on failure.
if self.transaction_opened:
self.transaction.needs_rollback = True
self.finish()
class PostgresSearchBackend(BaseSearchBackend):
query_compiler_class = PostgresSearchQueryCompiler
autocomplete_query_compiler_class = PostgresAutocompleteQueryCompiler
results_class = PostgresSearchResults
rebuilder_class = PostgresSearchRebuilder
atomic_rebuilder_class = PostgresSearchAtomicRebuilder
def __init__(self, params):
super().__init__(params)
self.index_name = params.get("INDEX", "default")
self.config = params.get("SEARCH_CONFIG")
# Use 'simple' config for autocomplete to disable stemming
# A good description for why this is important can be found at:
# https://www.postgresql.org/docs/9.1/datatype-textsearch.html#DATATYPE-TSQUERY
self.autocomplete_config = params.get("AUTOCOMPLETE_SEARCH_CONFIG", "simple")
if params.get("ATOMIC_REBUILD", False):
self.rebuilder_class = self.atomic_rebuilder_class
def get_index_for_model(self, model, db_alias=None):
return Index(self, db_alias)
def get_index_for_object(self, obj):
return self.get_index_for_model(obj._meta.model, obj._state.db)
def reset_index(self):
for connection in [
connection
for connection in connections.all()
if connection.vendor == "postgresql"
]:
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
def add_type(self, model):
pass # Not needed.
def refresh_index(self):
pass # Not needed.
def add(self, obj):
self.get_index_for_object(obj).add_item(obj)
def add_bulk(self, model, obj_list):
if obj_list:
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
def delete(self, obj):
self.get_index_for_object(obj).delete_item(obj)
SearchBackend = PostgresSearchBackend

View File

@@ -0,0 +1,83 @@
from django.contrib.postgres.search import SearchQueryField
from django.db.models.expressions import Expression, Value
class LexemeCombinable(Expression):
BITAND = "&"
BITOR = "|"
def _combine(self, other, connector, reversed, node=None):
if not isinstance(other, LexemeCombinable):
raise TypeError(
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
)
if reversed:
return CombinedLexeme(other, connector, self)
return CombinedLexeme(self, connector, other)
# On Combinable, these are not implemented to reduce confusion with Q. In
# this case we are actually (ab)using them to do logical combination so
# it's consistent with other usage in Django.
def bitand(self, other):
return self._combine(other, self.BITAND, False)
def bitor(self, other):
return self._combine(other, self.BITOR, False)
def __or__(self, other):
return self._combine(other, self.BITOR, False)
def __and__(self, other):
return self._combine(other, self.BITAND, False)
class Lexeme(LexemeCombinable, Value):
_output_field = SearchQueryField()
def __init__(
self, value, output_field=None, *, invert=False, prefix=False, weight=None
):
self.prefix = prefix
self.invert = invert
self.weight = weight
super().__init__(value, output_field=output_field)
def as_sql(self, compiler, connection):
param = "'%s'" % self.value.replace("'", "''").replace("\\", "\\\\")
template = "%s"
label = ""
if self.prefix:
label += "*"
if self.weight:
label += self.weight
if label:
param = f"{param}:{label}"
if self.invert:
param = f"!{param}"
return template, [param]
class CombinedLexeme(LexemeCombinable):
_output_field = SearchQueryField()
def __init__(self, lhs, connector, rhs, output_field=None):
super().__init__(output_field=output_field)
self.connector = connector
self.lhs = lhs
self.rhs = rhs
def as_sql(self, compiler, connection):
value_params = []
lsql, params = compiler.compile(self.lhs)
value_params.extend(params)
rsql, params = compiler.compile(self.rhs)
value_params.extend(params)
combined_sql = f"({lsql} {self.connector} {rsql})"
combined_value = combined_sql % tuple(value_params)
return "%s", [combined_value]

View File

@@ -0,0 +1,63 @@
from itertools import zip_longest
from django.apps import apps
from wagtail.search.index import Indexed
from wagtail.search.utils import get_search_fields
# This file contains the implementation of weights for PostgreSQL tsvectors. Only PostgreSQL has support for them, so that's why we define them here.
WEIGHTS = "ABCD"
WEIGHTS_COUNT = len(WEIGHTS)
# These are filled when apps are ready.
BOOSTS_WEIGHTS = []
WEIGHTS_VALUES = []
def get_boosts():
boosts = set()
for model in apps.get_models():
if issubclass(model, Indexed):
for search_field in get_search_fields(model.get_search_fields()):
boost = search_field.boost
if boost is not None:
boosts.add(boost)
return boosts
def determine_boosts_weights(boosts=()):
if not boosts:
boosts = get_boosts()
boosts = sorted(boosts, reverse=True)
min_boost = boosts[-1]
if len(boosts) <= WEIGHTS_COUNT:
return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0)))
max_boost = boosts[0]
boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1)
return [(max_boost - (i * boost_step), weight) for i, weight in enumerate(WEIGHTS)]
def set_weights():
BOOSTS_WEIGHTS.extend(determine_boosts_weights())
weights = [w for w, c in BOOSTS_WEIGHTS]
min_weight = min(weights)
if min_weight <= 0:
if min_weight == 0:
min_weight = -0.1
weights = [w - min_weight for w in weights]
max_weight = max(weights)
WEIGHTS_VALUES.extend([w / max_weight for w in reversed(weights)])
def get_weight(boost):
if boost is None:
return WEIGHTS[-1]
for max_boost, weight in BOOSTS_WEIGHTS:
if boost >= max_boost:
return weight
return weight
def get_sql_weights():
return "{" + ",".join(map(str, WEIGHTS_VALUES)) + "}"