Source code for drf_haystack.query

# -*- coding: utf-8 -*-

from __future__ import absolute_import, unicode_literals

import operator
import warnings
from itertools import chain

from django.utils import six
from django.utils.six.moves import zip
from dateutil import parser

from drf_haystack import constants
from drf_haystack.utils import merge_dict


[docs]class BaseQueryBuilder(object): """ Query builder base class. """ def __init__(self, backend, view): self.backend = backend self.view = view
[docs] def build_query(self, **filters): """ :param dict[str, list[str]] filters: is an expanded QueryDict or a mapping of keys to a list of parameters. """ raise NotImplementedError("You should override this method in subclasses.")
[docs] @staticmethod def tokenize(stream, separator): """ Tokenize and yield query parameter values. :param stream: Input value :param separator: Character to use to separate the tokens. :return: """ for value in stream: for token in value.split(separator): if token: yield token.strip()
[docs]class BoostQueryBuilder(BaseQueryBuilder): """ Query builder class for adding boost to queries. """
[docs] def build_query(self, **filters): applicable_filters = None query_param = getattr(self.backend, "query_param", None) value = filters.pop(query_param, None) if value: try: term, val = chain.from_iterable(zip(self.tokenize(value, self.view.lookup_sep))) except ValueError: raise ValueError("Cannot convert the '%s' query parameter to a valid boost filter." % query_param) else: try: applicable_filters = {"term": term, "boost": float(val)} except ValueError: raise ValueError("Cannot convert boost to float value. Make sure to provide a " "numerical boost value.") return applicable_filters
[docs]class FilterQueryBuilder(BaseQueryBuilder): """ Query builder class suitable for doing basic filtering. """ def __init__(self, backend, view): super(FilterQueryBuilder, self).__init__(backend, view) assert getattr(self.backend, "default_operator", None) in (operator.and_, operator.or_), ( "%(cls)s.default_operator must be either 'operator.and_' or 'operator.or_'." % { "cls": self.backend.__class__.__name__ }) self.default_operator = self.backend.default_operator
[docs] def build_query(self, **filters): """ Creates a single SQ filter from querystring parameters that correspond to the SearchIndex fields that have been "registered" in `view.fields`. Default behavior is to `OR` terms for the same parameters, and `AND` between parameters. Any querystring parameters that are not registered in `view.fields` will be ignored. :param dict[str, list[str]] filters: is an expanded QueryDict or a mapping of keys to a list of parameters. """ applicable_filters = [] applicable_exclusions = [] for param, value in filters.items(): excluding_term = False param_parts = param.split("__") base_param = param_parts[0] # only test against field without lookup negation_keyword = constants.DRF_HAYSTACK_NEGATION_KEYWORD if len(param_parts) > 1 and param_parts[1] == negation_keyword: excluding_term = True param = param.replace("__%s" % negation_keyword, "") # haystack wouldn't understand our negation if self.view.serializer_class: if hasattr(self.view.serializer_class.Meta, 'field_aliases'): old_base = base_param base_param = self.view.serializer_class.Meta.field_aliases.get(base_param, base_param) param = param.replace(old_base, base_param) # need to replace the alias fields = getattr(self.view.serializer_class.Meta, 'fields', []) exclude = getattr(self.view.serializer_class.Meta, 'exclude', []) search_fields = getattr(self.view.serializer_class.Meta, 'search_fields', []) # Skip if the parameter is not listed in the serializer's `fields` # or if it's in the `exclude` list. if ((fields or search_fields) and base_param not in chain(fields, search_fields)) or base_param in exclude or not value: continue field_queries = [] if len(param_parts) > 1 and param_parts[-1] in ('in', 'range'): # `in` and `range` filters expects a list of values field_queries.append(self.view.query_object((param, list(self.tokenize(value, self.view.lookup_sep))))) else: for token in self.tokenize(value, self.view.lookup_sep): field_queries.append(self.view.query_object((param, token))) field_queries = [fq for fq in field_queries if fq] if len(field_queries) > 0: term = six.moves.reduce(operator.or_, field_queries) if excluding_term: applicable_exclusions.append(term) else: applicable_filters.append(term) applicable_filters = six.moves.reduce( self.default_operator, filter(lambda x: x, applicable_filters)) if applicable_filters else [] applicable_exclusions = six.moves.reduce( self.default_operator, filter(lambda x: x, applicable_exclusions)) if applicable_exclusions else [] return applicable_filters, applicable_exclusions
[docs]class FacetQueryBuilder(BaseQueryBuilder): """ Query builder class suitable for constructing faceted queries. """
[docs] def build_query(self, **filters): """ Creates a dict of dictionaries suitable for passing to the SearchQuerySet `facet`, `date_facet` or `query_facet` method. All key word arguments should be wrapped in a list. :param view: API View :param dict[str, list[str]] filters: is an expanded QueryDict or a mapping of keys to a list of parameters. """ field_facets = {} date_facets = {} query_facets = {} facet_serializer_cls = self.view.get_facet_serializer_class() if self.view.lookup_sep == ":": raise AttributeError("The %(cls)s.lookup_sep attribute conflicts with the HaystackFacetFilter " "query parameter parser. Please choose another `lookup_sep` attribute " "for %(cls)s." % {"cls": self.view.__class__.__name__}) fields = facet_serializer_cls.Meta.fields exclude = facet_serializer_cls.Meta.exclude field_options = facet_serializer_cls.Meta.field_options for field, options in filters.items(): if field not in fields or field in exclude: continue field_options = merge_dict(field_options, {field: self.parse_field_options(self.view.lookup_sep, *options)}) valid_gap = ("year", "month", "day", "hour", "minute", "second") for field, options in field_options.items(): if any([k in options for k in ("start_date", "end_date", "gap_by", "gap_amount")]): if not all(("start_date", "end_date", "gap_by" in options)): raise ValueError("Date faceting requires at least 'start_date', 'end_date' " "and 'gap_by' to be set.") if not options["gap_by"] in valid_gap: raise ValueError("The 'gap_by' parameter must be one of %s." % ", ".join(valid_gap)) options.setdefault("gap_amount", 1) date_facets[field] = field_options[field] else: field_facets[field] = field_options[field] return { "date_facets": date_facets, "field_facets": field_facets, "query_facets": query_facets }
[docs] def parse_field_options(self, *options): """ Parse the field options query string and return it as a dictionary. """ defaults = {} for option in options: if isinstance(option, six.text_type): tokens = [token.strip() for token in option.split(self.view.lookup_sep)] for token in tokens: if not len(token.split(":")) == 2: warnings.warn("The %s token is not properly formatted. Tokens need to be " "formatted as 'token:value' pairs." % token) continue param, value = token.split(":", 1) if any([k == param for k in ("start_date", "end_date", "gap_amount")]): if param in ("start_date", "end_date"): value = parser.parse(value) if param == "gap_amount": value = int(value) defaults[param] = value return defaults
[docs]class SpatialQueryBuilder(BaseQueryBuilder): """ Query builder class suitable for construction spatial queries. """ def __init__(self, backend, view): super(SpatialQueryBuilder, self).__init__(backend, view) assert getattr(self.backend, "point_field", None) is not None, ( "%(cls)s.point_field cannot be None. Set the %(cls)s.point_field " "to the name of the `LocationField` you want to filter on your index class." % { "cls": self.backend.__class__.__name__ }) try: from haystack.utils.geo import D, Point self.D = D self.Point = Point except ImportError: warnings.warn("Make sure you've installed the `libgeos` library. " "Run `apt-get install libgeos` on debian based linux systems, " "or `brew install geos` on OS X.") raise
[docs] def build_query(self, **filters): """ Build queries for geo spatial filtering. Expected query parameters are: - a `unit=value` parameter where the unit is a valid UNIT in the `django.contrib.gis.measure.Distance` class. - `from` which must be a comma separated latitude and longitude. Example query: /api/v1/search/?km=10&from=59.744076,10.152045 Will perform a `dwithin` query within 10 km from the point with latitude 59.744076 and longitude 10.152045. """ applicable_filters = None filters = dict((k, filters[k]) for k in chain(self.D.UNITS.keys(), [constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM]) if k in filters) distance = dict((k, v) for k, v in filters.items() if k in self.D.UNITS.keys()) try: latitude, longitude = map(float, self.tokenize(filters[constants.DRF_HAYSTACK_SPATIAL_QUERY_PARAM], self.view.lookup_sep)) point = self.Point(longitude, latitude, srid=constants.GEO_SRID) except ValueError: raise ValueError("Cannot convert `from=latitude,longitude` query parameter to " "float values. Make sure to provide numerical values only!") except KeyError: # If the user has not provided any `from` query string parameter, # just return. pass else: for unit in distance.keys(): if not len(distance[unit]) == 1: raise ValueError("Each unit must have exactly one value.") distance[unit] = float(distance[unit][0]) if point and distance: applicable_filters = { "dwithin": { "field": self.backend.point_field, "point": point, "distance": self.D(**distance) }, "distance": { "field": self.backend.point_field, "point": point } } return applicable_filters