Initial commit
This commit is contained in:
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/__init__.py
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
294
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/query.py
vendored
Normal file
294
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/query.py
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.expressions import CombinedExpression, Expression, Func, Value
|
||||
from django.db.models.fields import BooleanField, Field, FloatField
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
|
||||
from wagtail.search.query import And, MatchAll, Not, Or, Phrase, PlainText, SearchQuery
|
||||
|
||||
|
||||
class BM25(Func):
|
||||
function = "bm25"
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self):
|
||||
expressions = ()
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(
|
||||
self,
|
||||
compiler: SQLCompiler,
|
||||
connection: BaseDatabaseWrapper,
|
||||
function=None,
|
||||
template=None,
|
||||
):
|
||||
sql, params = "bm25(wagtailsearch_indexentry_fts)", []
|
||||
return sql, params
|
||||
|
||||
|
||||
class LexemeCombinable(Expression):
|
||||
BITAND = "AND"
|
||||
BITOR = "OR"
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if not isinstance(other, LexemeCombinable):
|
||||
raise TypeError(
|
||||
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
|
||||
)
|
||||
if reversed:
|
||||
return CombinedLexeme(other, connector, self)
|
||||
return CombinedLexeme(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def bitand(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def bitor(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
|
||||
class SearchQueryField(Field):
|
||||
def db_type(self, connection):
|
||||
return None
|
||||
|
||||
|
||||
class Lexeme(LexemeCombinable, Value):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, value, output_field=None, *, prefix=False, weight=None):
|
||||
self.prefix = prefix
|
||||
self.weight = weight
|
||||
super().__init__(value, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
param = self.value.replace("'", "''").replace("\\", "\\\\")
|
||||
|
||||
if self.prefix:
|
||||
template = '"%s"*'
|
||||
else:
|
||||
template = '"%s"'
|
||||
|
||||
return template, [param]
|
||||
|
||||
|
||||
class CombinedLexeme(LexemeCombinable):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
value_params = []
|
||||
lsql, params = compiler.compile(self.lhs)
|
||||
value_params.extend(params)
|
||||
|
||||
rsql, params = compiler.compile(self.rhs)
|
||||
value_params.extend(params)
|
||||
|
||||
combined_sql = f"{lsql} {self.connector} {rsql}"
|
||||
combined_value = combined_sql % tuple(value_params)
|
||||
return "%s", [combined_value]
|
||||
|
||||
|
||||
class SearchQueryCombinable:
|
||||
BITAND = "AND"
|
||||
BITOR = "OR"
|
||||
|
||||
def _combine(self, other, connector: str, reversed: bool = False):
|
||||
if not isinstance(other, SearchQueryCombinable):
|
||||
raise TypeError(
|
||||
"SearchQuery can only be combined with other SearchQuery "
|
||||
"instances, got %s." % type(other).__name__
|
||||
)
|
||||
if reversed:
|
||||
return CombinedSearchQueryExpression(other, connector, self)
|
||||
return CombinedSearchQueryExpression(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __ror__(self, other):
|
||||
return self._combine(other, self.BITOR, True)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def __rand__(self, other):
|
||||
return self._combine(other, self.BITAND, True)
|
||||
|
||||
|
||||
class SearchQueryExpression(SearchQueryCombinable, Expression):
|
||||
def __init__(self, value: LexemeCombinable, using=None, **extra):
|
||||
super().__init__(output_field=SearchQueryField())
|
||||
self.using = using
|
||||
self.extra = extra
|
||||
if isinstance(value, str): # If the value is a string, we assume it's a phrase
|
||||
self.value = Value(
|
||||
'"%s"' % value
|
||||
) # We wrap it in quotes to make sure it's parsed as a phrase
|
||||
else: # Otherwise, we assume it's a lexeme
|
||||
self.value = value
|
||||
|
||||
def as_sql(
|
||||
self,
|
||||
compiler: SQLCompiler,
|
||||
connection: BaseDatabaseWrapper,
|
||||
**extra_context: Any,
|
||||
) -> Tuple[str, List[Any]]:
|
||||
sql, params = compiler.compile(self.value)
|
||||
return (sql, params)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.value.__repr__()
|
||||
|
||||
|
||||
class CombinedSearchQueryExpression(SearchQueryCombinable, CombinedExpression):
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(lhs, connector, rhs, output_field)
|
||||
|
||||
def __str__(self):
|
||||
return "(%s)" % super().__str__()
|
||||
|
||||
|
||||
class MatchExpression(Expression):
|
||||
filterable = True
|
||||
template = (
|
||||
"wagtailsearch_indexentry_fts MATCH %s" # TODO: Can the table name be inferred?
|
||||
)
|
||||
output_field = BooleanField()
|
||||
|
||||
def __init__(self, columns: List[str], query: SearchQueryCombinable) -> None:
|
||||
super().__init__(output_field=self.output_field)
|
||||
self.columns = columns
|
||||
self.query = query
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
joined_columns = " ".join(
|
||||
self.columns
|
||||
) # The format of the columns is 'column1 column2'
|
||||
compiled_query = compiler.compile(self.query) # Compile the query to a string
|
||||
formatted_query = compiled_query[0] % tuple(
|
||||
compiled_query[1]
|
||||
) # Substitute the params in the query
|
||||
params = [
|
||||
"{{{column}}} : ({query})".format(
|
||||
column=joined_columns, query=formatted_query
|
||||
)
|
||||
] # Build the full MATCH search query. It will be a parameter to the template, so no SQL injections are possible here.
|
||||
return (self.template, params)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MatchExpression: {self.columns!r} = {self.query!r}>"
|
||||
|
||||
|
||||
class AndNot(SearchQuery):
|
||||
"""
|
||||
This is a binary search query, where there are two subqueries, and the search is done by searching the first, and excluding the second subquery.
|
||||
For example, AndNot(X, Y) would be equivalent to doing And(X, Not(Y)), where X is the first subquery, and Y is the second subquery (the negated one).
|
||||
This is done because the SQLite FTS5 module doesn't support the unary NOT operator.
|
||||
"""
|
||||
|
||||
def __init__(self, subquery_a: SearchQuery, subquery_b: SearchQuery):
|
||||
self.subquery_a = subquery_a
|
||||
self.subquery_b = subquery_b
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{repr(self.subquery_a)} AndNot {repr(self.subquery_b)}>"
|
||||
|
||||
|
||||
def normalize(search_query: SearchQuery) -> Tuple[SearchQuery]:
|
||||
"""
|
||||
Turns this query into a normalized version.
|
||||
For example, And(Not(PlainText("Arepa")), PlainText("Crepe")) would be turned into AndNot(PlainText("Crepe"), PlainText("Arepa")): "Crepe AND NOT Arepa".
|
||||
This is done because we need to get the NOT operator to the front of the query, so it can be used in the search, because the SQLite FTS5 module doesn't support the unary NOT operator. This means that, in order to support the NOT operator, we need to match against the non-negated version of the query, and then return everything that is not in the results of the non-negated query.
|
||||
"""
|
||||
if isinstance(search_query, Phrase):
|
||||
return search_query # We can't normalize a Phrase.
|
||||
if isinstance(search_query, PlainText):
|
||||
return search_query # We can't normalize a PlainText.
|
||||
if isinstance(search_query, And):
|
||||
normalized_subqueries: List[SearchQuery] = [
|
||||
normalize(subquery) for subquery in search_query.subqueries
|
||||
] # This builds a list of normalized subqueries.
|
||||
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in normalized_subqueries
|
||||
if not isinstance(subquery, Not)
|
||||
] # All the non-negated subqueries.
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in not_negated_subqueries
|
||||
if not isinstance(subquery, MatchAll)
|
||||
] # We can ignore all MatchAll SearchQueries here, because they are redundant.
|
||||
negated_subqueries = [
|
||||
subquery.subquery
|
||||
for subquery in normalized_subqueries
|
||||
if isinstance(subquery, Not)
|
||||
]
|
||||
|
||||
if (
|
||||
negated_subqueries == []
|
||||
): # If there are no negated subqueries, return an And(), now without the redundant MatchAll subqueries.
|
||||
return And(not_negated_subqueries)
|
||||
|
||||
for subquery in (
|
||||
negated_subqueries
|
||||
): # If there's a negated MatchAll subquery, then nothing will get matched.
|
||||
if isinstance(subquery, MatchAll):
|
||||
return Not(MatchAll())
|
||||
|
||||
return AndNot(And(not_negated_subqueries), Or(negated_subqueries))
|
||||
if isinstance(search_query, Or):
|
||||
normalized_subqueries: List[SearchQuery] = [
|
||||
normalize(subquery) for subquery in search_query.subqueries
|
||||
] # This builds a list of (subquery, negated) tuples.
|
||||
|
||||
negated_subqueries = [
|
||||
subquery.subquery
|
||||
for subquery in normalized_subqueries
|
||||
if isinstance(subquery, Not)
|
||||
]
|
||||
if (
|
||||
negated_subqueries == []
|
||||
): # If there are no negated subqueries, return an Or().
|
||||
return Or(normalized_subqueries)
|
||||
|
||||
for subquery in (
|
||||
negated_subqueries
|
||||
): # If there's a MatchAll subquery, then anything will get matched.
|
||||
if isinstance(subquery, MatchAll):
|
||||
return MatchAll()
|
||||
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in normalized_subqueries
|
||||
if not isinstance(subquery, Not)
|
||||
] # All the non-negated subqueries.
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in not_negated_subqueries
|
||||
if not isinstance(subquery, MatchAll)
|
||||
] # We can ignore all MatchAll SearchQueries here, because they are redundant.
|
||||
|
||||
return AndNot(MatchAll(), And(negated_subqueries))
|
||||
if isinstance(search_query, Not):
|
||||
normalized = normalize(search_query.subquery)
|
||||
return Not(normalized) # Normalize the subquery, then invert it.
|
||||
if isinstance(search_query, MatchAll):
|
||||
return search_query # We can't normalize a MatchAll.
|
||||
707
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/sqlite.py
vendored
Normal file
707
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/sqlite.py
vendored
Normal file
@@ -0,0 +1,707 @@
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
||||
from django.db.models import Avg, Count, F, Manager, Q, TextField
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.functions import Cast, Length
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from ....index import AutocompleteField, RelatedFields, SearchField, get_indexed_models
|
||||
from ....models import IndexEntry, SQLiteFTSIndexEntry
|
||||
from ....query import And, MatchAll, Not, Or, Phrase, PlainText
|
||||
from ....utils import (
|
||||
ADD,
|
||||
MUL,
|
||||
OR,
|
||||
get_content_type_pk,
|
||||
get_descendants_content_types_pks,
|
||||
)
|
||||
from ...base import (
|
||||
BaseSearchBackend,
|
||||
BaseSearchQueryCompiler,
|
||||
BaseSearchResults,
|
||||
FilterFieldError,
|
||||
)
|
||||
from .query import (
|
||||
BM25,
|
||||
AndNot,
|
||||
Lexeme,
|
||||
MatchExpression,
|
||||
SearchQueryExpression,
|
||||
normalize,
|
||||
)
|
||||
|
||||
|
||||
class ObjectIndexer:
|
||||
"""
|
||||
Responsible for extracting data from an object to be inserted into the index.
|
||||
"""
|
||||
|
||||
def __init__(self, obj, backend):
|
||||
self.obj = obj
|
||||
self.search_fields = obj.get_search_fields()
|
||||
self.config = backend.config
|
||||
|
||||
def prepare_value(self, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
elif isinstance(value, list):
|
||||
return ", ".join(self.prepare_value(item) for item in value)
|
||||
|
||||
elif isinstance(value, dict):
|
||||
return ", ".join(self.prepare_value(item) for item in value.values())
|
||||
|
||||
return force_str(value)
|
||||
|
||||
def prepare_field(self, obj, field):
|
||||
if isinstance(field, SearchField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, AutocompleteField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, RelatedFields):
|
||||
sub_obj = field.get_value(obj)
|
||||
if sub_obj is None:
|
||||
return
|
||||
|
||||
if isinstance(sub_obj, Manager):
|
||||
sub_objs = sub_obj.all()
|
||||
|
||||
else:
|
||||
if callable(sub_obj):
|
||||
sub_obj = sub_obj()
|
||||
|
||||
sub_objs = [sub_obj]
|
||||
|
||||
for sub_obj in sub_objs:
|
||||
for sub_field in field.fields:
|
||||
yield from self.prepare_field(sub_obj, sub_field)
|
||||
|
||||
@cached_property
|
||||
def id(self):
|
||||
"""
|
||||
Returns the value to use as the ID of the record in the index
|
||||
"""
|
||||
return force_str(self.obj.pk)
|
||||
|
||||
@cached_property
|
||||
def title(self):
|
||||
"""
|
||||
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def body(self):
|
||||
"""
|
||||
Returns all values to index as "body". This is the value of all SearchFields excluding the title
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and not current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def autocomplete(self):
|
||||
"""
|
||||
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if isinstance(current_field, AutocompleteField):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
def as_vector(self, texts, for_autocomplete=False):
|
||||
"""
|
||||
Converts an array of strings into a SearchVector that can be indexed.
|
||||
"""
|
||||
texts = [(text.strip(), weight) for text, weight in texts]
|
||||
texts = [(text, weight) for text, weight in texts if text]
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, backend, db_alias=None):
|
||||
self.backend = backend
|
||||
self.name = self.backend.index_name
|
||||
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
|
||||
self.connection = connections[self.db_alias]
|
||||
if self.connection.vendor != "sqlite":
|
||||
raise NotSupportedError(
|
||||
"You must select a SQLite database " "to use the SQLite search backend."
|
||||
)
|
||||
|
||||
self.entries = IndexEntry._default_manager.using(self.db_alias)
|
||||
|
||||
def add_model(self, model):
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def _refresh_title_norms(self, full=False):
|
||||
"""
|
||||
Refreshes the value of the title_norm field.
|
||||
|
||||
This needs to be set to 'lavg/ld' where:
|
||||
- lavg is the average length of titles in all documents (also in terms)
|
||||
- ld is the length of the title field in this document (in terms)
|
||||
"""
|
||||
|
||||
lavg = (
|
||||
self.entries.annotate(title_length=Length("title"))
|
||||
.filter(title_length__gt=0)
|
||||
.aggregate(Avg("title_length"))["title_length__avg"]
|
||||
)
|
||||
|
||||
if full:
|
||||
# Update the whole table
|
||||
# This is the most accurate option but requires a full table rewrite
|
||||
# so we can't do it too often as it could lead to locking issues.
|
||||
entries = self.entries
|
||||
|
||||
else:
|
||||
# Only update entries where title_norm is 1.0
|
||||
# This is the default value set on new entries.
|
||||
# It's possible that other entries could have this exact value but there shouldn't be too many of those
|
||||
entries = self.entries.filter(title_norm=1.0)
|
||||
|
||||
entries.annotate(title_length=Length("title")).filter(
|
||||
title_length__gt=0
|
||||
).update(title_norm=lavg / F("title_length"))
|
||||
|
||||
def delete_stale_model_entries(self, model):
|
||||
existing_pks = (
|
||||
model._default_manager.using(self.db_alias)
|
||||
.annotate(object_id=Cast("pk", TextField()))
|
||||
.values("object_id")
|
||||
)
|
||||
content_types_pks = get_descendants_content_types_pks(model)
|
||||
stale_entries = self.entries.filter(
|
||||
content_type_id__in=content_types_pks
|
||||
).exclude(object_id__in=existing_pks)
|
||||
stale_entries.delete()
|
||||
|
||||
def delete_stale_entries(self):
|
||||
for model in get_indexed_models():
|
||||
# We don’t need to delete stale entries for non-root models,
|
||||
# since we already delete them by deleting roots.
|
||||
if not model._meta.parents:
|
||||
self.delete_stale_model_entries(model)
|
||||
|
||||
def add_item(self, obj):
|
||||
self.add_items(obj._meta.model, [obj])
|
||||
|
||||
def add_items_update_then_create(self, content_type_pk, indexers):
|
||||
ids_and_data = {}
|
||||
for indexer in indexers:
|
||||
ids_and_data[indexer.id] = (
|
||||
indexer.title,
|
||||
indexer.autocomplete,
|
||||
indexer.body,
|
||||
)
|
||||
|
||||
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
|
||||
indexed_ids = frozenset(
|
||||
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
|
||||
"object_id", flat=True
|
||||
)
|
||||
)
|
||||
for indexed_id in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[indexed_id]
|
||||
index_entries_for_ct.filter(object_id=indexed_id).update(
|
||||
title=title, autocomplete=autocomplete, body=body
|
||||
)
|
||||
|
||||
to_be_created = []
|
||||
for object_id in ids_and_data.keys():
|
||||
if object_id not in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[object_id]
|
||||
to_be_created.append(
|
||||
IndexEntry(
|
||||
content_type_id=content_type_pk,
|
||||
object_id=object_id,
|
||||
title=title,
|
||||
autocomplete=autocomplete,
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
self.entries.bulk_create(to_be_created)
|
||||
|
||||
self._refresh_title_norms()
|
||||
|
||||
def add_items(self, model, objs):
|
||||
search_fields = model.get_search_fields()
|
||||
if not search_fields:
|
||||
return
|
||||
|
||||
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
|
||||
|
||||
# TODO: Delete unindexed objects while dealing with proxy models.
|
||||
if indexers:
|
||||
content_type_pk = get_content_type_pk(model)
|
||||
|
||||
update_method = self.add_items_update_then_create
|
||||
update_method(content_type_pk, indexers)
|
||||
|
||||
def delete_item(self, item):
|
||||
item.index_entries.all()._raw_delete(using=self.db_alias)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class SQLiteSearchRebuilder:
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def start(self):
|
||||
self.index.delete_stale_entries()
|
||||
return self.index
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
|
||||
class SQLiteSearchAtomicRebuilder(SQLiteSearchRebuilder):
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.transaction = transaction.atomic(using=index.db_alias)
|
||||
self.transaction_opened = False
|
||||
|
||||
def start(self):
|
||||
self.transaction.__enter__()
|
||||
self.transaction_opened = True
|
||||
return super().start()
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
self.transaction.__exit__(None, None, None)
|
||||
self.transaction_opened = False
|
||||
|
||||
def __del__(self):
|
||||
# TODO: Implement a cleaner way to close the connection on failure.
|
||||
if self.transaction_opened:
|
||||
self.transaction.needs_rollback = True
|
||||
self.finish()
|
||||
|
||||
|
||||
class SQLiteSearchQueryCompiler(BaseSearchQueryCompiler):
|
||||
DEFAULT_OPERATOR = "AND"
|
||||
LAST_TERM_IS_PREFIX = False
|
||||
TARGET_SEARCH_FIELD_TYPE = SearchField
|
||||
FTS_TABLE_FIELDS = ["title", "body"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
local_search_fields = self.get_search_fields_for_model()
|
||||
|
||||
if self.fields is None:
|
||||
# search over the fields defined on the current model
|
||||
self.search_fields = local_search_fields
|
||||
else:
|
||||
# build a search_fields set from the passed definition,
|
||||
# which may involve traversing relations
|
||||
self.search_fields = {
|
||||
field_lookup: self.get_search_field(
|
||||
field_lookup, fields=local_search_fields
|
||||
)
|
||||
for field_lookup in self.fields
|
||||
}
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_searchable_search_fields()
|
||||
|
||||
def get_search_field(self, field_lookup, fields=None):
|
||||
if fields is None:
|
||||
fields = self.search_fields
|
||||
|
||||
if LOOKUP_SEP in field_lookup:
|
||||
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
|
||||
else:
|
||||
sub_field_name = None
|
||||
|
||||
for field in fields:
|
||||
if (
|
||||
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
|
||||
and field.field_name == field_lookup
|
||||
):
|
||||
return field
|
||||
|
||||
# Note: Searching on a specific related field using
|
||||
# `.search(fields=…)` is not yet supported by Wagtail.
|
||||
# This method anticipates by already implementing it.
|
||||
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
|
||||
return self.get_search_field(sub_field_name, field.fields)
|
||||
|
||||
def build_search_query_content(self, query, config=None):
|
||||
"""
|
||||
Takes a SearchQuery and returns another SearchQuery object, which can be used to construct the query in SQL.
|
||||
"""
|
||||
if isinstance(query, PlainText):
|
||||
terms = query.query_string.split()
|
||||
if not terms:
|
||||
return None
|
||||
|
||||
last_term = terms.pop()
|
||||
|
||||
lexemes = Lexeme(
|
||||
last_term, prefix=self.LAST_TERM_IS_PREFIX
|
||||
) # Combine all terms into a single lexeme.
|
||||
for term in terms:
|
||||
new_lexeme = Lexeme(term)
|
||||
|
||||
if query.operator.upper() == "AND":
|
||||
lexemes &= new_lexeme
|
||||
else:
|
||||
lexemes |= new_lexeme
|
||||
|
||||
return SearchQueryExpression(lexemes, config=config)
|
||||
|
||||
elif isinstance(query, Phrase):
|
||||
return SearchQueryExpression(query.query_string)
|
||||
|
||||
elif isinstance(query, AndNot):
|
||||
# Combine the two sub-queries into a query of the form `(first) AND NOT (second)`.
|
||||
subquery_a = self.build_search_query_content(
|
||||
query.subquery_a, config=config
|
||||
)
|
||||
subquery_b = self.build_search_query_content(
|
||||
query.subquery_b, config=config
|
||||
)
|
||||
combined_query = subquery_a._combine(subquery_b, "NOT")
|
||||
return combined_query
|
||||
|
||||
elif isinstance(query, (And, Or)):
|
||||
subquery_lexemes = [
|
||||
self.build_search_query_content(subquery, config=config)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
is_and = isinstance(query, And)
|
||||
|
||||
if is_and:
|
||||
return reduce(lambda a, b: a & b, subquery_lexemes)
|
||||
else:
|
||||
return reduce(lambda a, b: a | b, subquery_lexemes)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the SQLite search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def build_search_query(self, query, config=None):
|
||||
if isinstance(query, MatchAll):
|
||||
return query
|
||||
if isinstance(query, Not):
|
||||
unwrapped_query = query.subquery
|
||||
built_query = Not(
|
||||
self.build_search_query(unwrapped_query, config=config)
|
||||
) # We don't take the Not operator into account.
|
||||
else:
|
||||
built_query = self.build_search_query_content(query, config=config)
|
||||
return built_query
|
||||
|
||||
def build_tsrank(self, vector, query, config=None, boost=1.0):
|
||||
if isinstance(query, (Phrase, PlainText, Not)):
|
||||
rank_expression = BM25()
|
||||
|
||||
if boost != 1.0:
|
||||
rank_expression *= boost
|
||||
|
||||
return rank_expression
|
||||
|
||||
elif isinstance(query, And):
|
||||
return (
|
||||
MUL(
|
||||
1 + self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
elif isinstance(query, Or):
|
||||
return ADD(
|
||||
self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
) / (len(query.subqueries) or 1)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the SQLite search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def get_index_vectors(self):
|
||||
return [
|
||||
(F("index_entries__title"), F("index_entries__title_norm")),
|
||||
(F("index_entries__body"), 1.0),
|
||||
]
|
||||
|
||||
def get_search_vectors(self):
|
||||
return self.get_index_vectors()
|
||||
|
||||
def _build_rank_expression(self, vectors, config):
|
||||
# TODO: Come up with my own expression class that compiles down to bm25
|
||||
|
||||
rank_expressions = [
|
||||
self.build_tsrank(vector, self.query, config=config) * boost
|
||||
for vector, boost in vectors
|
||||
]
|
||||
|
||||
rank_expression = rank_expressions[0]
|
||||
for other_rank_expression in rank_expressions[1:]:
|
||||
rank_expression += other_rank_expression
|
||||
|
||||
return rank_expression
|
||||
|
||||
def search(self, config, start, stop, score_field=None):
|
||||
normalized_query = normalize(self.query)
|
||||
|
||||
if isinstance(normalized_query, MatchAll):
|
||||
return self.queryset[start:stop]
|
||||
|
||||
elif isinstance(normalized_query, Not) and isinstance(
|
||||
normalized_query.subquery, MatchAll
|
||||
):
|
||||
return self.queryset.none()
|
||||
|
||||
if isinstance(normalized_query, Not):
|
||||
normalized_query = normalized_query.subquery
|
||||
negated = True
|
||||
else:
|
||||
negated = False
|
||||
|
||||
search_query = self.build_search_query(
|
||||
normalized_query, config=config
|
||||
) # We build a search query here, for example: "%s MATCH '(hello AND world)'"
|
||||
vectors = self.get_search_vectors()
|
||||
rank_expression = self._build_rank_expression(vectors, config)
|
||||
|
||||
combined_vector = vectors[
|
||||
0
|
||||
][
|
||||
0
|
||||
] # We create a combined vector for the search results queryset. We start with the first vector and build from there.
|
||||
for vector, boost in vectors[1:]:
|
||||
combined_vector = combined_vector._combine(
|
||||
vector, " ", False
|
||||
) # We add the subsequent vectors to the combined vector.
|
||||
|
||||
# Build the FTS match expression.
|
||||
expr = MatchExpression(self.fields or self.FTS_TABLE_FIELDS, search_query)
|
||||
# Perform the FTS search. We'll get entries in the SQLiteFTSIndexEntry model.
|
||||
objs = (
|
||||
SQLiteFTSIndexEntry.objects.filter(expr)
|
||||
.select_related("index_entry")
|
||||
.filter(
|
||||
index_entry__content_type__in=get_descendants_content_types_pks(
|
||||
self.queryset.model
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.order_by_relevance:
|
||||
objs = objs.order_by(BM25().desc())
|
||||
elif not objs.query.order_by:
|
||||
# Adds a default ordering to avoid issue #3729.
|
||||
queryset = objs.order_by("-pk")
|
||||
rank_expression = F("pk")
|
||||
|
||||
from django.db import connection
|
||||
from django.db.models.sql.subqueries import InsertQuery
|
||||
|
||||
compiler = InsertQuery(IndexEntry).get_compiler(connection=connection)
|
||||
|
||||
try:
|
||||
obj_ids = [
|
||||
obj.index_entry.object_id for obj in objs
|
||||
] # Get the IDs of the objects that matched. They're stored in the IndexEntry model, so we need to get that first.
|
||||
except OperationalError as e:
|
||||
raise OperationalError(
|
||||
str(e)
|
||||
+ " The original query was: "
|
||||
+ compiler.compile(objs.query)[0]
|
||||
+ str(compiler.compile(objs.query)[1])
|
||||
) from e
|
||||
|
||||
if not negated:
|
||||
queryset = self.queryset.filter(
|
||||
id__in=obj_ids
|
||||
) # We need to filter the source queryset to get the objects that matched the search query.
|
||||
else:
|
||||
queryset = self.queryset.exclude(
|
||||
id__in=obj_ids
|
||||
) # We exclude the objects that matched the search query from the source queryset, if the query is negated.
|
||||
|
||||
if score_field is not None:
|
||||
queryset = queryset.annotate(**{score_field: rank_expression})
|
||||
|
||||
return queryset[start:stop]
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
lhs = field.get_attname(self.queryset.model) + "__" + lookup
|
||||
return Q(**{lhs: value})
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
if connector == "AND":
|
||||
q = Q(*filters)
|
||||
|
||||
elif connector == "OR":
|
||||
q = OR([Q(fil) for fil in filters])
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
if negated:
|
||||
q = ~q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class SQLiteAutocompleteQueryCompiler(SQLiteSearchQueryCompiler):
|
||||
LAST_TERM_IS_PREFIX = True
|
||||
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
|
||||
FTS_TABLE_FIELDS = ["autocomplete"]
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.autocomplete_config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_autocomplete_search_fields()
|
||||
|
||||
def get_index_vectors(self):
|
||||
return [(F("index_entries__autocomplete"), 1.0)]
|
||||
|
||||
|
||||
class SQLiteSearchResults(BaseSearchResults):
|
||||
def get_queryset(self, for_count=False):
|
||||
if for_count:
|
||||
start = None
|
||||
stop = None
|
||||
else:
|
||||
start = self.start
|
||||
stop = self.stop
|
||||
|
||||
return self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend),
|
||||
start,
|
||||
stop,
|
||||
score_field=self._score_field,
|
||||
)
|
||||
|
||||
def _do_search(self):
|
||||
return list(self.get_queryset())
|
||||
|
||||
def _do_count(self):
|
||||
return self.get_queryset(for_count=True).count()
|
||||
|
||||
supports_facet = True
|
||||
|
||||
def facet(self, field_name):
|
||||
# Get field
|
||||
field = self.query_compiler._get_filterable_field(field_name)
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot facet search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.query_compiler.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
query = self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend), None, None
|
||||
)
|
||||
results = (
|
||||
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
|
||||
)
|
||||
|
||||
return OrderedDict(
|
||||
[(result[field_name], result["count"]) for result in results]
|
||||
)
|
||||
|
||||
|
||||
class SQLiteSearchBackend(BaseSearchBackend):
|
||||
query_compiler_class = SQLiteSearchQueryCompiler
|
||||
autocomplete_query_compiler_class = SQLiteAutocompleteQueryCompiler
|
||||
|
||||
results_class = SQLiteSearchResults
|
||||
rebuilder_class = SQLiteSearchRebuilder
|
||||
atomic_rebuilder_class = SQLiteSearchAtomicRebuilder
|
||||
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
self.index_name = params.get("INDEX", "default")
|
||||
|
||||
# SQLite backend currently has no config options
|
||||
self.config = None
|
||||
self.autocomplete_config = None
|
||||
|
||||
if params.get("ATOMIC_REBUILD", False):
|
||||
self.rebuilder_class = self.atomic_rebuilder_class
|
||||
|
||||
def get_index_for_model(self, model, db_alias=None):
|
||||
return Index(self, db_alias)
|
||||
|
||||
def get_index_for_object(self, obj):
|
||||
return self.get_index_for_model(obj._meta.model, obj._state.db)
|
||||
|
||||
def reset_index(self):
|
||||
for connection in [
|
||||
connection
|
||||
for connection in connections.all()
|
||||
if connection.vendor == "sqlite"
|
||||
]:
|
||||
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
|
||||
|
||||
def add_type(self, model):
|
||||
pass # Not needed.
|
||||
|
||||
def refresh_index(self):
|
||||
pass # Not needed.
|
||||
|
||||
def add(self, obj):
|
||||
self.get_index_for_object(obj).add_item(obj)
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
if obj_list:
|
||||
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
|
||||
|
||||
def delete(self, obj):
|
||||
self.get_index_for_object(obj).delete_item(obj)
|
||||
|
||||
|
||||
SearchBackend = SQLiteSearchBackend
|
||||
35
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/utils.py
vendored
Normal file
35
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/utils.py
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
import sqlite3
|
||||
|
||||
from django.db import OperationalError
|
||||
|
||||
|
||||
def fts5_available():
|
||||
# based on https://stackoverflow.com/a/36656216/1853523
|
||||
if sqlite3.sqlite_version_info < (3, 19, 0):
|
||||
# Prior to version 3.19, SQLite doesn't support FTS5 queries with
|
||||
# column filters ('{column_1 column_2} : query'), which the sqlite
|
||||
# fulltext backend needs
|
||||
return False
|
||||
|
||||
tmp_db = sqlite3.connect(":memory:")
|
||||
try:
|
||||
tmp_db.execute("CREATE VIRTUAL TABLE fts5test USING fts5 (data);")
|
||||
except sqlite3.OperationalError:
|
||||
return False
|
||||
finally:
|
||||
tmp_db.close()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def fts_table_exists():
|
||||
from wagtail.search.models import SQLiteFTSIndexEntry
|
||||
|
||||
try:
|
||||
# ignore result of query; we are only interested in the query failing,
|
||||
# not the presence of index entries
|
||||
SQLiteFTSIndexEntry.objects.exists()
|
||||
except OperationalError:
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user