Initial commit
This commit is contained in:
113
env/lib/python3.10/site-packages/wagtail/search/backends/__init__.py
vendored
Normal file
113
env/lib/python3.10/site-packages/wagtail/search/backends/__init__.py
vendored
Normal file
@@ -0,0 +1,113 @@
|
||||
# Backend loading
|
||||
# Based on the Django cache framework
|
||||
# https://github.com/django/django/blob/5d263dee304fdaf95e18d2f0619d6925984a7f02/django/core/cache/__init__.py
|
||||
|
||||
from importlib import import_module
|
||||
|
||||
from django.conf import settings
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from django.utils.module_loading import import_string
|
||||
|
||||
|
||||
class InvalidSearchBackendError(ImproperlyConfigured):
|
||||
pass
|
||||
|
||||
|
||||
def get_search_backend_config():
|
||||
search_backends = getattr(settings, "WAGTAILSEARCH_BACKENDS", {})
|
||||
|
||||
# Make sure the default backend is always defined
|
||||
search_backends.setdefault(
|
||||
"default",
|
||||
{
|
||||
"BACKEND": "wagtail.search.backends.database",
|
||||
},
|
||||
)
|
||||
|
||||
return search_backends
|
||||
|
||||
|
||||
def import_backend(dotted_path):
|
||||
"""
|
||||
There's two formats for the dotted_path.
|
||||
One with the backend class (old) and one without (new)
|
||||
eg:
|
||||
old: wagtail.search.backends.elasticsearch.ElasticsearchSearchBackend
|
||||
new: wagtail.search.backends.elasticsearch
|
||||
|
||||
If a new style dotted path was specified, this function would
|
||||
look for a backend class from the "SearchBackend" attribute.
|
||||
"""
|
||||
try:
|
||||
# New
|
||||
backend_module = import_module(dotted_path)
|
||||
return backend_module.SearchBackend
|
||||
except ImportError as e:
|
||||
try:
|
||||
# Old
|
||||
return import_string(dotted_path)
|
||||
except ImportError:
|
||||
raise ImportError from e
|
||||
|
||||
|
||||
def get_search_backend(backend="default", **kwargs):
|
||||
search_backends = get_search_backend_config()
|
||||
|
||||
# Try to find the backend
|
||||
try:
|
||||
# Try to get the WAGTAILSEARCH_BACKENDS entry for the given backend name first
|
||||
conf = search_backends[backend]
|
||||
except KeyError:
|
||||
try:
|
||||
# Trying to import the given backend, in case it's a dotted path
|
||||
import_backend(backend)
|
||||
except ImportError as e:
|
||||
raise InvalidSearchBackendError(f"Could not find backend '{backend}': {e}")
|
||||
params = kwargs
|
||||
else:
|
||||
# Backend is a conf entry
|
||||
params = conf.copy()
|
||||
params.update(kwargs)
|
||||
backend = params.pop("BACKEND")
|
||||
|
||||
# Try to import the backend
|
||||
try:
|
||||
backend_cls = import_backend(backend)
|
||||
except ImportError as e:
|
||||
raise InvalidSearchBackendError(f"Could not find backend '{backend}': {e}")
|
||||
|
||||
# Create backend
|
||||
return backend_cls(params)
|
||||
|
||||
|
||||
def _backend_requires_auto_update(backend_name, params):
|
||||
if params.get("AUTO_UPDATE", True):
|
||||
return True
|
||||
|
||||
# _WAGTAILSEARCH_FORCE_AUTO_UPDATE is only used by Wagtail tests. It allows
|
||||
# us to test AUTO_UPDATE behaviour against Elasticsearch without having to
|
||||
# have AUTO_UPDATE enabed for every test.
|
||||
force_auto_update = getattr(settings, "_WAGTAILSEARCH_FORCE_AUTO_UPDATE", [])
|
||||
if backend_name in force_auto_update:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_search_backends_with_name(with_auto_update=False):
|
||||
search_backends = get_search_backend_config()
|
||||
for backend, params in search_backends.items():
|
||||
if with_auto_update and _backend_requires_auto_update(backend, params) is False:
|
||||
continue
|
||||
|
||||
yield backend, get_search_backend(backend)
|
||||
|
||||
|
||||
def get_search_backends(with_auto_update=False):
|
||||
# For backwards compatibility
|
||||
return (
|
||||
backend
|
||||
for _, backend in get_search_backends_with_name(
|
||||
with_auto_update=with_auto_update
|
||||
)
|
||||
)
|
||||
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/__init__.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/base.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/base.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/elasticsearch7.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/elasticsearch7.cpython-310.pyc
vendored
Normal file
Binary file not shown.
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/elasticsearch8.cpython-310.pyc
vendored
Normal file
BIN
env/lib/python3.10/site-packages/wagtail/search/backends/__pycache__/elasticsearch8.cpython-310.pyc
vendored
Normal file
Binary file not shown.
529
env/lib/python3.10/site-packages/wagtail/search/backends/base.py
vendored
Normal file
529
env/lib/python3.10/site-packages/wagtail/search/backends/base.py
vendored
Normal file
@@ -0,0 +1,529 @@
|
||||
import datetime
|
||||
from warnings import warn
|
||||
|
||||
from django.db.models.functions.datetime import Extract as ExtractDate
|
||||
from django.db.models.functions.datetime import ExtractYear
|
||||
from django.db.models.lookups import Lookup
|
||||
from django.db.models.query import QuerySet
|
||||
from django.db.models.sql.where import SubqueryConstraint, WhereNode
|
||||
|
||||
from wagtail.search.index import class_is_indexed, get_indexed_models
|
||||
from wagtail.search.query import MATCH_ALL, PlainText
|
||||
|
||||
|
||||
class FilterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FieldError(Exception):
|
||||
def __init__(self, *args, field_name=None, **kwargs):
|
||||
self.field_name = field_name
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SearchFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
class FilterFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
class OrderByFieldError(FieldError):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSearchQueryCompiler:
|
||||
DEFAULT_OPERATOR = "or"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
queryset,
|
||||
query,
|
||||
fields=None,
|
||||
operator=None,
|
||||
order_by_relevance=True,
|
||||
):
|
||||
self.queryset = queryset
|
||||
if query is None:
|
||||
warn(
|
||||
"Querying `None` is deprecated, use `MATCH_ALL` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
query = MATCH_ALL
|
||||
elif isinstance(query, str):
|
||||
query = PlainText(query, operator=operator or self.DEFAULT_OPERATOR)
|
||||
self.query = query
|
||||
self.fields = fields
|
||||
self.order_by_relevance = order_by_relevance
|
||||
|
||||
def _get_filterable_field(self, field_attname):
|
||||
# Get field
|
||||
field = {
|
||||
field.get_attname(self.queryset.model): field
|
||||
for field in self.queryset.model.get_filterable_search_fields()
|
||||
}.get(field_attname, None)
|
||||
|
||||
return field
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
raise NotImplementedError
|
||||
|
||||
def _process_filter(self, field_attname, lookup, value, check_only=False):
|
||||
# Get the field
|
||||
field = self._get_filterable_field(field_attname)
|
||||
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot filter search results with field "'
|
||||
+ field_attname
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_attname
|
||||
+ "') to "
|
||||
+ self.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_attname,
|
||||
)
|
||||
|
||||
# Process the lookup
|
||||
if not check_only:
|
||||
result = self._process_lookup(field, lookup, value)
|
||||
|
||||
if result is None:
|
||||
raise FilterError(
|
||||
'Could not apply filter on search results: "'
|
||||
+ field_attname
|
||||
+ "__"
|
||||
+ lookup
|
||||
+ " = "
|
||||
+ str(value)
|
||||
+ '". Lookup "'
|
||||
+ lookup
|
||||
+ '"" not recognised.'
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _get_filters_from_where_node(self, where_node, check_only=False):
|
||||
# Check if this is a leaf node
|
||||
if isinstance(where_node, Lookup):
|
||||
if isinstance(where_node.lhs, ExtractDate):
|
||||
if not isinstance(where_node.lhs, ExtractYear):
|
||||
raise FilterError(
|
||||
'Cannot apply filter on search results: "'
|
||||
+ where_node.lhs.lookup_name
|
||||
+ '" queries are not supported.'
|
||||
)
|
||||
else:
|
||||
field_attname = where_node.lhs.lhs.target.attname
|
||||
lookup = where_node.lookup_name
|
||||
if lookup == "gte":
|
||||
# filter on year(date) >= value
|
||||
# i.e. date >= Jan 1st of that year
|
||||
value = datetime.date(int(where_node.rhs), 1, 1)
|
||||
elif lookup == "gt":
|
||||
# filter on year(date) > value
|
||||
# i.e. date >= Jan 1st of the next year
|
||||
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
|
||||
lookup = "gte"
|
||||
elif lookup == "lte":
|
||||
# filter on year(date) <= value
|
||||
# i.e. date < Jan 1st of the next year
|
||||
value = datetime.date(int(where_node.rhs) + 1, 1, 1)
|
||||
lookup = "lt"
|
||||
elif lookup == "lt":
|
||||
# filter on year(date) < value
|
||||
# i.e. date < Jan 1st of that year
|
||||
value = datetime.date(int(where_node.rhs), 1, 1)
|
||||
elif lookup == "exact":
|
||||
# filter on year(date) == value
|
||||
# i.e. date >= Jan 1st of that year and date < Jan 1st of the next year
|
||||
filter1 = self._process_filter(
|
||||
field_attname,
|
||||
"gte",
|
||||
datetime.date(int(where_node.rhs), 1, 1),
|
||||
check_only=check_only,
|
||||
)
|
||||
filter2 = self._process_filter(
|
||||
field_attname,
|
||||
"lt",
|
||||
datetime.date(int(where_node.rhs) + 1, 1, 1),
|
||||
check_only=check_only,
|
||||
)
|
||||
if check_only:
|
||||
return
|
||||
else:
|
||||
return self._connect_filters(
|
||||
[filter1, filter2], "AND", False
|
||||
)
|
||||
else:
|
||||
raise FilterError(
|
||||
'Cannot apply filter on search results: "'
|
||||
+ where_node.lhs.lookup_name
|
||||
+ '" queries are not supported.'
|
||||
)
|
||||
else:
|
||||
field_attname = where_node.lhs.target.attname
|
||||
lookup = where_node.lookup_name
|
||||
value = where_node.rhs
|
||||
|
||||
# Ignore pointer fields that show up in specific page type queries
|
||||
if field_attname.endswith("_ptr_id"):
|
||||
return
|
||||
|
||||
# Process the filter
|
||||
return self._process_filter(
|
||||
field_attname, lookup, value, check_only=check_only
|
||||
)
|
||||
|
||||
elif isinstance(where_node, SubqueryConstraint):
|
||||
raise FilterError(
|
||||
"Could not apply filter on search results: Subqueries are not allowed."
|
||||
)
|
||||
|
||||
elif isinstance(where_node, WhereNode):
|
||||
# Get child filters
|
||||
connector = where_node.connector
|
||||
child_filters = [
|
||||
self._get_filters_from_where_node(child)
|
||||
for child in where_node.children
|
||||
]
|
||||
|
||||
if not check_only:
|
||||
child_filters = [
|
||||
child_filter for child_filter in child_filters if child_filter
|
||||
]
|
||||
return self._connect_filters(
|
||||
child_filters, connector, where_node.negated
|
||||
)
|
||||
|
||||
else:
|
||||
raise FilterError(
|
||||
"Could not apply filter on search results: Unknown where node: "
|
||||
+ str(type(where_node))
|
||||
)
|
||||
|
||||
def _get_filters_from_queryset(self):
|
||||
return self._get_filters_from_where_node(self.queryset.query.where)
|
||||
|
||||
def _get_order_by(self):
|
||||
if self.order_by_relevance:
|
||||
return
|
||||
|
||||
for field_name in self.queryset.query.order_by:
|
||||
reverse = False
|
||||
|
||||
if field_name.startswith("-"):
|
||||
reverse = True
|
||||
field_name = field_name[1:]
|
||||
|
||||
field = self._get_filterable_field(field_name)
|
||||
|
||||
if field is None:
|
||||
raise OrderByFieldError(
|
||||
'Cannot sort search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
yield reverse, field
|
||||
|
||||
def check(self):
|
||||
# Check search fields
|
||||
if self.fields:
|
||||
allowed_fields = {
|
||||
field.field_name
|
||||
for field in self.queryset.model.get_searchable_search_fields()
|
||||
}
|
||||
|
||||
for field_name in self.fields:
|
||||
if field_name not in allowed_fields:
|
||||
raise SearchFieldError(
|
||||
'Cannot search with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.SearchField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
# Check where clause
|
||||
# Raises FilterFieldError if an unindexed field is being filtered on
|
||||
self._get_filters_from_where_node(self.queryset.query.where, check_only=True)
|
||||
|
||||
# Check order by
|
||||
# Raises OrderByFieldError if an unindexed field is being used to order by
|
||||
list(self._get_order_by())
|
||||
|
||||
|
||||
class BaseSearchResults:
|
||||
supports_facet = False
|
||||
|
||||
def __init__(self, backend, query_compiler, prefetch_related=None):
|
||||
self.backend = backend
|
||||
self.query_compiler = query_compiler
|
||||
self.prefetch_related = prefetch_related
|
||||
self.start = 0
|
||||
self.stop = None
|
||||
self._results_cache = None
|
||||
self._count_cache = None
|
||||
self._score_field = None
|
||||
|
||||
def _set_limits(self, start=None, stop=None):
|
||||
if stop is not None:
|
||||
if self.stop is not None:
|
||||
self.stop = min(self.stop, self.start + stop)
|
||||
else:
|
||||
self.stop = self.start + stop
|
||||
|
||||
if start is not None:
|
||||
if self.stop is not None:
|
||||
self.start = min(self.stop, self.start + start)
|
||||
else:
|
||||
self.start = self.start + start
|
||||
|
||||
def _clone(self):
|
||||
klass = self.__class__
|
||||
new = klass(
|
||||
self.backend, self.query_compiler, prefetch_related=self.prefetch_related
|
||||
)
|
||||
new.start = self.start
|
||||
new.stop = self.stop
|
||||
new._score_field = self._score_field
|
||||
return new
|
||||
|
||||
def _do_search(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _do_count(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def results(self):
|
||||
if self._results_cache is None:
|
||||
self._results_cache = list(self._do_search())
|
||||
return self._results_cache
|
||||
|
||||
def count(self):
|
||||
if self._count_cache is None:
|
||||
if self._results_cache is not None:
|
||||
self._count_cache = len(self._results_cache)
|
||||
else:
|
||||
self._count_cache = self._do_count()
|
||||
return self._count_cache
|
||||
|
||||
def __getitem__(self, key):
|
||||
new = self._clone()
|
||||
|
||||
if isinstance(key, slice):
|
||||
# Set limits
|
||||
start = int(key.start) if key.start is not None else None
|
||||
stop = int(key.stop) if key.stop is not None else None
|
||||
new._set_limits(start, stop)
|
||||
|
||||
# Copy results cache
|
||||
if self._results_cache is not None:
|
||||
new._results_cache = self._results_cache[key]
|
||||
|
||||
return new
|
||||
else:
|
||||
if self._results_cache is not None:
|
||||
return self._results_cache[key]
|
||||
|
||||
new.start = self.start + key
|
||||
new.stop = self.start + key + 1
|
||||
return list(new)[0]
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.results())
|
||||
|
||||
def __len__(self):
|
||||
return len(self.results())
|
||||
|
||||
def __repr__(self):
|
||||
data = list(self[:21])
|
||||
if len(data) > 20:
|
||||
data[-1] = "...(remaining elements truncated)..."
|
||||
return "<SearchResults %r>" % data
|
||||
|
||||
def annotate_score(self, field_name):
|
||||
clone = self._clone()
|
||||
clone._score_field = field_name
|
||||
return clone
|
||||
|
||||
def facet(self, field_name):
|
||||
raise NotImplementedError("This search backend does not support faceting")
|
||||
|
||||
|
||||
class EmptySearchResults(BaseSearchResults):
|
||||
def __init__(self):
|
||||
super().__init__(None, None)
|
||||
|
||||
def _clone(self):
|
||||
return self.__class__()
|
||||
|
||||
def _do_search(self):
|
||||
return []
|
||||
|
||||
def _do_count(self):
|
||||
return 0
|
||||
|
||||
|
||||
class NullIndex:
|
||||
"""
|
||||
Index class that provides do-nothing implementations of the indexing operations required by
|
||||
BaseSearchBackend. Use this for search backends that do not maintain an index, such as the
|
||||
database backend.
|
||||
"""
|
||||
|
||||
def add_model(self, model):
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def add_item(self, item):
|
||||
pass
|
||||
|
||||
def add_items(self, model, items):
|
||||
pass
|
||||
|
||||
def delete_item(self, item):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSearchBackend:
|
||||
query_compiler_class = None
|
||||
autocomplete_query_compiler_class = None
|
||||
results_class = None
|
||||
rebuilder_class = None
|
||||
catch_indexing_errors = False
|
||||
|
||||
def __init__(self, params):
|
||||
pass
|
||||
|
||||
def get_index_for_model(self, model):
|
||||
return NullIndex()
|
||||
|
||||
def get_rebuilder(self):
|
||||
return None
|
||||
|
||||
def reset_index(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_type(self, model):
|
||||
self.get_index_for_model(model).add_model(model)
|
||||
|
||||
def refresh_index(self):
|
||||
refreshed_indexes = []
|
||||
for model in get_indexed_models():
|
||||
index = self.get_index_for_model(model)
|
||||
if index not in refreshed_indexes:
|
||||
index.refresh()
|
||||
refreshed_indexes.append(index)
|
||||
|
||||
def add(self, obj):
|
||||
self.get_index_for_model(type(obj)).add_item(obj)
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
self.get_index_for_model(model).add_items(model, obj_list)
|
||||
|
||||
def delete(self, obj):
|
||||
self.get_index_for_model(type(obj)).delete_item(obj)
|
||||
|
||||
def _search(self, query_compiler_class, query, model_or_queryset, **kwargs):
|
||||
# Find model/queryset
|
||||
if isinstance(model_or_queryset, QuerySet):
|
||||
model = model_or_queryset.model
|
||||
queryset = model_or_queryset
|
||||
else:
|
||||
model = model_or_queryset
|
||||
queryset = model_or_queryset.objects.all()
|
||||
|
||||
# Model must be a class that is in the index
|
||||
if not class_is_indexed(model):
|
||||
return EmptySearchResults()
|
||||
|
||||
# Check that there's still a query string after the clean up
|
||||
if query == "":
|
||||
return EmptySearchResults()
|
||||
|
||||
# Search
|
||||
search_query_compiler = query_compiler_class(queryset, query, **kwargs)
|
||||
|
||||
# Check the query
|
||||
search_query_compiler.check()
|
||||
|
||||
return self.results_class(self, search_query_compiler)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query,
|
||||
model_or_queryset,
|
||||
fields=None,
|
||||
operator=None,
|
||||
order_by_relevance=True,
|
||||
):
|
||||
return self._search(
|
||||
self.query_compiler_class,
|
||||
query,
|
||||
model_or_queryset,
|
||||
fields=fields,
|
||||
operator=operator,
|
||||
order_by_relevance=order_by_relevance,
|
||||
)
|
||||
|
||||
def autocomplete(
|
||||
self,
|
||||
query,
|
||||
model_or_queryset,
|
||||
fields=None,
|
||||
operator=None,
|
||||
order_by_relevance=True,
|
||||
):
|
||||
if self.autocomplete_query_compiler_class is None:
|
||||
raise NotImplementedError(
|
||||
"This search backend does not support the autocomplete API"
|
||||
)
|
||||
|
||||
return self._search(
|
||||
self.autocomplete_query_compiler_class,
|
||||
query,
|
||||
model_or_queryset,
|
||||
fields=fields,
|
||||
operator=operator,
|
||||
order_by_relevance=order_by_relevance,
|
||||
)
|
||||
|
||||
|
||||
def get_model_root(model):
|
||||
"""
|
||||
This function finds the root model for any given model. The root model is
|
||||
the highest concrete model that it descends from. If the model doesn't
|
||||
descend from another concrete model then the model is it's own root model so
|
||||
it is returned.
|
||||
|
||||
Examples:
|
||||
>>> get_model_root(wagtailcore.Page)
|
||||
wagtailcore.Page
|
||||
|
||||
>>> get_model_root(myapp.HomePage)
|
||||
wagtailcore.Page
|
||||
|
||||
>>> get_model_root(wagtailimages.Image)
|
||||
wagtailimages.Image
|
||||
"""
|
||||
if model._meta.parents:
|
||||
parent_model = list(model._meta.parents.items())[0][0]
|
||||
return get_model_root(parent_model)
|
||||
|
||||
return model
|
||||
50
env/lib/python3.10/site-packages/wagtail/search/backends/database/__init__.py
vendored
Normal file
50
env/lib/python3.10/site-packages/wagtail/search/backends/database/__init__.py
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
import warnings
|
||||
|
||||
from django.db import connection
|
||||
|
||||
USE_SQLITE_FTS = None # True if sqlite FTS is available, False if not, None if untested
|
||||
|
||||
|
||||
def SearchBackend(params):
|
||||
"""
|
||||
Returns the appropriate search backend for the current 'default' database system
|
||||
"""
|
||||
if connection.vendor == "postgresql":
|
||||
from .postgres.postgres import PostgresSearchBackend
|
||||
|
||||
return PostgresSearchBackend(params)
|
||||
elif connection.vendor == "mysql":
|
||||
from .mysql.mysql import MySQLSearchBackend
|
||||
|
||||
return MySQLSearchBackend(params)
|
||||
elif connection.vendor == "sqlite":
|
||||
global USE_SQLITE_FTS
|
||||
|
||||
if USE_SQLITE_FTS is None:
|
||||
from .sqlite.utils import fts5_available, fts_table_exists
|
||||
|
||||
if not fts5_available():
|
||||
USE_SQLITE_FTS = False
|
||||
elif not fts_table_exists():
|
||||
USE_SQLITE_FTS = False
|
||||
warnings.warn(
|
||||
"The installed SQLite library supports full-text search, but the table for storing "
|
||||
"searchable content is missing. This probably means SQLite was upgraded after the "
|
||||
"migration was applied. To enable full-text search, reapply wagtailsearch migration 0006 "
|
||||
"or create the table manually."
|
||||
)
|
||||
else:
|
||||
USE_SQLITE_FTS = True
|
||||
|
||||
if USE_SQLITE_FTS:
|
||||
from .sqlite.sqlite import SQLiteSearchBackend
|
||||
|
||||
return SQLiteSearchBackend(params)
|
||||
else:
|
||||
from .fallback import DatabaseSearchBackend
|
||||
|
||||
return DatabaseSearchBackend(params)
|
||||
else:
|
||||
from .fallback import DatabaseSearchBackend
|
||||
|
||||
return DatabaseSearchBackend(params)
|
||||
Binary file not shown.
Binary file not shown.
249
env/lib/python3.10/site-packages/wagtail/search/backends/database/fallback.py
vendored
Normal file
249
env/lib/python3.10/site-packages/wagtail/search/backends/database/fallback.py
vendored
Normal file
@@ -0,0 +1,249 @@
|
||||
from collections import OrderedDict
|
||||
from warnings import warn
|
||||
|
||||
from django.core.exceptions import FieldDoesNotExist
|
||||
from django.db import models
|
||||
from django.db.models import Count
|
||||
from django.db.models.expressions import Value
|
||||
|
||||
from wagtail.search.backends.base import (
|
||||
BaseSearchBackend,
|
||||
BaseSearchQueryCompiler,
|
||||
BaseSearchResults,
|
||||
FilterFieldError,
|
||||
)
|
||||
from wagtail.search.query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
|
||||
from wagtail.search.utils import AND, OR
|
||||
|
||||
# This file implements a database search backend using basic substring matching, and no
|
||||
# database-specific full-text search capabilities. It will be used in the following cases:
|
||||
# * The current default database is SQLite <3.19, or SQLite built without fulltext
|
||||
# extensions, or something other than PostgreSQL, MySQL or SQLite
|
||||
# * 'wagtail.search.backends.database.fallback' is specified directly as the search backend
|
||||
|
||||
|
||||
MATCH_ALL = "_ALL_"
|
||||
MATCH_NONE = "_NONE_"
|
||||
|
||||
|
||||
class DatabaseSearchQueryCompiler(BaseSearchQueryCompiler):
|
||||
DEFAULT_OPERATOR = "and"
|
||||
OPERATORS = {
|
||||
"and": AND,
|
||||
"or": OR,
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.fields_names = list(self.get_fields_names())
|
||||
|
||||
def get_fields_names(self):
|
||||
model = self.queryset.model
|
||||
fields_names = self.fields or [
|
||||
field.field_name for field in model.get_searchable_search_fields()
|
||||
]
|
||||
# Check if the field exists (this will filter out indexed callables)
|
||||
for field_name in fields_names:
|
||||
try:
|
||||
model._meta.get_field(field_name)
|
||||
except FieldDoesNotExist:
|
||||
continue
|
||||
else:
|
||||
yield field_name
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
return models.Q(
|
||||
**{field.get_attname(self.queryset.model) + "__" + lookup: value}
|
||||
)
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
if connector == "AND":
|
||||
q = models.Q(*filters)
|
||||
elif connector == "OR":
|
||||
q = OR([models.Q(fil) for fil in filters])
|
||||
else:
|
||||
return
|
||||
|
||||
if negated:
|
||||
q = ~q
|
||||
|
||||
return q
|
||||
|
||||
def build_single_term_filter(self, term):
|
||||
term_query = models.Q()
|
||||
for field_name in self.fields_names:
|
||||
term_query |= models.Q(**{field_name + "__icontains": term})
|
||||
return term_query
|
||||
|
||||
def check_boost(self, query, boost=1.0):
|
||||
if query.boost * boost != 1.0:
|
||||
warn("Database search backend does not support term boosting.")
|
||||
|
||||
def build_database_filter(self, query, boost=1.0):
|
||||
if isinstance(query, PlainText):
|
||||
self.check_boost(query, boost=boost)
|
||||
|
||||
operator = self.OPERATORS[query.operator]
|
||||
|
||||
return operator(
|
||||
[
|
||||
self.build_single_term_filter(term)
|
||||
for term in query.query_string.split()
|
||||
]
|
||||
)
|
||||
|
||||
if isinstance(query, Phrase):
|
||||
q = models.Q()
|
||||
for field_name in self.fields_names:
|
||||
q |= models.Q(**{field_name + "__icontains": query.query_string})
|
||||
return q
|
||||
|
||||
if isinstance(query, Boost):
|
||||
boost *= query.boost
|
||||
return self.build_database_filter(query.subquery, boost=boost)
|
||||
|
||||
if isinstance(query, MatchAll):
|
||||
return MATCH_ALL
|
||||
|
||||
if isinstance(query, Not):
|
||||
q = self.build_database_filter(query.subquery, boost=boost)
|
||||
|
||||
if q == MATCH_ALL:
|
||||
return MATCH_NONE
|
||||
|
||||
elif q == MATCH_NONE:
|
||||
return MATCH_ALL
|
||||
|
||||
else:
|
||||
return ~q
|
||||
|
||||
if isinstance(query, And):
|
||||
subqueries = [
|
||||
self.build_database_filter(subquery, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
# If there's a MATCH_NONE, return MATCH_NONE
|
||||
if MATCH_NONE in subqueries:
|
||||
return MATCH_NONE
|
||||
|
||||
# Ignore MATCH_ALL
|
||||
subqueries = [q for q in subqueries if q != MATCH_ALL]
|
||||
|
||||
return AND(subqueries)
|
||||
|
||||
if isinstance(query, Or):
|
||||
subqueries = [
|
||||
self.build_database_filter(subquery, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
# If there's a MATCH_ALL, return MATCH_ALL
|
||||
if MATCH_ALL in subqueries:
|
||||
return MATCH_ALL
|
||||
|
||||
# Ignore MATCH_NONE
|
||||
subqueries = [q for q in subqueries if q != MATCH_NONE]
|
||||
|
||||
return OR(subqueries)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the database search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
|
||||
class DatabaseAutocompleteQueryCompiler(DatabaseSearchQueryCompiler):
|
||||
# The fallback backend doesn't handle word boundaries, so standard searches are
|
||||
# essentially equivalent to autocomplete queries anyhow. However, to provide a
|
||||
# consistent API with other backends, we provide both endpoints.
|
||||
pass
|
||||
|
||||
|
||||
class DatabaseSearchResults(BaseSearchResults):
|
||||
iterator_chunk_size = 2000
|
||||
|
||||
def get_queryset(self):
|
||||
queryset = self.query_compiler.queryset
|
||||
|
||||
# Run _get_filters_from_queryset to test that no fields that are not
|
||||
# a FilterField have been used in the query.
|
||||
self.query_compiler._get_filters_from_queryset()
|
||||
|
||||
q = self.query_compiler.build_database_filter(self.query_compiler.query)
|
||||
|
||||
if q == MATCH_ALL:
|
||||
pass
|
||||
elif q == MATCH_NONE:
|
||||
queryset = queryset.none()
|
||||
else:
|
||||
queryset = queryset.filter(q)
|
||||
|
||||
return queryset.distinct()[self.start : self.stop]
|
||||
|
||||
def _do_search(self):
|
||||
queryset = self.get_queryset()
|
||||
|
||||
if self._score_field:
|
||||
queryset = queryset.annotate(
|
||||
**{self._score_field: Value(None, output_field=models.FloatField())}
|
||||
)
|
||||
|
||||
return queryset.iterator(self.iterator_chunk_size)
|
||||
|
||||
def _do_count(self):
|
||||
return self.get_queryset().count()
|
||||
|
||||
supports_facet = True
|
||||
|
||||
def facet(self, field_name):
|
||||
# Get field
|
||||
field = self.query_compiler._get_filterable_field(field_name)
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot facet search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.query_compiler.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
query = self.get_queryset()
|
||||
results = (
|
||||
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
|
||||
)
|
||||
|
||||
return OrderedDict(
|
||||
[(result[field_name], result["count"]) for result in results]
|
||||
)
|
||||
|
||||
|
||||
class DatabaseSearchBackend(BaseSearchBackend):
|
||||
query_compiler_class = DatabaseSearchQueryCompiler
|
||||
autocomplete_query_compiler_class = DatabaseSearchQueryCompiler
|
||||
results_class = DatabaseSearchResults
|
||||
|
||||
def reset_index(self):
|
||||
pass # Not needed
|
||||
|
||||
def add_type(self, model):
|
||||
pass # Not needed
|
||||
|
||||
def refresh_index(self):
|
||||
pass # Not needed
|
||||
|
||||
def add(self, obj):
|
||||
pass # Not needed
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
return # Not needed
|
||||
|
||||
def delete(self, obj):
|
||||
pass # Not needed
|
||||
|
||||
|
||||
# This line allows using 'wagtail.search.backends.database.fallback' as the backend, bypassing the automatic selection of the backend that would get run if the user chose 'wagtail.search.backends.database'
|
||||
SearchBackend = DatabaseSearchBackend
|
||||
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/__init__.py
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
684
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/mysql.py
vendored
Normal file
684
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/mysql.py
vendored
Normal file
@@ -0,0 +1,684 @@
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
||||
from django.db.models import Case, When
|
||||
from django.db.models.aggregates import Avg, Count
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.expressions import F
|
||||
from django.db.models.fields import BooleanField, FloatField, TextField
|
||||
from django.db.models.functions.comparison import Cast
|
||||
from django.db.models.functions.text import Length
|
||||
from django.db.models.manager import Manager
|
||||
from django.db.models.query_utils import Q
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from wagtail.search.backends.base import (
|
||||
BaseSearchBackend,
|
||||
BaseSearchQueryCompiler,
|
||||
BaseSearchResults,
|
||||
FilterFieldError,
|
||||
)
|
||||
from wagtail.search.backends.database.mysql.query import (
|
||||
Lexeme,
|
||||
MatchExpression,
|
||||
SearchQuery,
|
||||
)
|
||||
from wagtail.search.index import (
|
||||
AutocompleteField,
|
||||
RelatedFields,
|
||||
SearchField,
|
||||
get_indexed_models,
|
||||
)
|
||||
from wagtail.search.models import IndexEntry
|
||||
from wagtail.search.query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
|
||||
from wagtail.search.utils import (
|
||||
OR,
|
||||
balanced_reduce,
|
||||
get_content_type_pk,
|
||||
get_descendants_content_types_pks,
|
||||
)
|
||||
|
||||
|
||||
class ObjectIndexer:
|
||||
"""
|
||||
Responsible for extracting data from an object to be inserted into the index.
|
||||
"""
|
||||
|
||||
def __init__(self, obj, backend):
|
||||
self.obj = obj
|
||||
self.search_fields = obj.get_search_fields()
|
||||
self.config = backend.config
|
||||
|
||||
def prepare_value(self, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
elif isinstance(value, list):
|
||||
return ", ".join(self.prepare_value(item) for item in value)
|
||||
|
||||
elif isinstance(value, dict):
|
||||
return ", ".join(self.prepare_value(item) for item in value.values())
|
||||
|
||||
return force_str(value)
|
||||
|
||||
def prepare_field(self, obj, field):
|
||||
if isinstance(field, SearchField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, AutocompleteField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, RelatedFields):
|
||||
sub_obj = field.get_value(obj)
|
||||
if sub_obj is None:
|
||||
return
|
||||
|
||||
if isinstance(sub_obj, Manager):
|
||||
sub_objs = sub_obj.all()
|
||||
|
||||
else:
|
||||
if callable(sub_obj):
|
||||
sub_obj = sub_obj()
|
||||
|
||||
sub_objs = [sub_obj]
|
||||
|
||||
for sub_obj in sub_objs:
|
||||
for sub_field in field.fields:
|
||||
yield from self.prepare_field(sub_obj, sub_field)
|
||||
|
||||
@cached_property
|
||||
def id(self):
|
||||
"""
|
||||
Returns the value to use as the ID of the record in the index
|
||||
"""
|
||||
return force_str(self.obj.pk)
|
||||
|
||||
@cached_property
|
||||
def title(self):
|
||||
"""
|
||||
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def body(self):
|
||||
"""
|
||||
Returns all values to index as "body". This is the value of all SearchFields excluding the title
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and not current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def autocomplete(self):
|
||||
"""
|
||||
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if isinstance(current_field, AutocompleteField):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
def as_vector(self, texts, for_autocomplete=False):
|
||||
"""
|
||||
Converts an array of strings into a SearchVector that can be indexed.
|
||||
"""
|
||||
texts = [(text.strip(), weight) for text, weight in texts]
|
||||
texts = [(text, weight) for text, weight in texts if text]
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, backend, db_alias=None):
|
||||
self.backend = backend
|
||||
self.name = self.backend.index_name
|
||||
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
|
||||
self.connection = connections[self.db_alias]
|
||||
if self.connection.vendor != "mysql":
|
||||
raise NotSupportedError(
|
||||
"You must select a MySQL database " "to use MySQL search."
|
||||
)
|
||||
|
||||
self.entries = IndexEntry._default_manager.using(self.db_alias)
|
||||
|
||||
def add_model(self, model):
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def _refresh_title_norms(self, full=False):
|
||||
"""
|
||||
Refreshes the value of the title_norm field.
|
||||
|
||||
This needs to be set to 'lavg/ld' where:
|
||||
- lavg is the average length of titles in all documents (also in terms)
|
||||
- ld is the length of the title field in this document (in terms)
|
||||
"""
|
||||
|
||||
lavg = (
|
||||
self.entries.annotate(title_length=Length("title"))
|
||||
.filter(title_length__gt=0)
|
||||
.aggregate(Avg("title_length"))["title_length__avg"]
|
||||
)
|
||||
|
||||
if full:
|
||||
# Update the whole table
|
||||
# This is the most accurate option but requires a full table rewrite
|
||||
# so we can't do it too often as it could lead to locking issues.
|
||||
entries = self.entries
|
||||
|
||||
else:
|
||||
# Only update entries where title_norm is 1.0
|
||||
# This is the default value set on new entries.
|
||||
# It's possible that other entries could have this exact value but there shouldn't be too many of those
|
||||
entries = self.entries.filter(title_norm=1.0)
|
||||
|
||||
entries.annotate(title_length=Length("title")).filter(
|
||||
title_length__gt=0
|
||||
).update(title_norm=lavg / F("title_length"))
|
||||
|
||||
def delete_stale_model_entries(self, model):
|
||||
existing_pks = (
|
||||
model._default_manager.using(self.db_alias)
|
||||
.annotate(object_id=Cast("pk", TextField()))
|
||||
.values("object_id")
|
||||
)
|
||||
content_types_pks = get_descendants_content_types_pks(model)
|
||||
stale_entries = self.entries.filter(
|
||||
content_type_id__in=content_types_pks
|
||||
).exclude(object_id__in=existing_pks)
|
||||
stale_entries.delete()
|
||||
|
||||
def delete_stale_entries(self):
|
||||
for model in get_indexed_models():
|
||||
# We don’t need to delete stale entries for non-root models,
|
||||
# since we already delete them by deleting roots.
|
||||
if not model._meta.parents:
|
||||
self.delete_stale_model_entries(model)
|
||||
|
||||
def add_item(self, obj):
|
||||
self.add_items(obj._meta.model, [obj])
|
||||
|
||||
def add_items_update_then_create(self, content_type_pk, indexers):
|
||||
ids_and_data = {}
|
||||
for indexer in indexers:
|
||||
ids_and_data[indexer.id] = (
|
||||
indexer.title,
|
||||
indexer.autocomplete,
|
||||
indexer.body,
|
||||
)
|
||||
|
||||
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
|
||||
indexed_ids = frozenset(
|
||||
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
|
||||
"object_id", flat=True
|
||||
)
|
||||
)
|
||||
for indexed_id in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[indexed_id]
|
||||
index_entries_for_ct.filter(object_id=indexed_id).update(
|
||||
title=title, autocomplete=autocomplete, body=body
|
||||
)
|
||||
|
||||
to_be_created = []
|
||||
for object_id in ids_and_data.keys():
|
||||
if object_id not in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[object_id]
|
||||
to_be_created.append(
|
||||
IndexEntry(
|
||||
content_type_id=content_type_pk,
|
||||
object_id=object_id,
|
||||
title=title,
|
||||
autocomplete=autocomplete,
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
self.entries.bulk_create(to_be_created)
|
||||
|
||||
self._refresh_title_norms()
|
||||
|
||||
def add_items(self, model, objs):
|
||||
search_fields = model.get_search_fields()
|
||||
if not search_fields:
|
||||
return
|
||||
|
||||
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
|
||||
|
||||
# TODO: Delete unindexed objects while dealing with proxy models.
|
||||
if indexers:
|
||||
content_type_pk = get_content_type_pk(model)
|
||||
|
||||
update_method = self.add_items_update_then_create
|
||||
update_method(content_type_pk, indexers)
|
||||
|
||||
def delete_item(self, item):
|
||||
item.index_entries.all()._raw_delete(using=self.db_alias)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class MySQLSearchQueryCompiler(BaseSearchQueryCompiler):
|
||||
DEFAULT_OPERATOR = "and"
|
||||
LAST_TERM_IS_PREFIX = False
|
||||
TARGET_SEARCH_FIELD_TYPE = SearchField
|
||||
FTS_TABLE_FIELDS = ["title", "body"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
local_search_fields = self.get_search_fields_for_model()
|
||||
|
||||
if self.fields is None:
|
||||
# search over the fields defined on the current model
|
||||
self.search_fields = local_search_fields
|
||||
else:
|
||||
# build a search_fields set from the passed definition,
|
||||
# which may involve traversing relations
|
||||
self.search_fields = {
|
||||
field_lookup: self.get_search_field(
|
||||
field_lookup, fields=local_search_fields
|
||||
)
|
||||
for field_lookup in self.fields
|
||||
}
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_searchable_search_fields()
|
||||
|
||||
def get_search_field(self, field_lookup, fields=None):
|
||||
if fields is None:
|
||||
fields = self.search_fields
|
||||
|
||||
if LOOKUP_SEP in field_lookup:
|
||||
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
|
||||
else:
|
||||
sub_field_name = None
|
||||
|
||||
for field in fields:
|
||||
if (
|
||||
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
|
||||
and field.field_name == field_lookup
|
||||
):
|
||||
return field
|
||||
|
||||
# Note: Searching on a specific related field using
|
||||
# `.search(fields=…)` is not yet supported by Wagtail.
|
||||
# This method anticipates by already implementing it.
|
||||
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
|
||||
return self.get_search_field(sub_field_name, field.fields)
|
||||
|
||||
def build_search_query_content(self, query, invert=False):
|
||||
if isinstance(query, PlainText):
|
||||
terms = query.query_string.split()
|
||||
if not terms:
|
||||
return None
|
||||
|
||||
last_term = terms.pop()
|
||||
|
||||
lexemes = Lexeme(last_term, invert=invert, prefix=self.LAST_TERM_IS_PREFIX)
|
||||
for term in terms:
|
||||
new_lexeme = Lexeme(term, invert=invert)
|
||||
|
||||
if query.operator == "and":
|
||||
lexemes &= new_lexeme
|
||||
else:
|
||||
lexemes |= new_lexeme
|
||||
|
||||
return SearchQuery(lexemes)
|
||||
|
||||
elif isinstance(query, Phrase):
|
||||
return SearchQuery(query.query_string, search_type="phrase")
|
||||
|
||||
elif isinstance(query, Boost):
|
||||
# Not supported
|
||||
msg = "The Boost query is not supported by the MySQL search backend."
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
|
||||
return self.build_search_query_content(query.subquery, invert=invert)
|
||||
|
||||
elif isinstance(query, Not):
|
||||
return self.build_search_query_content(query.subquery, invert=not invert)
|
||||
|
||||
elif isinstance(query, (And, Or)):
|
||||
# If this part of the query is inverted, we swap the operator and
|
||||
# pass down the inversion state to the child queries.
|
||||
# This works thanks to De Morgan's law.
|
||||
#
|
||||
# For example, the following query:
|
||||
#
|
||||
# Not(And(Term("A"), Term("B")))
|
||||
#
|
||||
# Is equivalent to:
|
||||
#
|
||||
# Or(Not(Term("A")), Not(Term("B")))
|
||||
#
|
||||
# It's simpler to code it this way as we only need to store the
|
||||
# invert status of the terms rather than all the operators.
|
||||
|
||||
subquery_lexemes = [
|
||||
self.build_search_query_content(subquery, invert=invert)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
is_and = isinstance(query, And)
|
||||
|
||||
if invert:
|
||||
is_and = not is_and
|
||||
|
||||
if is_and:
|
||||
return balanced_reduce(lambda a, b: a & b, subquery_lexemes)
|
||||
else:
|
||||
return balanced_reduce(lambda a, b: a | b, subquery_lexemes)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the MySQL search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def build_search_query(self, query):
|
||||
return self.build_search_query_content(query)
|
||||
|
||||
def get_index_vectors(self, search_query):
|
||||
return [
|
||||
(F("index_entries__title"), F("index_entries__title_norm")),
|
||||
(F("index_entries__body"), 1.0),
|
||||
]
|
||||
|
||||
def get_fields_vectors(self, search_query):
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_search_vectors(self, search_query):
|
||||
if self.fields is None:
|
||||
return self.get_index_vectors(search_query)
|
||||
|
||||
else:
|
||||
return self.get_fields_vectors(search_query)
|
||||
|
||||
def _build_rank_expression(self, vectors, config):
|
||||
rank_expressions = [
|
||||
self.build_tsrank(vector, self.query, config=config) * boost
|
||||
for vector, boost in vectors
|
||||
]
|
||||
|
||||
rank_expression = rank_expressions[0]
|
||||
for other_rank_expression in rank_expressions[1:]:
|
||||
rank_expression += other_rank_expression
|
||||
|
||||
return rank_expression
|
||||
|
||||
def search(self, config, start, stop, score_field=None):
|
||||
# TODO: Handle MatchAll nested inside other search query classes.
|
||||
if isinstance(self.query, MatchAll):
|
||||
return self.queryset[start:stop]
|
||||
|
||||
elif isinstance(self.query, Not) and isinstance(self.query.subquery, MatchAll):
|
||||
return self.queryset.none()
|
||||
|
||||
if isinstance(
|
||||
self.query, Not
|
||||
): # If the outermost operator is a Not, we invert the query. This is done because, if every search term is negated, the Not() will return no results, an we want to match all results instead.
|
||||
query = self.query.subquery
|
||||
negated = True
|
||||
else:
|
||||
query = self.query
|
||||
negated = False
|
||||
|
||||
search_query = self.build_search_query(query)
|
||||
match_expression = MatchExpression(
|
||||
search_query, columns=self.FTS_TABLE_FIELDS, output_field=BooleanField()
|
||||
) # For example: MATCH (`title`, `body`) AGAINST ('+query' IN BOOLEAN MODE)
|
||||
|
||||
# In Django 4.0 the above match expression would produce this SQL WHERE clause:
|
||||
#
|
||||
# WHERE ... MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE)
|
||||
#
|
||||
# In Django 4.1, this behavior was changed:
|
||||
#
|
||||
# https://code.djangoproject.com/ticket/32691
|
||||
# https://github.com/django/django/commit/407fe95cb116599adeb4b9ed01df5673aa5cb1db
|
||||
#
|
||||
# so that instead this SQL WHERE clause is generated, explicitly filtering
|
||||
# against "= True":
|
||||
#
|
||||
# WHERE ... MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE) = True
|
||||
#
|
||||
# This no longer works properly because MATCH returns a floating point score
|
||||
# as a measurement of the match quality, not a boolean value:
|
||||
#
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/fulltext-boolean.html
|
||||
#
|
||||
# In order for filtering on "= True" to work, we change the match expression
|
||||
# SQL to be:
|
||||
#
|
||||
# WHERE ... CASE WHEN MATCH (`title`, `body`) AGAINST (query IN BOOLEAN MODE) THEN True ELSE False END = True
|
||||
match_expression = Case(When(match_expression, then=True), default=False)
|
||||
|
||||
score_expression = MatchExpression(
|
||||
search_query, columns=["title"], output_field=FloatField()
|
||||
) * F("title_norm") + MatchExpression(
|
||||
search_query, columns=["body"], output_field=FloatField()
|
||||
)
|
||||
|
||||
index_entries = IndexEntry.objects.annotate(score=score_expression).filter(
|
||||
content_type_id__in=get_descendants_content_types_pks(self.queryset.model)
|
||||
)
|
||||
if not negated:
|
||||
index_entries = index_entries.filter(match_expression)
|
||||
if self.order_by_relevance: # Only applies to the case where the outermost query is not a Not(), because if it is, the relevance score is always 0 (anything that matches is excluded from the results).
|
||||
index_entries = index_entries.order_by(score_expression.desc())
|
||||
else:
|
||||
index_entries = index_entries.exclude(match_expression)
|
||||
|
||||
index_entries = index_entries[start:stop] # Trim the results
|
||||
|
||||
object_ids = {
|
||||
index_entry.object_id for index_entry in index_entries
|
||||
} # Get the set of IDs from the indexed objects, removes duplicates too
|
||||
|
||||
results = self.queryset.filter(id__in=object_ids)
|
||||
|
||||
return results
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
lhs = field.get_attname(self.queryset.model) + "__" + lookup
|
||||
return Q(**{lhs: value})
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
if connector == "AND":
|
||||
q = Q(*filters)
|
||||
|
||||
elif connector == "OR":
|
||||
q = OR([Q(fil) for fil in filters])
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
if negated:
|
||||
q = ~q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class MySQLAutocompleteQueryCompiler(MySQLSearchQueryCompiler):
|
||||
LAST_TERM_IS_PREFIX = True
|
||||
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
|
||||
FTS_TABLE_FIELDS = ["autocomplete"]
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.autocomplete_config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_autocomplete_search_fields()
|
||||
|
||||
def get_index_vectors(self, search_query):
|
||||
return [(F("index_entries__autocomplete"), 1.0)]
|
||||
|
||||
def get_fields_vectors(self, search_query):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MySQLSearchResults(BaseSearchResults):
|
||||
def get_queryset(self, for_count=False):
|
||||
if for_count:
|
||||
start = None
|
||||
stop = None
|
||||
else:
|
||||
start = self.start
|
||||
stop = self.stop
|
||||
|
||||
return self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend),
|
||||
start,
|
||||
stop,
|
||||
score_field=self._score_field,
|
||||
)
|
||||
|
||||
def _do_search(self):
|
||||
return list(self.get_queryset())
|
||||
|
||||
def _do_count(self):
|
||||
return self.get_queryset(for_count=True).count()
|
||||
|
||||
supports_facet = True
|
||||
|
||||
def facet(self, field_name):
|
||||
# Get field
|
||||
field = self.query_compiler._get_filterable_field(field_name)
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot facet search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.query_compiler.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
query = self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend), None, None
|
||||
)
|
||||
results = (
|
||||
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
|
||||
)
|
||||
|
||||
return OrderedDict(
|
||||
[(result[field_name], result["count"]) for result in results]
|
||||
)
|
||||
|
||||
|
||||
class MySQLSearchRebuilder:
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def start(self):
|
||||
self.index.delete_stale_entries()
|
||||
return self.index
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
|
||||
class MySQLSearchAtomicRebuilder(MySQLSearchRebuilder):
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.transaction = transaction.atomic(using=index.db_alias)
|
||||
self.transaction_opened = False
|
||||
|
||||
def start(self):
|
||||
self.transaction.__enter__()
|
||||
self.transaction_opened = True
|
||||
return super().start()
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
self.transaction.__exit__(None, None, None)
|
||||
self.transaction_opened = False
|
||||
|
||||
def __del__(self):
|
||||
# TODO: Implement a cleaner way to close the connection on failure.
|
||||
if self.transaction_opened:
|
||||
self.transaction.needs_rollback = True
|
||||
self.finish()
|
||||
|
||||
|
||||
class MySQLSearchBackend(BaseSearchBackend):
|
||||
query_compiler_class = MySQLSearchQueryCompiler
|
||||
autocomplete_query_compiler_class = MySQLAutocompleteQueryCompiler
|
||||
|
||||
results_class = MySQLSearchResults
|
||||
rebuilder_class = MySQLSearchRebuilder
|
||||
atomic_rebuilder_class = MySQLSearchAtomicRebuilder
|
||||
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
self.index_name = params.get("INDEX", "default")
|
||||
|
||||
# MySQL backend currently has no config options
|
||||
self.config = None
|
||||
self.autocomplete_config = None
|
||||
|
||||
if params.get("ATOMIC_REBUILD", False):
|
||||
self.rebuilder_class = self.atomic_rebuilder_class
|
||||
|
||||
def get_index_for_model(self, model, db_alias=None):
|
||||
return Index(self, db_alias)
|
||||
|
||||
def get_index_for_object(self, obj):
|
||||
return self.get_index_for_model(obj._meta.model, obj._state.db)
|
||||
|
||||
def reset_index(self):
|
||||
for connection in [
|
||||
connection
|
||||
for connection in connections.all()
|
||||
if connection.vendor == "mysql"
|
||||
]:
|
||||
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
|
||||
|
||||
def add_type(self, model):
|
||||
pass # Not needed.
|
||||
|
||||
def refresh_index(self):
|
||||
pass # Not needed.
|
||||
|
||||
def add(self, obj):
|
||||
self.get_index_for_object(obj).add_item(obj)
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
if obj_list:
|
||||
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
|
||||
|
||||
def delete(self, obj):
|
||||
self.get_index_for_object(obj).delete_item(obj)
|
||||
|
||||
|
||||
SearchBackend = MySQLSearchBackend
|
||||
254
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/query.py
vendored
Normal file
254
env/lib/python3.10/site-packages/wagtail/search/backends/database/mysql/query.py
vendored
Normal file
@@ -0,0 +1,254 @@
|
||||
import re
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.expressions import CombinedExpression, Expression, Value
|
||||
from django.db.models.fields import BooleanField, Field
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
|
||||
|
||||
class LexemeCombinable(Expression):
|
||||
BITAND = "+"
|
||||
BITOR = ""
|
||||
invert = False # By default, it is not inverted
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if not isinstance(other, LexemeCombinable):
|
||||
raise TypeError(
|
||||
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
|
||||
)
|
||||
if reversed:
|
||||
return CombinedLexeme(other, connector, self)
|
||||
return CombinedLexeme(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def bitand(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def bitor(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
|
||||
class SearchQueryField(Field):
|
||||
def db_type(self, connection):
|
||||
return None
|
||||
|
||||
|
||||
class Lexeme(LexemeCombinable, Value):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(
|
||||
self, value, output_field=None, invert=False, prefix=False, weight=None
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.invert = invert
|
||||
self.weight = weight
|
||||
super().__init__(value, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
param = re.sub(
|
||||
r"\W+", " ", self.value
|
||||
) # Remove non-word characters. This is done to disallow the usage of full text search operators in the MATCH clause, because MySQL doesn't include these kinds of characters in FULLTEXT indexes.
|
||||
|
||||
template = "%s"
|
||||
|
||||
if self.prefix:
|
||||
param = f"{param}*"
|
||||
if self.invert:
|
||||
param = f"(-{param})"
|
||||
else:
|
||||
param = f"{param}"
|
||||
|
||||
return template, [param]
|
||||
|
||||
|
||||
class CombinedLexeme(LexemeCombinable):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
value_params = []
|
||||
|
||||
lhs_connector = self.connector
|
||||
rhs_connector = self.connector
|
||||
|
||||
if (
|
||||
self.lhs.invert and self.connector == "+"
|
||||
): # NOTE: This is a special case for MySQL. If either side's operator is AND (+), and it is inverted, the operator should become NOT (-). If we did nothing, the result could would be '+X +(-Y)' for And(X, Not(Y)), which seems correct, but produces a wrong result. The solution is to turn the query into '+X -Y', which does work, and therefore this is done here.
|
||||
# TODO: There may be a better solution than this.
|
||||
modified_value = self.lhs.value.copy()
|
||||
modified_value.invert = not modified_value.invert
|
||||
lhs_connector = "-"
|
||||
lsql, params = compiler.compile(modified_value)
|
||||
else:
|
||||
lsql, params = compiler.compile(self.lhs)
|
||||
value_params.extend(params)
|
||||
|
||||
if self.rhs.invert and self.connector == "+": # Same explanation as above.
|
||||
modified_value = self.rhs.value.copy()
|
||||
modified_value.invert = not modified_value.invert
|
||||
rhs_connector = "-"
|
||||
rsql, params = compiler.compile(modified_value)
|
||||
else:
|
||||
rsql, params = compiler.compile(self.rhs)
|
||||
value_params.extend(params)
|
||||
|
||||
combined_sql = "({}{} {}{})".format(
|
||||
lhs_connector, lsql, rhs_connector, rsql
|
||||
) # if self.connector is '+' (AND), then both terms will be ANDed together. We need to repeat the connector to make that work.
|
||||
combined_value = combined_sql % tuple(value_params)
|
||||
return "%s", [combined_value]
|
||||
|
||||
|
||||
class SearchQueryCombinable:
|
||||
BITAND = "+"
|
||||
BITOR = ""
|
||||
|
||||
def _combine(self, other, connector: str, reversed: bool = False):
|
||||
if not isinstance(other, SearchQueryCombinable):
|
||||
raise TypeError(
|
||||
"SearchQuery can only be combined with other SearchQuery "
|
||||
"instances, got %s." % type(other).__name__
|
||||
)
|
||||
if reversed:
|
||||
return CombinedSearchQuery(other, connector, self)
|
||||
return CombinedSearchQuery(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __ror__(self, other):
|
||||
return self._combine(other, self.BITOR, True)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def __rand__(self, other):
|
||||
return self._combine(other, self.BITAND, True)
|
||||
|
||||
|
||||
class SearchQuery(SearchQueryCombinable, Expression):
|
||||
def __init__(
|
||||
self, value: Union[LexemeCombinable, str], search_type: str = "lexeme", **extra
|
||||
):
|
||||
super().__init__(output_field=SearchQueryField())
|
||||
self.extra = extra
|
||||
if (
|
||||
isinstance(value, str) or search_type == "phrase"
|
||||
): # If the value is a string, we assume it's a phrase
|
||||
safe_string = re.sub(
|
||||
r"\W+", " ", value
|
||||
) # Remove non-word characters. This is done to disallow the usage of full text search operators in the MATCH clause, because MySQL doesn't include these kinds of characters in FULLTEXT indexes.
|
||||
self.value = Value(
|
||||
'"%s"' % safe_string
|
||||
) # We wrap it in quotes to make sure it's parsed as a phrase
|
||||
else: # Otherwise, we assume it's a lexeme
|
||||
self.value = value
|
||||
|
||||
def as_sql(
|
||||
self,
|
||||
compiler: SQLCompiler,
|
||||
connection: BaseDatabaseWrapper,
|
||||
**extra_context: Any,
|
||||
) -> Tuple[str, List[Any]]:
|
||||
sql, params = compiler.compile(self.value)
|
||||
return (sql, params)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.value.__repr__()
|
||||
|
||||
|
||||
class CombinedSearchQuery(SearchQueryCombinable, CombinedExpression):
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(lhs, connector, rhs, output_field)
|
||||
|
||||
def __str__(self):
|
||||
return "%s" % super().__str__()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
value_params = []
|
||||
|
||||
lhs_connector = self.connector
|
||||
rhs_connector = self.connector
|
||||
|
||||
if (
|
||||
isinstance(self.lhs, SearchQuery)
|
||||
and isinstance(self.lhs.value, Lexeme)
|
||||
and self.lhs.value.invert
|
||||
and self.connector == "+"
|
||||
): # NOTE: The explanation for this special case is the same as above, in the CombinedLexeme class.
|
||||
modified_value = self.lhs.value.copy()
|
||||
modified_value.invert = not modified_value.invert
|
||||
lhs_connector = "-"
|
||||
lsql, params = compiler.compile(modified_value)
|
||||
else:
|
||||
lsql, params = compiler.compile(self.lhs)
|
||||
value_params.extend(params)
|
||||
|
||||
if (
|
||||
isinstance(self.rhs, SearchQuery)
|
||||
and isinstance(self.rhs.value, Lexeme)
|
||||
and self.rhs.value.invert
|
||||
and self.connector == "+"
|
||||
): # NOTE: The explanation for this special case is the same as above, in the CombinedLexeme class.
|
||||
modified_value = self.rhs.value.copy()
|
||||
modified_value.invert = not modified_value.invert
|
||||
rhs_connector = "-"
|
||||
rsql, params = compiler.compile(modified_value)
|
||||
else:
|
||||
rsql, params = compiler.compile(self.rhs)
|
||||
value_params.extend(params)
|
||||
|
||||
combined_sql = "({}{} {}{})".format(
|
||||
lhs_connector, lsql, rhs_connector, rsql
|
||||
) # if self.connector is '+' (AND), then both terms will be ANDed together. We need to repeat the connector to make that work.
|
||||
combined_value = combined_sql % tuple(value_params)
|
||||
return "%s", [combined_value]
|
||||
|
||||
|
||||
class MatchExpression(Expression):
|
||||
filterable = True
|
||||
template = "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query: SearchQueryCombinable,
|
||||
columns: List[str] = None,
|
||||
output_field: Field = BooleanField(),
|
||||
) -> None:
|
||||
super().__init__(output_field=output_field)
|
||||
self.query = query
|
||||
self.columns = (
|
||||
columns
|
||||
or [
|
||||
"title",
|
||||
"body",
|
||||
]
|
||||
) # We need to provide a default list of columns if the user doesn't specify one. We have a joint index for for 'title' and 'body' (see wagtail.search.migrations.0006_customise_indexentry), so we'll pick that one.
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
compiled_query = compiler.compile(self.query) # Compile the query to a string
|
||||
formatted_query = compiled_query[0] % tuple(
|
||||
compiled_query[1]
|
||||
) # Substitute the params in the query
|
||||
column_list = ", ".join(
|
||||
[f"`{column}`" for column in self.columns]
|
||||
) # ['title', 'body'] becomes '`title`, `body`'
|
||||
params = [formatted_query]
|
||||
return (self.template % (column_list, "%s"), params)
|
||||
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/__init__.py
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
784
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/postgres.py
vendored
Normal file
784
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/postgres.py
vendored
Normal file
@@ -0,0 +1,784 @@
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
|
||||
from django.contrib.postgres.search import SearchQuery, SearchRank, SearchVector
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
||||
from django.db.models import Avg, Count, F, Manager, Q, TextField, Value
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.functions import Cast, Length
|
||||
from django.db.models.sql.subqueries import InsertQuery
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from ....index import AutocompleteField, RelatedFields, SearchField, get_indexed_models
|
||||
from ....models import IndexEntry
|
||||
from ....query import And, Boost, MatchAll, Not, Or, Phrase, PlainText
|
||||
from ....utils import (
|
||||
ADD,
|
||||
MUL,
|
||||
OR,
|
||||
get_content_type_pk,
|
||||
get_descendants_content_types_pks,
|
||||
)
|
||||
from ...base import (
|
||||
BaseSearchBackend,
|
||||
BaseSearchQueryCompiler,
|
||||
BaseSearchResults,
|
||||
FilterFieldError,
|
||||
)
|
||||
from .query import Lexeme
|
||||
from .weights import get_sql_weights, get_weight
|
||||
|
||||
EMPTY_VECTOR = SearchVector(Value("", output_field=TextField()))
|
||||
|
||||
|
||||
class ObjectIndexer:
|
||||
"""
|
||||
Responsible for extracting data from an object to be inserted into the index.
|
||||
"""
|
||||
|
||||
def __init__(self, obj, backend):
|
||||
self.obj = obj
|
||||
self.search_fields = obj.get_search_fields()
|
||||
self.config = backend.config
|
||||
self.autocomplete_config = backend.autocomplete_config
|
||||
|
||||
def prepare_value(self, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
elif isinstance(value, list):
|
||||
return ", ".join(self.prepare_value(item) for item in value)
|
||||
|
||||
elif isinstance(value, dict):
|
||||
return ", ".join(self.prepare_value(item) for item in value.values())
|
||||
|
||||
return force_str(value)
|
||||
|
||||
def prepare_field(self, obj, field):
|
||||
if isinstance(field, SearchField):
|
||||
yield (
|
||||
field,
|
||||
get_weight(field.boost),
|
||||
self.prepare_value(field.get_value(obj)),
|
||||
)
|
||||
|
||||
elif isinstance(field, AutocompleteField):
|
||||
# AutocompleteField does not define a boost parameter, so use a base weight of 'D'
|
||||
yield (field, "D", self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, RelatedFields):
|
||||
sub_obj = field.get_value(obj)
|
||||
if sub_obj is None:
|
||||
return
|
||||
|
||||
if isinstance(sub_obj, Manager):
|
||||
sub_objs = sub_obj.all()
|
||||
|
||||
else:
|
||||
if callable(sub_obj):
|
||||
sub_obj = sub_obj()
|
||||
|
||||
sub_objs = [sub_obj]
|
||||
|
||||
for sub_obj in sub_objs:
|
||||
for sub_field in field.fields:
|
||||
yield from self.prepare_field(sub_obj, sub_field)
|
||||
|
||||
def as_vector(self, texts, for_autocomplete=False):
|
||||
"""
|
||||
Converts an array of strings into a SearchVector that can be indexed.
|
||||
"""
|
||||
texts = [(text.strip(), weight) for text, weight in texts]
|
||||
texts = [(text, weight) for text, weight in texts if text]
|
||||
|
||||
if not texts:
|
||||
return EMPTY_VECTOR
|
||||
|
||||
search_config = self.autocomplete_config if for_autocomplete else self.config
|
||||
|
||||
return ADD(
|
||||
[
|
||||
SearchVector(
|
||||
Value(text, output_field=TextField()),
|
||||
weight=weight,
|
||||
config=search_config,
|
||||
)
|
||||
for text, weight in texts
|
||||
]
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def id(self):
|
||||
"""
|
||||
Returns the value to use as the ID of the record in the index
|
||||
"""
|
||||
return force_str(self.obj.pk)
|
||||
|
||||
@cached_property
|
||||
def title(self):
|
||||
"""
|
||||
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, boost, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and current_field.field_name == "title"
|
||||
):
|
||||
texts.append((value, boost))
|
||||
|
||||
return self.as_vector(texts)
|
||||
|
||||
@cached_property
|
||||
def body(self):
|
||||
"""
|
||||
Returns all values to index as "body". This is the value of all SearchFields excluding the title
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, boost, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and not current_field.field_name == "title"
|
||||
):
|
||||
texts.append((value, boost))
|
||||
|
||||
return self.as_vector(texts)
|
||||
|
||||
@cached_property
|
||||
def autocomplete(self):
|
||||
"""
|
||||
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, boost, value in self.prepare_field(self.obj, field):
|
||||
if isinstance(current_field, AutocompleteField):
|
||||
texts.append((value, boost))
|
||||
|
||||
return self.as_vector(texts, for_autocomplete=True)
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, backend, db_alias=None):
|
||||
self.backend = backend
|
||||
self.name = self.backend.index_name
|
||||
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
|
||||
self.connection = connections[self.db_alias]
|
||||
if self.connection.vendor != "postgresql":
|
||||
raise NotSupportedError(
|
||||
"You must select a PostgreSQL database " "to use PostgreSQL search."
|
||||
)
|
||||
|
||||
# Whether to allow adding items via the faster upsert method available in Postgres >=9.5
|
||||
self._enable_upsert = self.connection.pg_version >= 90500
|
||||
|
||||
self.entries = IndexEntry._default_manager.using(self.db_alias)
|
||||
|
||||
def add_model(self, model):
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def _refresh_title_norms(self, full=False):
|
||||
"""
|
||||
Refreshes the value of the title_norm field.
|
||||
|
||||
This needs to be set to 'lavg/ld' where:
|
||||
- lavg is the average length of titles in all documents (also in terms)
|
||||
- ld is the length of the title field in this document (in terms)
|
||||
"""
|
||||
|
||||
lavg = (
|
||||
self.entries.annotate(title_length=Length("title"))
|
||||
.filter(title_length__gt=0)
|
||||
.aggregate(Avg("title_length"))["title_length__avg"]
|
||||
)
|
||||
|
||||
if full:
|
||||
# Update the whole table
|
||||
# This is the most accurate option but requires a full table rewrite
|
||||
# so we can't do it too often as it could lead to locking issues.
|
||||
entries = self.entries
|
||||
|
||||
else:
|
||||
# Only update entries where title_norm is 1.0
|
||||
# This is the default value set on new entries.
|
||||
# It's possible that other entries could have this exact value but there shouldn't be too many of those
|
||||
entries = self.entries.filter(title_norm=1.0)
|
||||
|
||||
entries.annotate(title_length=Length("title")).filter(
|
||||
title_length__gt=0
|
||||
).update(title_norm=lavg / F("title_length"))
|
||||
|
||||
def delete_stale_model_entries(self, model):
|
||||
existing_pks = (
|
||||
model._default_manager.using(self.db_alias)
|
||||
.annotate(object_id=Cast("pk", TextField()))
|
||||
.values("object_id")
|
||||
)
|
||||
content_types_pks = get_descendants_content_types_pks(model)
|
||||
stale_entries = self.entries.filter(
|
||||
content_type_id__in=content_types_pks
|
||||
).exclude(object_id__in=existing_pks)
|
||||
stale_entries.delete()
|
||||
|
||||
def delete_stale_entries(self):
|
||||
for model in get_indexed_models():
|
||||
# We don’t need to delete stale entries for non-root models,
|
||||
# since we already delete them by deleting roots.
|
||||
if not model._meta.parents:
|
||||
self.delete_stale_model_entries(model)
|
||||
|
||||
def add_item(self, obj):
|
||||
self.add_items(obj._meta.model, [obj])
|
||||
|
||||
def add_items_upsert(self, content_type_pk, indexers):
|
||||
compiler = InsertQuery(IndexEntry).get_compiler(connection=self.connection)
|
||||
title_sql = []
|
||||
autocomplete_sql = []
|
||||
body_sql = []
|
||||
data_params = []
|
||||
|
||||
for indexer in indexers:
|
||||
data_params.extend((content_type_pk, indexer.id))
|
||||
|
||||
# Compile title value
|
||||
value = compiler.prepare_value(
|
||||
IndexEntry._meta.get_field("title"), indexer.title
|
||||
)
|
||||
sql, params = value.as_sql(compiler, self.connection)
|
||||
title_sql.append(sql)
|
||||
data_params.extend(params)
|
||||
|
||||
# Compile autocomplete value
|
||||
value = compiler.prepare_value(
|
||||
IndexEntry._meta.get_field("autocomplete"), indexer.autocomplete
|
||||
)
|
||||
sql, params = value.as_sql(compiler, self.connection)
|
||||
autocomplete_sql.append(sql)
|
||||
data_params.extend(params)
|
||||
|
||||
# Compile body value
|
||||
value = compiler.prepare_value(
|
||||
IndexEntry._meta.get_field("body"), indexer.body
|
||||
)
|
||||
sql, params = value.as_sql(compiler, self.connection)
|
||||
body_sql.append(sql)
|
||||
data_params.extend(params)
|
||||
|
||||
data_sql = ", ".join(
|
||||
[
|
||||
f"(%s, %s, {a}, {b}, {c}, 1.0)"
|
||||
for a, b, c in zip(title_sql, autocomplete_sql, body_sql)
|
||||
]
|
||||
)
|
||||
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO %s (content_type_id, object_id, title, autocomplete, body, title_norm)
|
||||
(VALUES %s)
|
||||
ON CONFLICT (content_type_id, object_id)
|
||||
DO UPDATE SET title = EXCLUDED.title,
|
||||
title_norm = 1.0,
|
||||
autocomplete = EXCLUDED.autocomplete,
|
||||
body = EXCLUDED.body
|
||||
"""
|
||||
% (IndexEntry._meta.db_table, data_sql),
|
||||
data_params,
|
||||
)
|
||||
|
||||
self._refresh_title_norms()
|
||||
|
||||
def add_items_update_then_create(self, content_type_pk, indexers):
|
||||
ids_and_data = {}
|
||||
for indexer in indexers:
|
||||
ids_and_data[indexer.id] = (
|
||||
indexer.title,
|
||||
indexer.autocomplete,
|
||||
indexer.body,
|
||||
)
|
||||
|
||||
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
|
||||
indexed_ids = frozenset(
|
||||
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
|
||||
"object_id", flat=True
|
||||
)
|
||||
)
|
||||
for indexed_id in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[indexed_id]
|
||||
index_entries_for_ct.filter(object_id=indexed_id).update(
|
||||
title=title, autocomplete=autocomplete, body=body
|
||||
)
|
||||
|
||||
to_be_created = []
|
||||
for object_id in ids_and_data.keys():
|
||||
if object_id not in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[object_id]
|
||||
to_be_created.append(
|
||||
IndexEntry(
|
||||
content_type_id=content_type_pk,
|
||||
object_id=object_id,
|
||||
title=title,
|
||||
autocomplete=autocomplete,
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
self.entries.bulk_create(to_be_created)
|
||||
|
||||
self._refresh_title_norms()
|
||||
|
||||
def add_items(self, model, objs):
|
||||
search_fields = model.get_search_fields()
|
||||
if not search_fields:
|
||||
return
|
||||
|
||||
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
|
||||
|
||||
# TODO: Delete unindexed objects while dealing with proxy models.
|
||||
if indexers:
|
||||
content_type_pk = get_content_type_pk(model)
|
||||
|
||||
update_method = (
|
||||
self.add_items_upsert
|
||||
if self._enable_upsert
|
||||
else self.add_items_update_then_create
|
||||
)
|
||||
update_method(content_type_pk, indexers)
|
||||
|
||||
def delete_item(self, item):
|
||||
item.index_entries.all()._raw_delete(using=self.db_alias)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class PostgresSearchQueryCompiler(BaseSearchQueryCompiler):
|
||||
DEFAULT_OPERATOR = "and"
|
||||
LAST_TERM_IS_PREFIX = False
|
||||
TARGET_SEARCH_FIELD_TYPE = SearchField
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
local_search_fields = self.get_search_fields_for_model()
|
||||
|
||||
# Due to a Django bug, arrays are not automatically converted
|
||||
# when we use WEIGHTS_VALUES.
|
||||
self.sql_weights = get_sql_weights()
|
||||
|
||||
if self.fields is None:
|
||||
# search over the fields defined on the current model
|
||||
self.search_fields = local_search_fields
|
||||
else:
|
||||
# build a search_fields set from the passed definition,
|
||||
# which may involve traversing relations
|
||||
self.search_fields = {
|
||||
field_lookup: self.get_search_field(
|
||||
field_lookup, fields=local_search_fields
|
||||
)
|
||||
for field_lookup in self.fields
|
||||
}
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_searchable_search_fields()
|
||||
|
||||
def get_search_field(self, field_lookup, fields=None):
|
||||
if fields is None:
|
||||
fields = self.search_fields
|
||||
|
||||
if LOOKUP_SEP in field_lookup:
|
||||
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
|
||||
else:
|
||||
sub_field_name = None
|
||||
|
||||
for field in fields:
|
||||
if (
|
||||
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
|
||||
and field.field_name == field_lookup
|
||||
):
|
||||
return field
|
||||
|
||||
# Note: Searching on a specific related field using
|
||||
# `.search(fields=…)` is not yet supported by Wagtail.
|
||||
# This method anticipates by already implementing it.
|
||||
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
|
||||
return self.get_search_field(sub_field_name, field.fields)
|
||||
|
||||
def build_tsquery_content(self, query, config=None, invert=False):
|
||||
if isinstance(query, PlainText):
|
||||
terms = query.query_string.split()
|
||||
if not terms:
|
||||
return None
|
||||
|
||||
last_term = terms.pop()
|
||||
|
||||
lexemes = Lexeme(last_term, invert=invert, prefix=self.LAST_TERM_IS_PREFIX)
|
||||
for term in terms:
|
||||
new_lexeme = Lexeme(term, invert=invert)
|
||||
|
||||
if query.operator == "and":
|
||||
lexemes &= new_lexeme
|
||||
else:
|
||||
lexemes |= new_lexeme
|
||||
|
||||
return SearchQuery(lexemes, search_type="raw", config=config)
|
||||
|
||||
elif isinstance(query, Phrase):
|
||||
return SearchQuery(query.query_string, search_type="phrase", config=config)
|
||||
|
||||
elif isinstance(query, Boost):
|
||||
# Not supported
|
||||
msg = "The Boost query is not supported by the PostgreSQL search backend."
|
||||
warnings.warn(msg, RuntimeWarning)
|
||||
|
||||
return self.build_tsquery_content(
|
||||
query.subquery, config=config, invert=invert
|
||||
)
|
||||
|
||||
elif isinstance(query, Not):
|
||||
return self.build_tsquery_content(
|
||||
query.subquery, config=config, invert=not invert
|
||||
)
|
||||
|
||||
elif isinstance(query, (And, Or)):
|
||||
# If this part of the query is inverted, we swap the operator and
|
||||
# pass down the inversion state to the child queries.
|
||||
# This works thanks to De Morgan's law.
|
||||
#
|
||||
# For example, the following query:
|
||||
#
|
||||
# Not(And(Term("A"), Term("B")))
|
||||
#
|
||||
# Is equivalent to:
|
||||
#
|
||||
# Or(Not(Term("A")), Not(Term("B")))
|
||||
#
|
||||
# It's simpler to code it this way as we only need to store the
|
||||
# invert status of the terms rather than all the operators.
|
||||
|
||||
subquery_lexemes = [
|
||||
self.build_tsquery_content(subquery, config=config, invert=invert)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
is_and = isinstance(query, And)
|
||||
|
||||
if invert:
|
||||
is_and = not is_and
|
||||
|
||||
if is_and:
|
||||
return reduce(lambda a, b: a & b, subquery_lexemes)
|
||||
else:
|
||||
return reduce(lambda a, b: a | b, subquery_lexemes)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the PostgreSQL search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def build_tsquery(self, query, config=None):
|
||||
return self.build_tsquery_content(query, config=config)
|
||||
|
||||
def build_tsrank(self, vector, query, config=None, boost=1.0):
|
||||
if isinstance(query, (Phrase, PlainText, Not)):
|
||||
rank_expression = SearchRank(
|
||||
vector,
|
||||
self.build_tsquery(query, config=config),
|
||||
weights=self.sql_weights,
|
||||
)
|
||||
|
||||
if boost != 1.0:
|
||||
rank_expression *= boost
|
||||
|
||||
return rank_expression
|
||||
|
||||
elif isinstance(query, Boost):
|
||||
boost *= query.boost
|
||||
return self.build_tsrank(vector, query.subquery, config=config, boost=boost)
|
||||
|
||||
elif isinstance(query, And):
|
||||
return (
|
||||
MUL(
|
||||
1 + self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
elif isinstance(query, Or):
|
||||
return ADD(
|
||||
self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
) / (len(query.subqueries) or 1)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the PostgreSQL search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def get_index_vectors(self, search_query):
|
||||
return [
|
||||
(F("index_entries__title"), F("index_entries__title_norm")),
|
||||
(F("index_entries__body"), 1.0),
|
||||
]
|
||||
|
||||
def get_fields_vectors(self, search_query):
|
||||
return [
|
||||
(
|
||||
SearchVector(
|
||||
field_lookup,
|
||||
config=search_query.config,
|
||||
),
|
||||
search_field.boost,
|
||||
)
|
||||
for field_lookup, search_field in self.search_fields.items()
|
||||
]
|
||||
|
||||
def get_search_vectors(self, search_query):
|
||||
if self.fields is None:
|
||||
return self.get_index_vectors(search_query)
|
||||
|
||||
else:
|
||||
return self.get_fields_vectors(search_query)
|
||||
|
||||
def _build_rank_expression(self, vectors, config):
|
||||
rank_expressions = [
|
||||
self.build_tsrank(vector, self.query, config=config) * boost
|
||||
for vector, boost in vectors
|
||||
]
|
||||
|
||||
rank_expression = rank_expressions[0]
|
||||
for other_rank_expression in rank_expressions[1:]:
|
||||
rank_expression += other_rank_expression
|
||||
|
||||
return rank_expression
|
||||
|
||||
def search(self, config, start, stop, score_field=None):
|
||||
# TODO: Handle MatchAll nested inside other search query classes.
|
||||
if isinstance(self.query, MatchAll):
|
||||
return self.queryset[start:stop]
|
||||
|
||||
elif isinstance(self.query, Not) and isinstance(self.query.subquery, MatchAll):
|
||||
return self.queryset.none()
|
||||
|
||||
search_query = self.build_tsquery(self.query, config=config)
|
||||
vectors = self.get_search_vectors(search_query)
|
||||
rank_expression = self._build_rank_expression(vectors, config)
|
||||
|
||||
combined_vector = vectors[0][0]
|
||||
for vector, boost in vectors[1:]:
|
||||
combined_vector = combined_vector._combine(vector, "||", False)
|
||||
|
||||
queryset = self.queryset.annotate(_vector_=combined_vector).filter(
|
||||
_vector_=search_query
|
||||
)
|
||||
|
||||
if self.order_by_relevance:
|
||||
queryset = queryset.order_by(rank_expression.desc(), "-pk")
|
||||
|
||||
elif not queryset.query.order_by:
|
||||
# Adds a default ordering to avoid issue #3729.
|
||||
queryset = queryset.order_by("-pk")
|
||||
rank_expression = F("pk")
|
||||
|
||||
if score_field is not None:
|
||||
queryset = queryset.annotate(**{score_field: rank_expression})
|
||||
|
||||
return queryset[start:stop]
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
lhs = field.get_attname(self.queryset.model) + "__" + lookup
|
||||
return Q(**{lhs: value})
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
if connector == "AND":
|
||||
q = Q(*filters)
|
||||
|
||||
elif connector == "OR":
|
||||
q = OR([Q(fil) for fil in filters])
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
if negated:
|
||||
q = ~q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class PostgresAutocompleteQueryCompiler(PostgresSearchQueryCompiler):
|
||||
LAST_TERM_IS_PREFIX = True
|
||||
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.autocomplete_config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_autocomplete_search_fields()
|
||||
|
||||
def get_index_vectors(self, search_query):
|
||||
return [(F("index_entries__autocomplete"), 1.0)]
|
||||
|
||||
def get_fields_vectors(self, search_query):
|
||||
return [
|
||||
(
|
||||
SearchVector(
|
||||
field_lookup,
|
||||
config=search_query.config,
|
||||
weight="D",
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
for field_lookup, search_field in self.search_fields.items()
|
||||
]
|
||||
|
||||
|
||||
class PostgresSearchResults(BaseSearchResults):
|
||||
def get_queryset(self, for_count=False):
|
||||
if for_count:
|
||||
start = None
|
||||
stop = None
|
||||
else:
|
||||
start = self.start
|
||||
stop = self.stop
|
||||
|
||||
return self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend),
|
||||
start,
|
||||
stop,
|
||||
score_field=self._score_field,
|
||||
)
|
||||
|
||||
def _do_search(self):
|
||||
return list(self.get_queryset())
|
||||
|
||||
def _do_count(self):
|
||||
return self.get_queryset(for_count=True).count()
|
||||
|
||||
supports_facet = True
|
||||
|
||||
def facet(self, field_name):
|
||||
# Get field
|
||||
field = self.query_compiler._get_filterable_field(field_name)
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot facet search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.query_compiler.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
query = self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend), None, None
|
||||
)
|
||||
results = (
|
||||
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
|
||||
)
|
||||
|
||||
return OrderedDict(
|
||||
[(result[field_name], result["count"]) for result in results]
|
||||
)
|
||||
|
||||
|
||||
class PostgresSearchRebuilder:
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def start(self):
|
||||
self.index.delete_stale_entries()
|
||||
return self.index
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
|
||||
class PostgresSearchAtomicRebuilder(PostgresSearchRebuilder):
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.transaction = transaction.atomic(using=index.db_alias)
|
||||
self.transaction_opened = False
|
||||
|
||||
def start(self):
|
||||
self.transaction.__enter__()
|
||||
self.transaction_opened = True
|
||||
return super().start()
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
self.transaction.__exit__(None, None, None)
|
||||
self.transaction_opened = False
|
||||
|
||||
def __del__(self):
|
||||
# TODO: Implement a cleaner way to close the connection on failure.
|
||||
if self.transaction_opened:
|
||||
self.transaction.needs_rollback = True
|
||||
self.finish()
|
||||
|
||||
|
||||
class PostgresSearchBackend(BaseSearchBackend):
|
||||
query_compiler_class = PostgresSearchQueryCompiler
|
||||
autocomplete_query_compiler_class = PostgresAutocompleteQueryCompiler
|
||||
results_class = PostgresSearchResults
|
||||
rebuilder_class = PostgresSearchRebuilder
|
||||
atomic_rebuilder_class = PostgresSearchAtomicRebuilder
|
||||
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
self.index_name = params.get("INDEX", "default")
|
||||
self.config = params.get("SEARCH_CONFIG")
|
||||
|
||||
# Use 'simple' config for autocomplete to disable stemming
|
||||
# A good description for why this is important can be found at:
|
||||
# https://www.postgresql.org/docs/9.1/datatype-textsearch.html#DATATYPE-TSQUERY
|
||||
self.autocomplete_config = params.get("AUTOCOMPLETE_SEARCH_CONFIG", "simple")
|
||||
|
||||
if params.get("ATOMIC_REBUILD", False):
|
||||
self.rebuilder_class = self.atomic_rebuilder_class
|
||||
|
||||
def get_index_for_model(self, model, db_alias=None):
|
||||
return Index(self, db_alias)
|
||||
|
||||
def get_index_for_object(self, obj):
|
||||
return self.get_index_for_model(obj._meta.model, obj._state.db)
|
||||
|
||||
def reset_index(self):
|
||||
for connection in [
|
||||
connection
|
||||
for connection in connections.all()
|
||||
if connection.vendor == "postgresql"
|
||||
]:
|
||||
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
|
||||
|
||||
def add_type(self, model):
|
||||
pass # Not needed.
|
||||
|
||||
def refresh_index(self):
|
||||
pass # Not needed.
|
||||
|
||||
def add(self, obj):
|
||||
self.get_index_for_object(obj).add_item(obj)
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
if obj_list:
|
||||
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
|
||||
|
||||
def delete(self, obj):
|
||||
self.get_index_for_object(obj).delete_item(obj)
|
||||
|
||||
|
||||
SearchBackend = PostgresSearchBackend
|
||||
83
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/query.py
vendored
Normal file
83
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/query.py
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
from django.contrib.postgres.search import SearchQueryField
|
||||
from django.db.models.expressions import Expression, Value
|
||||
|
||||
|
||||
class LexemeCombinable(Expression):
|
||||
BITAND = "&"
|
||||
BITOR = "|"
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if not isinstance(other, LexemeCombinable):
|
||||
raise TypeError(
|
||||
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
|
||||
)
|
||||
if reversed:
|
||||
return CombinedLexeme(other, connector, self)
|
||||
return CombinedLexeme(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def bitand(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def bitor(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
|
||||
class Lexeme(LexemeCombinable, Value):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(
|
||||
self, value, output_field=None, *, invert=False, prefix=False, weight=None
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.invert = invert
|
||||
self.weight = weight
|
||||
super().__init__(value, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
param = "'%s'" % self.value.replace("'", "''").replace("\\", "\\\\")
|
||||
|
||||
template = "%s"
|
||||
|
||||
label = ""
|
||||
if self.prefix:
|
||||
label += "*"
|
||||
if self.weight:
|
||||
label += self.weight
|
||||
|
||||
if label:
|
||||
param = f"{param}:{label}"
|
||||
if self.invert:
|
||||
param = f"!{param}"
|
||||
|
||||
return template, [param]
|
||||
|
||||
|
||||
class CombinedLexeme(LexemeCombinable):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
value_params = []
|
||||
lsql, params = compiler.compile(self.lhs)
|
||||
value_params.extend(params)
|
||||
|
||||
rsql, params = compiler.compile(self.rhs)
|
||||
value_params.extend(params)
|
||||
|
||||
combined_sql = f"({lsql} {self.connector} {rsql})"
|
||||
combined_value = combined_sql % tuple(value_params)
|
||||
return "%s", [combined_value]
|
||||
63
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/weights.py
vendored
Normal file
63
env/lib/python3.10/site-packages/wagtail/search/backends/database/postgres/weights.py
vendored
Normal file
@@ -0,0 +1,63 @@
|
||||
from itertools import zip_longest
|
||||
|
||||
from django.apps import apps
|
||||
|
||||
from wagtail.search.index import Indexed
|
||||
from wagtail.search.utils import get_search_fields
|
||||
|
||||
# This file contains the implementation of weights for PostgreSQL tsvectors. Only PostgreSQL has support for them, so that's why we define them here.
|
||||
|
||||
|
||||
WEIGHTS = "ABCD"
|
||||
WEIGHTS_COUNT = len(WEIGHTS)
|
||||
# These are filled when apps are ready.
|
||||
BOOSTS_WEIGHTS = []
|
||||
WEIGHTS_VALUES = []
|
||||
|
||||
|
||||
def get_boosts():
|
||||
boosts = set()
|
||||
for model in apps.get_models():
|
||||
if issubclass(model, Indexed):
|
||||
for search_field in get_search_fields(model.get_search_fields()):
|
||||
boost = search_field.boost
|
||||
if boost is not None:
|
||||
boosts.add(boost)
|
||||
return boosts
|
||||
|
||||
|
||||
def determine_boosts_weights(boosts=()):
|
||||
if not boosts:
|
||||
boosts = get_boosts()
|
||||
boosts = sorted(boosts, reverse=True)
|
||||
min_boost = boosts[-1]
|
||||
if len(boosts) <= WEIGHTS_COUNT:
|
||||
return list(zip_longest(boosts, WEIGHTS, fillvalue=min(min_boost, 0)))
|
||||
max_boost = boosts[0]
|
||||
boost_step = (max_boost - min_boost) / (WEIGHTS_COUNT - 1)
|
||||
return [(max_boost - (i * boost_step), weight) for i, weight in enumerate(WEIGHTS)]
|
||||
|
||||
|
||||
def set_weights():
|
||||
BOOSTS_WEIGHTS.extend(determine_boosts_weights())
|
||||
weights = [w for w, c in BOOSTS_WEIGHTS]
|
||||
min_weight = min(weights)
|
||||
if min_weight <= 0:
|
||||
if min_weight == 0:
|
||||
min_weight = -0.1
|
||||
weights = [w - min_weight for w in weights]
|
||||
max_weight = max(weights)
|
||||
WEIGHTS_VALUES.extend([w / max_weight for w in reversed(weights)])
|
||||
|
||||
|
||||
def get_weight(boost):
|
||||
if boost is None:
|
||||
return WEIGHTS[-1]
|
||||
for max_boost, weight in BOOSTS_WEIGHTS:
|
||||
if boost >= max_boost:
|
||||
return weight
|
||||
return weight
|
||||
|
||||
|
||||
def get_sql_weights():
|
||||
return "{" + ",".join(map(str, WEIGHTS_VALUES)) + "}"
|
||||
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/__init__.py
vendored
Normal file
0
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/__init__.py
vendored
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
294
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/query.py
vendored
Normal file
294
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/query.py
vendored
Normal file
@@ -0,0 +1,294 @@
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from django.db.backends.base.base import BaseDatabaseWrapper
|
||||
from django.db.models.expressions import CombinedExpression, Expression, Func, Value
|
||||
from django.db.models.fields import BooleanField, Field, FloatField
|
||||
from django.db.models.sql.compiler import SQLCompiler
|
||||
|
||||
from wagtail.search.query import And, MatchAll, Not, Or, Phrase, PlainText, SearchQuery
|
||||
|
||||
|
||||
class BM25(Func):
|
||||
function = "bm25"
|
||||
output_field = FloatField()
|
||||
|
||||
def __init__(self):
|
||||
expressions = ()
|
||||
super().__init__(*expressions)
|
||||
|
||||
def as_sql(
|
||||
self,
|
||||
compiler: SQLCompiler,
|
||||
connection: BaseDatabaseWrapper,
|
||||
function=None,
|
||||
template=None,
|
||||
):
|
||||
sql, params = "bm25(wagtailsearch_indexentry_fts)", []
|
||||
return sql, params
|
||||
|
||||
|
||||
class LexemeCombinable(Expression):
|
||||
BITAND = "AND"
|
||||
BITOR = "OR"
|
||||
|
||||
def _combine(self, other, connector, reversed, node=None):
|
||||
if not isinstance(other, LexemeCombinable):
|
||||
raise TypeError(
|
||||
f"Lexeme can only be combined with other Lexemes, got {type(other)}."
|
||||
)
|
||||
if reversed:
|
||||
return CombinedLexeme(other, connector, self)
|
||||
return CombinedLexeme(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def bitand(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def bitor(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
|
||||
class SearchQueryField(Field):
|
||||
def db_type(self, connection):
|
||||
return None
|
||||
|
||||
|
||||
class Lexeme(LexemeCombinable, Value):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, value, output_field=None, *, prefix=False, weight=None):
|
||||
self.prefix = prefix
|
||||
self.weight = weight
|
||||
super().__init__(value, output_field=output_field)
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
param = self.value.replace("'", "''").replace("\\", "\\\\")
|
||||
|
||||
if self.prefix:
|
||||
template = '"%s"*'
|
||||
else:
|
||||
template = '"%s"'
|
||||
|
||||
return template, [param]
|
||||
|
||||
|
||||
class CombinedLexeme(LexemeCombinable):
|
||||
_output_field = SearchQueryField()
|
||||
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(output_field=output_field)
|
||||
self.connector = connector
|
||||
self.lhs = lhs
|
||||
self.rhs = rhs
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
value_params = []
|
||||
lsql, params = compiler.compile(self.lhs)
|
||||
value_params.extend(params)
|
||||
|
||||
rsql, params = compiler.compile(self.rhs)
|
||||
value_params.extend(params)
|
||||
|
||||
combined_sql = f"{lsql} {self.connector} {rsql}"
|
||||
combined_value = combined_sql % tuple(value_params)
|
||||
return "%s", [combined_value]
|
||||
|
||||
|
||||
class SearchQueryCombinable:
|
||||
BITAND = "AND"
|
||||
BITOR = "OR"
|
||||
|
||||
def _combine(self, other, connector: str, reversed: bool = False):
|
||||
if not isinstance(other, SearchQueryCombinable):
|
||||
raise TypeError(
|
||||
"SearchQuery can only be combined with other SearchQuery "
|
||||
"instances, got %s." % type(other).__name__
|
||||
)
|
||||
if reversed:
|
||||
return CombinedSearchQueryExpression(other, connector, self)
|
||||
return CombinedSearchQueryExpression(self, connector, other)
|
||||
|
||||
# On Combinable, these are not implemented to reduce confusion with Q. In
|
||||
# this case we are actually (ab)using them to do logical combination so
|
||||
# it's consistent with other usage in Django.
|
||||
def __or__(self, other):
|
||||
return self._combine(other, self.BITOR, False)
|
||||
|
||||
def __ror__(self, other):
|
||||
return self._combine(other, self.BITOR, True)
|
||||
|
||||
def __and__(self, other):
|
||||
return self._combine(other, self.BITAND, False)
|
||||
|
||||
def __rand__(self, other):
|
||||
return self._combine(other, self.BITAND, True)
|
||||
|
||||
|
||||
class SearchQueryExpression(SearchQueryCombinable, Expression):
|
||||
def __init__(self, value: LexemeCombinable, using=None, **extra):
|
||||
super().__init__(output_field=SearchQueryField())
|
||||
self.using = using
|
||||
self.extra = extra
|
||||
if isinstance(value, str): # If the value is a string, we assume it's a phrase
|
||||
self.value = Value(
|
||||
'"%s"' % value
|
||||
) # We wrap it in quotes to make sure it's parsed as a phrase
|
||||
else: # Otherwise, we assume it's a lexeme
|
||||
self.value = value
|
||||
|
||||
def as_sql(
|
||||
self,
|
||||
compiler: SQLCompiler,
|
||||
connection: BaseDatabaseWrapper,
|
||||
**extra_context: Any,
|
||||
) -> Tuple[str, List[Any]]:
|
||||
sql, params = compiler.compile(self.value)
|
||||
return (sql, params)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.value.__repr__()
|
||||
|
||||
|
||||
class CombinedSearchQueryExpression(SearchQueryCombinable, CombinedExpression):
|
||||
def __init__(self, lhs, connector, rhs, output_field=None):
|
||||
super().__init__(lhs, connector, rhs, output_field)
|
||||
|
||||
def __str__(self):
|
||||
return "(%s)" % super().__str__()
|
||||
|
||||
|
||||
class MatchExpression(Expression):
|
||||
filterable = True
|
||||
template = (
|
||||
"wagtailsearch_indexentry_fts MATCH %s" # TODO: Can the table name be inferred?
|
||||
)
|
||||
output_field = BooleanField()
|
||||
|
||||
def __init__(self, columns: List[str], query: SearchQueryCombinable) -> None:
|
||||
super().__init__(output_field=self.output_field)
|
||||
self.columns = columns
|
||||
self.query = query
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
joined_columns = " ".join(
|
||||
self.columns
|
||||
) # The format of the columns is 'column1 column2'
|
||||
compiled_query = compiler.compile(self.query) # Compile the query to a string
|
||||
formatted_query = compiled_query[0] % tuple(
|
||||
compiled_query[1]
|
||||
) # Substitute the params in the query
|
||||
params = [
|
||||
"{{{column}}} : ({query})".format(
|
||||
column=joined_columns, query=formatted_query
|
||||
)
|
||||
] # Build the full MATCH search query. It will be a parameter to the template, so no SQL injections are possible here.
|
||||
return (self.template, params)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MatchExpression: {self.columns!r} = {self.query!r}>"
|
||||
|
||||
|
||||
class AndNot(SearchQuery):
|
||||
"""
|
||||
This is a binary search query, where there are two subqueries, and the search is done by searching the first, and excluding the second subquery.
|
||||
For example, AndNot(X, Y) would be equivalent to doing And(X, Not(Y)), where X is the first subquery, and Y is the second subquery (the negated one).
|
||||
This is done because the SQLite FTS5 module doesn't support the unary NOT operator.
|
||||
"""
|
||||
|
||||
def __init__(self, subquery_a: SearchQuery, subquery_b: SearchQuery):
|
||||
self.subquery_a = subquery_a
|
||||
self.subquery_b = subquery_b
|
||||
|
||||
def __repr__(self):
|
||||
return f"<{repr(self.subquery_a)} AndNot {repr(self.subquery_b)}>"
|
||||
|
||||
|
||||
def normalize(search_query: SearchQuery) -> Tuple[SearchQuery]:
|
||||
"""
|
||||
Turns this query into a normalized version.
|
||||
For example, And(Not(PlainText("Arepa")), PlainText("Crepe")) would be turned into AndNot(PlainText("Crepe"), PlainText("Arepa")): "Crepe AND NOT Arepa".
|
||||
This is done because we need to get the NOT operator to the front of the query, so it can be used in the search, because the SQLite FTS5 module doesn't support the unary NOT operator. This means that, in order to support the NOT operator, we need to match against the non-negated version of the query, and then return everything that is not in the results of the non-negated query.
|
||||
"""
|
||||
if isinstance(search_query, Phrase):
|
||||
return search_query # We can't normalize a Phrase.
|
||||
if isinstance(search_query, PlainText):
|
||||
return search_query # We can't normalize a PlainText.
|
||||
if isinstance(search_query, And):
|
||||
normalized_subqueries: List[SearchQuery] = [
|
||||
normalize(subquery) for subquery in search_query.subqueries
|
||||
] # This builds a list of normalized subqueries.
|
||||
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in normalized_subqueries
|
||||
if not isinstance(subquery, Not)
|
||||
] # All the non-negated subqueries.
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in not_negated_subqueries
|
||||
if not isinstance(subquery, MatchAll)
|
||||
] # We can ignore all MatchAll SearchQueries here, because they are redundant.
|
||||
negated_subqueries = [
|
||||
subquery.subquery
|
||||
for subquery in normalized_subqueries
|
||||
if isinstance(subquery, Not)
|
||||
]
|
||||
|
||||
if (
|
||||
negated_subqueries == []
|
||||
): # If there are no negated subqueries, return an And(), now without the redundant MatchAll subqueries.
|
||||
return And(not_negated_subqueries)
|
||||
|
||||
for subquery in (
|
||||
negated_subqueries
|
||||
): # If there's a negated MatchAll subquery, then nothing will get matched.
|
||||
if isinstance(subquery, MatchAll):
|
||||
return Not(MatchAll())
|
||||
|
||||
return AndNot(And(not_negated_subqueries), Or(negated_subqueries))
|
||||
if isinstance(search_query, Or):
|
||||
normalized_subqueries: List[SearchQuery] = [
|
||||
normalize(subquery) for subquery in search_query.subqueries
|
||||
] # This builds a list of (subquery, negated) tuples.
|
||||
|
||||
negated_subqueries = [
|
||||
subquery.subquery
|
||||
for subquery in normalized_subqueries
|
||||
if isinstance(subquery, Not)
|
||||
]
|
||||
if (
|
||||
negated_subqueries == []
|
||||
): # If there are no negated subqueries, return an Or().
|
||||
return Or(normalized_subqueries)
|
||||
|
||||
for subquery in (
|
||||
negated_subqueries
|
||||
): # If there's a MatchAll subquery, then anything will get matched.
|
||||
if isinstance(subquery, MatchAll):
|
||||
return MatchAll()
|
||||
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in normalized_subqueries
|
||||
if not isinstance(subquery, Not)
|
||||
] # All the non-negated subqueries.
|
||||
not_negated_subqueries = [
|
||||
subquery
|
||||
for subquery in not_negated_subqueries
|
||||
if not isinstance(subquery, MatchAll)
|
||||
] # We can ignore all MatchAll SearchQueries here, because they are redundant.
|
||||
|
||||
return AndNot(MatchAll(), And(negated_subqueries))
|
||||
if isinstance(search_query, Not):
|
||||
normalized = normalize(search_query.subquery)
|
||||
return Not(normalized) # Normalize the subquery, then invert it.
|
||||
if isinstance(search_query, MatchAll):
|
||||
return search_query # We can't normalize a MatchAll.
|
||||
707
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/sqlite.py
vendored
Normal file
707
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/sqlite.py
vendored
Normal file
@@ -0,0 +1,707 @@
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
|
||||
from django.db import DEFAULT_DB_ALIAS, NotSupportedError, connections, transaction
|
||||
from django.db.models import Avg, Count, F, Manager, Q, TextField
|
||||
from django.db.models.constants import LOOKUP_SEP
|
||||
from django.db.models.functions import Cast, Length
|
||||
from django.db.utils import OperationalError
|
||||
from django.utils.encoding import force_str
|
||||
from django.utils.functional import cached_property
|
||||
|
||||
from ....index import AutocompleteField, RelatedFields, SearchField, get_indexed_models
|
||||
from ....models import IndexEntry, SQLiteFTSIndexEntry
|
||||
from ....query import And, MatchAll, Not, Or, Phrase, PlainText
|
||||
from ....utils import (
|
||||
ADD,
|
||||
MUL,
|
||||
OR,
|
||||
get_content_type_pk,
|
||||
get_descendants_content_types_pks,
|
||||
)
|
||||
from ...base import (
|
||||
BaseSearchBackend,
|
||||
BaseSearchQueryCompiler,
|
||||
BaseSearchResults,
|
||||
FilterFieldError,
|
||||
)
|
||||
from .query import (
|
||||
BM25,
|
||||
AndNot,
|
||||
Lexeme,
|
||||
MatchExpression,
|
||||
SearchQueryExpression,
|
||||
normalize,
|
||||
)
|
||||
|
||||
|
||||
class ObjectIndexer:
|
||||
"""
|
||||
Responsible for extracting data from an object to be inserted into the index.
|
||||
"""
|
||||
|
||||
def __init__(self, obj, backend):
|
||||
self.obj = obj
|
||||
self.search_fields = obj.get_search_fields()
|
||||
self.config = backend.config
|
||||
|
||||
def prepare_value(self, value):
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
elif isinstance(value, list):
|
||||
return ", ".join(self.prepare_value(item) for item in value)
|
||||
|
||||
elif isinstance(value, dict):
|
||||
return ", ".join(self.prepare_value(item) for item in value.values())
|
||||
|
||||
return force_str(value)
|
||||
|
||||
def prepare_field(self, obj, field):
|
||||
if isinstance(field, SearchField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, AutocompleteField):
|
||||
yield (field, self.prepare_value(field.get_value(obj)))
|
||||
|
||||
elif isinstance(field, RelatedFields):
|
||||
sub_obj = field.get_value(obj)
|
||||
if sub_obj is None:
|
||||
return
|
||||
|
||||
if isinstance(sub_obj, Manager):
|
||||
sub_objs = sub_obj.all()
|
||||
|
||||
else:
|
||||
if callable(sub_obj):
|
||||
sub_obj = sub_obj()
|
||||
|
||||
sub_objs = [sub_obj]
|
||||
|
||||
for sub_obj in sub_objs:
|
||||
for sub_field in field.fields:
|
||||
yield from self.prepare_field(sub_obj, sub_field)
|
||||
|
||||
@cached_property
|
||||
def id(self):
|
||||
"""
|
||||
Returns the value to use as the ID of the record in the index
|
||||
"""
|
||||
return force_str(self.obj.pk)
|
||||
|
||||
@cached_property
|
||||
def title(self):
|
||||
"""
|
||||
Returns all values to index as "title". This is the value of all SearchFields that have the field_name 'title'
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def body(self):
|
||||
"""
|
||||
Returns all values to index as "body". This is the value of all SearchFields excluding the title
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if (
|
||||
isinstance(current_field, SearchField)
|
||||
and not current_field.field_name == "title"
|
||||
):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
@cached_property
|
||||
def autocomplete(self):
|
||||
"""
|
||||
Returns all values to index as "autocomplete". This is the value of all AutocompleteFields
|
||||
"""
|
||||
texts = []
|
||||
for field in self.search_fields:
|
||||
for current_field, value in self.prepare_field(self.obj, field):
|
||||
if isinstance(current_field, AutocompleteField):
|
||||
texts.append(value)
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
def as_vector(self, texts, for_autocomplete=False):
|
||||
"""
|
||||
Converts an array of strings into a SearchVector that can be indexed.
|
||||
"""
|
||||
texts = [(text.strip(), weight) for text, weight in texts]
|
||||
texts = [(text, weight) for text, weight in texts if text]
|
||||
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
class Index:
|
||||
def __init__(self, backend, db_alias=None):
|
||||
self.backend = backend
|
||||
self.name = self.backend.index_name
|
||||
self.db_alias = DEFAULT_DB_ALIAS if db_alias is None else db_alias
|
||||
self.connection = connections[self.db_alias]
|
||||
if self.connection.vendor != "sqlite":
|
||||
raise NotSupportedError(
|
||||
"You must select a SQLite database " "to use the SQLite search backend."
|
||||
)
|
||||
|
||||
self.entries = IndexEntry._default_manager.using(self.db_alias)
|
||||
|
||||
def add_model(self, model):
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
pass
|
||||
|
||||
def _refresh_title_norms(self, full=False):
|
||||
"""
|
||||
Refreshes the value of the title_norm field.
|
||||
|
||||
This needs to be set to 'lavg/ld' where:
|
||||
- lavg is the average length of titles in all documents (also in terms)
|
||||
- ld is the length of the title field in this document (in terms)
|
||||
"""
|
||||
|
||||
lavg = (
|
||||
self.entries.annotate(title_length=Length("title"))
|
||||
.filter(title_length__gt=0)
|
||||
.aggregate(Avg("title_length"))["title_length__avg"]
|
||||
)
|
||||
|
||||
if full:
|
||||
# Update the whole table
|
||||
# This is the most accurate option but requires a full table rewrite
|
||||
# so we can't do it too often as it could lead to locking issues.
|
||||
entries = self.entries
|
||||
|
||||
else:
|
||||
# Only update entries where title_norm is 1.0
|
||||
# This is the default value set on new entries.
|
||||
# It's possible that other entries could have this exact value but there shouldn't be too many of those
|
||||
entries = self.entries.filter(title_norm=1.0)
|
||||
|
||||
entries.annotate(title_length=Length("title")).filter(
|
||||
title_length__gt=0
|
||||
).update(title_norm=lavg / F("title_length"))
|
||||
|
||||
def delete_stale_model_entries(self, model):
|
||||
existing_pks = (
|
||||
model._default_manager.using(self.db_alias)
|
||||
.annotate(object_id=Cast("pk", TextField()))
|
||||
.values("object_id")
|
||||
)
|
||||
content_types_pks = get_descendants_content_types_pks(model)
|
||||
stale_entries = self.entries.filter(
|
||||
content_type_id__in=content_types_pks
|
||||
).exclude(object_id__in=existing_pks)
|
||||
stale_entries.delete()
|
||||
|
||||
def delete_stale_entries(self):
|
||||
for model in get_indexed_models():
|
||||
# We don’t need to delete stale entries for non-root models,
|
||||
# since we already delete them by deleting roots.
|
||||
if not model._meta.parents:
|
||||
self.delete_stale_model_entries(model)
|
||||
|
||||
def add_item(self, obj):
|
||||
self.add_items(obj._meta.model, [obj])
|
||||
|
||||
def add_items_update_then_create(self, content_type_pk, indexers):
|
||||
ids_and_data = {}
|
||||
for indexer in indexers:
|
||||
ids_and_data[indexer.id] = (
|
||||
indexer.title,
|
||||
indexer.autocomplete,
|
||||
indexer.body,
|
||||
)
|
||||
|
||||
index_entries_for_ct = self.entries.filter(content_type_id=content_type_pk)
|
||||
indexed_ids = frozenset(
|
||||
index_entries_for_ct.filter(object_id__in=ids_and_data.keys()).values_list(
|
||||
"object_id", flat=True
|
||||
)
|
||||
)
|
||||
for indexed_id in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[indexed_id]
|
||||
index_entries_for_ct.filter(object_id=indexed_id).update(
|
||||
title=title, autocomplete=autocomplete, body=body
|
||||
)
|
||||
|
||||
to_be_created = []
|
||||
for object_id in ids_and_data.keys():
|
||||
if object_id not in indexed_ids:
|
||||
title, autocomplete, body = ids_and_data[object_id]
|
||||
to_be_created.append(
|
||||
IndexEntry(
|
||||
content_type_id=content_type_pk,
|
||||
object_id=object_id,
|
||||
title=title,
|
||||
autocomplete=autocomplete,
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
self.entries.bulk_create(to_be_created)
|
||||
|
||||
self._refresh_title_norms()
|
||||
|
||||
def add_items(self, model, objs):
|
||||
search_fields = model.get_search_fields()
|
||||
if not search_fields:
|
||||
return
|
||||
|
||||
indexers = [ObjectIndexer(obj, self.backend) for obj in objs]
|
||||
|
||||
# TODO: Delete unindexed objects while dealing with proxy models.
|
||||
if indexers:
|
||||
content_type_pk = get_content_type_pk(model)
|
||||
|
||||
update_method = self.add_items_update_then_create
|
||||
update_method(content_type_pk, indexers)
|
||||
|
||||
def delete_item(self, item):
|
||||
item.index_entries.all()._raw_delete(using=self.db_alias)
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
class SQLiteSearchRebuilder:
|
||||
def __init__(self, index):
|
||||
self.index = index
|
||||
|
||||
def start(self):
|
||||
self.index.delete_stale_entries()
|
||||
return self.index
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
|
||||
class SQLiteSearchAtomicRebuilder(SQLiteSearchRebuilder):
|
||||
def __init__(self, index):
|
||||
super().__init__(index)
|
||||
self.transaction = transaction.atomic(using=index.db_alias)
|
||||
self.transaction_opened = False
|
||||
|
||||
def start(self):
|
||||
self.transaction.__enter__()
|
||||
self.transaction_opened = True
|
||||
return super().start()
|
||||
|
||||
def finish(self):
|
||||
self.index._refresh_title_norms(full=True)
|
||||
|
||||
self.transaction.__exit__(None, None, None)
|
||||
self.transaction_opened = False
|
||||
|
||||
def __del__(self):
|
||||
# TODO: Implement a cleaner way to close the connection on failure.
|
||||
if self.transaction_opened:
|
||||
self.transaction.needs_rollback = True
|
||||
self.finish()
|
||||
|
||||
|
||||
class SQLiteSearchQueryCompiler(BaseSearchQueryCompiler):
|
||||
DEFAULT_OPERATOR = "AND"
|
||||
LAST_TERM_IS_PREFIX = False
|
||||
TARGET_SEARCH_FIELD_TYPE = SearchField
|
||||
FTS_TABLE_FIELDS = ["title", "body"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
local_search_fields = self.get_search_fields_for_model()
|
||||
|
||||
if self.fields is None:
|
||||
# search over the fields defined on the current model
|
||||
self.search_fields = local_search_fields
|
||||
else:
|
||||
# build a search_fields set from the passed definition,
|
||||
# which may involve traversing relations
|
||||
self.search_fields = {
|
||||
field_lookup: self.get_search_field(
|
||||
field_lookup, fields=local_search_fields
|
||||
)
|
||||
for field_lookup in self.fields
|
||||
}
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_searchable_search_fields()
|
||||
|
||||
def get_search_field(self, field_lookup, fields=None):
|
||||
if fields is None:
|
||||
fields = self.search_fields
|
||||
|
||||
if LOOKUP_SEP in field_lookup:
|
||||
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
|
||||
else:
|
||||
sub_field_name = None
|
||||
|
||||
for field in fields:
|
||||
if (
|
||||
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
|
||||
and field.field_name == field_lookup
|
||||
):
|
||||
return field
|
||||
|
||||
# Note: Searching on a specific related field using
|
||||
# `.search(fields=…)` is not yet supported by Wagtail.
|
||||
# This method anticipates by already implementing it.
|
||||
if isinstance(field, RelatedFields) and field.field_name == field_lookup:
|
||||
return self.get_search_field(sub_field_name, field.fields)
|
||||
|
||||
def build_search_query_content(self, query, config=None):
|
||||
"""
|
||||
Takes a SearchQuery and returns another SearchQuery object, which can be used to construct the query in SQL.
|
||||
"""
|
||||
if isinstance(query, PlainText):
|
||||
terms = query.query_string.split()
|
||||
if not terms:
|
||||
return None
|
||||
|
||||
last_term = terms.pop()
|
||||
|
||||
lexemes = Lexeme(
|
||||
last_term, prefix=self.LAST_TERM_IS_PREFIX
|
||||
) # Combine all terms into a single lexeme.
|
||||
for term in terms:
|
||||
new_lexeme = Lexeme(term)
|
||||
|
||||
if query.operator.upper() == "AND":
|
||||
lexemes &= new_lexeme
|
||||
else:
|
||||
lexemes |= new_lexeme
|
||||
|
||||
return SearchQueryExpression(lexemes, config=config)
|
||||
|
||||
elif isinstance(query, Phrase):
|
||||
return SearchQueryExpression(query.query_string)
|
||||
|
||||
elif isinstance(query, AndNot):
|
||||
# Combine the two sub-queries into a query of the form `(first) AND NOT (second)`.
|
||||
subquery_a = self.build_search_query_content(
|
||||
query.subquery_a, config=config
|
||||
)
|
||||
subquery_b = self.build_search_query_content(
|
||||
query.subquery_b, config=config
|
||||
)
|
||||
combined_query = subquery_a._combine(subquery_b, "NOT")
|
||||
return combined_query
|
||||
|
||||
elif isinstance(query, (And, Or)):
|
||||
subquery_lexemes = [
|
||||
self.build_search_query_content(subquery, config=config)
|
||||
for subquery in query.subqueries
|
||||
]
|
||||
|
||||
is_and = isinstance(query, And)
|
||||
|
||||
if is_and:
|
||||
return reduce(lambda a, b: a & b, subquery_lexemes)
|
||||
else:
|
||||
return reduce(lambda a, b: a | b, subquery_lexemes)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the SQLite search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def build_search_query(self, query, config=None):
|
||||
if isinstance(query, MatchAll):
|
||||
return query
|
||||
if isinstance(query, Not):
|
||||
unwrapped_query = query.subquery
|
||||
built_query = Not(
|
||||
self.build_search_query(unwrapped_query, config=config)
|
||||
) # We don't take the Not operator into account.
|
||||
else:
|
||||
built_query = self.build_search_query_content(query, config=config)
|
||||
return built_query
|
||||
|
||||
def build_tsrank(self, vector, query, config=None, boost=1.0):
|
||||
if isinstance(query, (Phrase, PlainText, Not)):
|
||||
rank_expression = BM25()
|
||||
|
||||
if boost != 1.0:
|
||||
rank_expression *= boost
|
||||
|
||||
return rank_expression
|
||||
|
||||
elif isinstance(query, And):
|
||||
return (
|
||||
MUL(
|
||||
1 + self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
elif isinstance(query, Or):
|
||||
return ADD(
|
||||
self.build_tsrank(vector, subquery, config=config, boost=boost)
|
||||
for subquery in query.subqueries
|
||||
) / (len(query.subqueries) or 1)
|
||||
|
||||
raise NotImplementedError(
|
||||
"`%s` is not supported by the SQLite search backend."
|
||||
% query.__class__.__name__
|
||||
)
|
||||
|
||||
def get_index_vectors(self):
|
||||
return [
|
||||
(F("index_entries__title"), F("index_entries__title_norm")),
|
||||
(F("index_entries__body"), 1.0),
|
||||
]
|
||||
|
||||
def get_search_vectors(self):
|
||||
return self.get_index_vectors()
|
||||
|
||||
def _build_rank_expression(self, vectors, config):
|
||||
# TODO: Come up with my own expression class that compiles down to bm25
|
||||
|
||||
rank_expressions = [
|
||||
self.build_tsrank(vector, self.query, config=config) * boost
|
||||
for vector, boost in vectors
|
||||
]
|
||||
|
||||
rank_expression = rank_expressions[0]
|
||||
for other_rank_expression in rank_expressions[1:]:
|
||||
rank_expression += other_rank_expression
|
||||
|
||||
return rank_expression
|
||||
|
||||
def search(self, config, start, stop, score_field=None):
|
||||
normalized_query = normalize(self.query)
|
||||
|
||||
if isinstance(normalized_query, MatchAll):
|
||||
return self.queryset[start:stop]
|
||||
|
||||
elif isinstance(normalized_query, Not) and isinstance(
|
||||
normalized_query.subquery, MatchAll
|
||||
):
|
||||
return self.queryset.none()
|
||||
|
||||
if isinstance(normalized_query, Not):
|
||||
normalized_query = normalized_query.subquery
|
||||
negated = True
|
||||
else:
|
||||
negated = False
|
||||
|
||||
search_query = self.build_search_query(
|
||||
normalized_query, config=config
|
||||
) # We build a search query here, for example: "%s MATCH '(hello AND world)'"
|
||||
vectors = self.get_search_vectors()
|
||||
rank_expression = self._build_rank_expression(vectors, config)
|
||||
|
||||
combined_vector = vectors[
|
||||
0
|
||||
][
|
||||
0
|
||||
] # We create a combined vector for the search results queryset. We start with the first vector and build from there.
|
||||
for vector, boost in vectors[1:]:
|
||||
combined_vector = combined_vector._combine(
|
||||
vector, " ", False
|
||||
) # We add the subsequent vectors to the combined vector.
|
||||
|
||||
# Build the FTS match expression.
|
||||
expr = MatchExpression(self.fields or self.FTS_TABLE_FIELDS, search_query)
|
||||
# Perform the FTS search. We'll get entries in the SQLiteFTSIndexEntry model.
|
||||
objs = (
|
||||
SQLiteFTSIndexEntry.objects.filter(expr)
|
||||
.select_related("index_entry")
|
||||
.filter(
|
||||
index_entry__content_type__in=get_descendants_content_types_pks(
|
||||
self.queryset.model
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if self.order_by_relevance:
|
||||
objs = objs.order_by(BM25().desc())
|
||||
elif not objs.query.order_by:
|
||||
# Adds a default ordering to avoid issue #3729.
|
||||
queryset = objs.order_by("-pk")
|
||||
rank_expression = F("pk")
|
||||
|
||||
from django.db import connection
|
||||
from django.db.models.sql.subqueries import InsertQuery
|
||||
|
||||
compiler = InsertQuery(IndexEntry).get_compiler(connection=connection)
|
||||
|
||||
try:
|
||||
obj_ids = [
|
||||
obj.index_entry.object_id for obj in objs
|
||||
] # Get the IDs of the objects that matched. They're stored in the IndexEntry model, so we need to get that first.
|
||||
except OperationalError as e:
|
||||
raise OperationalError(
|
||||
str(e)
|
||||
+ " The original query was: "
|
||||
+ compiler.compile(objs.query)[0]
|
||||
+ str(compiler.compile(objs.query)[1])
|
||||
) from e
|
||||
|
||||
if not negated:
|
||||
queryset = self.queryset.filter(
|
||||
id__in=obj_ids
|
||||
) # We need to filter the source queryset to get the objects that matched the search query.
|
||||
else:
|
||||
queryset = self.queryset.exclude(
|
||||
id__in=obj_ids
|
||||
) # We exclude the objects that matched the search query from the source queryset, if the query is negated.
|
||||
|
||||
if score_field is not None:
|
||||
queryset = queryset.annotate(**{score_field: rank_expression})
|
||||
|
||||
return queryset[start:stop]
|
||||
|
||||
def _process_lookup(self, field, lookup, value):
|
||||
lhs = field.get_attname(self.queryset.model) + "__" + lookup
|
||||
return Q(**{lhs: value})
|
||||
|
||||
def _connect_filters(self, filters, connector, negated):
|
||||
if connector == "AND":
|
||||
q = Q(*filters)
|
||||
|
||||
elif connector == "OR":
|
||||
q = OR([Q(fil) for fil in filters])
|
||||
|
||||
else:
|
||||
return
|
||||
|
||||
if negated:
|
||||
q = ~q
|
||||
|
||||
return q
|
||||
|
||||
|
||||
class SQLiteAutocompleteQueryCompiler(SQLiteSearchQueryCompiler):
|
||||
LAST_TERM_IS_PREFIX = True
|
||||
TARGET_SEARCH_FIELD_TYPE = AutocompleteField
|
||||
FTS_TABLE_FIELDS = ["autocomplete"]
|
||||
|
||||
def get_config(self, backend):
|
||||
return backend.autocomplete_config
|
||||
|
||||
def get_search_fields_for_model(self):
|
||||
return self.queryset.model.get_autocomplete_search_fields()
|
||||
|
||||
def get_index_vectors(self):
|
||||
return [(F("index_entries__autocomplete"), 1.0)]
|
||||
|
||||
|
||||
class SQLiteSearchResults(BaseSearchResults):
|
||||
def get_queryset(self, for_count=False):
|
||||
if for_count:
|
||||
start = None
|
||||
stop = None
|
||||
else:
|
||||
start = self.start
|
||||
stop = self.stop
|
||||
|
||||
return self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend),
|
||||
start,
|
||||
stop,
|
||||
score_field=self._score_field,
|
||||
)
|
||||
|
||||
def _do_search(self):
|
||||
return list(self.get_queryset())
|
||||
|
||||
def _do_count(self):
|
||||
return self.get_queryset(for_count=True).count()
|
||||
|
||||
supports_facet = True
|
||||
|
||||
def facet(self, field_name):
|
||||
# Get field
|
||||
field = self.query_compiler._get_filterable_field(field_name)
|
||||
if field is None:
|
||||
raise FilterFieldError(
|
||||
'Cannot facet search results with field "'
|
||||
+ field_name
|
||||
+ "\". Please add index.FilterField('"
|
||||
+ field_name
|
||||
+ "') to "
|
||||
+ self.query_compiler.queryset.model.__name__
|
||||
+ ".search_fields.",
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
query = self.query_compiler.search(
|
||||
self.query_compiler.get_config(self.backend), None, None
|
||||
)
|
||||
results = (
|
||||
query.values(field_name).annotate(count=Count("pk")).order_by("-count")
|
||||
)
|
||||
|
||||
return OrderedDict(
|
||||
[(result[field_name], result["count"]) for result in results]
|
||||
)
|
||||
|
||||
|
||||
class SQLiteSearchBackend(BaseSearchBackend):
|
||||
query_compiler_class = SQLiteSearchQueryCompiler
|
||||
autocomplete_query_compiler_class = SQLiteAutocompleteQueryCompiler
|
||||
|
||||
results_class = SQLiteSearchResults
|
||||
rebuilder_class = SQLiteSearchRebuilder
|
||||
atomic_rebuilder_class = SQLiteSearchAtomicRebuilder
|
||||
|
||||
def __init__(self, params):
|
||||
super().__init__(params)
|
||||
self.index_name = params.get("INDEX", "default")
|
||||
|
||||
# SQLite backend currently has no config options
|
||||
self.config = None
|
||||
self.autocomplete_config = None
|
||||
|
||||
if params.get("ATOMIC_REBUILD", False):
|
||||
self.rebuilder_class = self.atomic_rebuilder_class
|
||||
|
||||
def get_index_for_model(self, model, db_alias=None):
|
||||
return Index(self, db_alias)
|
||||
|
||||
def get_index_for_object(self, obj):
|
||||
return self.get_index_for_model(obj._meta.model, obj._state.db)
|
||||
|
||||
def reset_index(self):
|
||||
for connection in [
|
||||
connection
|
||||
for connection in connections.all()
|
||||
if connection.vendor == "sqlite"
|
||||
]:
|
||||
IndexEntry._default_manager.all()._raw_delete(using=connection.alias)
|
||||
|
||||
def add_type(self, model):
|
||||
pass # Not needed.
|
||||
|
||||
def refresh_index(self):
|
||||
pass # Not needed.
|
||||
|
||||
def add(self, obj):
|
||||
self.get_index_for_object(obj).add_item(obj)
|
||||
|
||||
def add_bulk(self, model, obj_list):
|
||||
if obj_list:
|
||||
self.get_index_for_object(obj_list[0]).add_items(model, obj_list)
|
||||
|
||||
def delete(self, obj):
|
||||
self.get_index_for_object(obj).delete_item(obj)
|
||||
|
||||
|
||||
SearchBackend = SQLiteSearchBackend
|
||||
35
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/utils.py
vendored
Normal file
35
env/lib/python3.10/site-packages/wagtail/search/backends/database/sqlite/utils.py
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
import sqlite3
|
||||
|
||||
from django.db import OperationalError
|
||||
|
||||
|
||||
def fts5_available():
|
||||
# based on https://stackoverflow.com/a/36656216/1853523
|
||||
if sqlite3.sqlite_version_info < (3, 19, 0):
|
||||
# Prior to version 3.19, SQLite doesn't support FTS5 queries with
|
||||
# column filters ('{column_1 column_2} : query'), which the sqlite
|
||||
# fulltext backend needs
|
||||
return False
|
||||
|
||||
tmp_db = sqlite3.connect(":memory:")
|
||||
try:
|
||||
tmp_db.execute("CREATE VIRTUAL TABLE fts5test USING fts5 (data);")
|
||||
except sqlite3.OperationalError:
|
||||
return False
|
||||
finally:
|
||||
tmp_db.close()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def fts_table_exists():
|
||||
from wagtail.search.models import SQLiteFTSIndexEntry
|
||||
|
||||
try:
|
||||
# ignore result of query; we are only interested in the query failing,
|
||||
# not the presence of index entries
|
||||
SQLiteFTSIndexEntry.objects.exists()
|
||||
except OperationalError:
|
||||
return False
|
||||
|
||||
return True
|
||||
1260
env/lib/python3.10/site-packages/wagtail/search/backends/elasticsearch7.py
vendored
Normal file
1260
env/lib/python3.10/site-packages/wagtail/search/backends/elasticsearch7.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
110
env/lib/python3.10/site-packages/wagtail/search/backends/elasticsearch8.py
vendored
Normal file
110
env/lib/python3.10/site-packages/wagtail/search/backends/elasticsearch8.py
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
from django.core.exceptions import ImproperlyConfigured
|
||||
from elasticsearch import NotFoundError
|
||||
|
||||
from wagtail.search.backends.elasticsearch7 import (
|
||||
Elasticsearch7AutocompleteQueryCompiler,
|
||||
Elasticsearch7Index,
|
||||
Elasticsearch7Mapping,
|
||||
Elasticsearch7SearchBackend,
|
||||
Elasticsearch7SearchQueryCompiler,
|
||||
Elasticsearch7SearchResults,
|
||||
)
|
||||
from wagtail.search.index import class_is_indexed
|
||||
|
||||
|
||||
class Elasticsearch8Mapping(Elasticsearch7Mapping):
|
||||
pass
|
||||
|
||||
|
||||
class Elasticsearch8Index(Elasticsearch7Index):
|
||||
def put(self):
|
||||
self.es.indices.create(index=self.name, **self.backend.settings)
|
||||
|
||||
def delete(self):
|
||||
try:
|
||||
self.es.indices.delete(index=self.name)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
def refresh(self):
|
||||
self.es.indices.refresh(index=self.name)
|
||||
|
||||
def add_model(self, model):
|
||||
# Get mapping
|
||||
mapping = self.mapping_class(model)
|
||||
|
||||
# Put mapping
|
||||
self.es.indices.put_mapping(index=self.name, **mapping.get_mapping())
|
||||
|
||||
def add_item(self, item):
|
||||
# Make sure the object can be indexed
|
||||
if not class_is_indexed(item.__class__):
|
||||
return
|
||||
|
||||
# Get mapping
|
||||
mapping = self.mapping_class(item.__class__)
|
||||
|
||||
# Add document to index
|
||||
self.es.index(
|
||||
index=self.name,
|
||||
document=mapping.get_document(item),
|
||||
id=mapping.get_document_id(item),
|
||||
)
|
||||
|
||||
|
||||
class Elasticsearch8SearchQueryCompiler(Elasticsearch7SearchQueryCompiler):
|
||||
mapping_class = Elasticsearch8Mapping
|
||||
|
||||
|
||||
class Elasticsearch8SearchResults(Elasticsearch7SearchResults):
|
||||
def _backend_do_search(self, body, **kwargs):
|
||||
# As of Elasticsearch 7.15, the 'body' parameter is deprecated; instead, the top-level
|
||||
# keys of the body dict are now kwargs in their own right
|
||||
return self.backend.es.search(**body, **kwargs)
|
||||
|
||||
|
||||
class Elasticsearch8AutocompleteQueryCompiler(Elasticsearch7AutocompleteQueryCompiler):
|
||||
mapping_class = Elasticsearch8Mapping
|
||||
|
||||
|
||||
class Elasticsearch8SearchBackend(Elasticsearch7SearchBackend):
|
||||
mapping_class = Elasticsearch8Mapping
|
||||
index_class = Elasticsearch8Index
|
||||
query_compiler_class = Elasticsearch8SearchQueryCompiler
|
||||
autocomplete_query_compiler_class = Elasticsearch8AutocompleteQueryCompiler
|
||||
results_class = Elasticsearch8SearchResults
|
||||
timeout_kwarg_name = "request_timeout"
|
||||
|
||||
def _get_host_config_from_url(self, url):
|
||||
"""Given a parsed URL, return the host configuration to be added to self.hosts"""
|
||||
use_ssl = url.scheme == "https"
|
||||
port = url.port or (443 if use_ssl else 80)
|
||||
|
||||
# the verify_certs and http_auth options are no longer valid in Elasticsearch 8
|
||||
return {
|
||||
"host": url.hostname,
|
||||
"port": port,
|
||||
"path_prefix": url.path,
|
||||
"scheme": url.scheme,
|
||||
}
|
||||
|
||||
def _get_options_from_host_urls(self, urls):
|
||||
"""Given a list of parsed URLs, return a dict of additional options to be passed into the
|
||||
Elasticsearch constructor; necessary for options that aren't valid as part of the 'hosts' config"""
|
||||
opts = super()._get_options_from_host_urls(urls)
|
||||
|
||||
basic_auth = (urls[0].username, urls[0].password)
|
||||
# Ensure that all urls have the same credentials
|
||||
if any((url.username, url.password) != basic_auth for url in urls):
|
||||
raise ImproperlyConfigured(
|
||||
"Elasticsearch host configuration is invalid. "
|
||||
"Elasticsearch 8 does not support multiple hosts with differing authentication credentials."
|
||||
)
|
||||
|
||||
if basic_auth != (None, None):
|
||||
opts["basic_auth"] = basic_auth
|
||||
|
||||
return opts
|
||||
|
||||
|
||||
SearchBackend = Elasticsearch8SearchBackend
|
||||
Reference in New Issue
Block a user