Source code for drf_haystack.serializers

# -*- coding: utf-8 -*-

from __future__ import absolute_import, unicode_literals

import copy
import warnings
from itertools import chain
from datetime import datetime

try:
    from collections import OrderedDict
except ImportError:
    from django.utils.datastructures import SortedDict as OrderedDict

from django.core.exceptions import ImproperlyConfigured, FieldDoesNotExist
from django.utils import six

from haystack import fields as haystack_fields
from haystack.query import EmptySearchQuerySet
from haystack.utils import Highlighter

from rest_framework import serializers
from rest_framework.fields import empty
from rest_framework.utils.field_mapping import ClassLookupDict, get_field_kwargs

from drf_haystack.fields import (
    HaystackBooleanField, HaystackCharField, HaystackDateField, HaystackDateTimeField,
    HaystackDecimalField, HaystackFloatField, HaystackIntegerField, HaystackMultiValueField,
    FacetDictField, FacetListField
)


[docs]class Meta(type): """ Template for the HaystackSerializerMeta.Meta class. """ fields = tuple() exclude = tuple() search_fields = tuple() index_classes = tuple() serializers = tuple() ignore_fields = tuple() field_aliases = {} field_options = {} index_aliases = {} def __new__(mcs, name, bases, attrs): cls = super(Meta, mcs).__new__(mcs, str(name), bases, attrs) if cls.fields and cls.exclude: raise ImproperlyConfigured("%s cannot define both 'fields' and 'exclude'." % name) return cls def __setattr__(cls, key, value): raise AttributeError("Meta class is immutable.") def __delattr__(cls, key, value): raise AttributeError("Meta class is immutable.")
[docs]class HaystackSerializerMeta(serializers.SerializerMetaclass): """ Metaclass for the HaystackSerializer that ensures that all declared subclasses implemented a Meta. """ def __new__(mcs, name, bases, attrs): attrs.setdefault("_abstract", False) cls = super(HaystackSerializerMeta, mcs).__new__(mcs, str(name), bases, attrs) if getattr(cls, "Meta", None): cls.Meta = Meta("Meta", (Meta,), dict(cls.Meta.__dict__)) elif not cls._abstract: raise ImproperlyConfigured("%s must implement a Meta class or have the property _abstract" % name) return cls
[docs]class HaystackSerializer(six.with_metaclass(HaystackSerializerMeta, serializers.Serializer)): """ A `HaystackSerializer` which populates fields based on which models that are available in the SearchQueryset. """ _abstract = True _field_mapping = ClassLookupDict({ haystack_fields.BooleanField: HaystackBooleanField, haystack_fields.CharField: HaystackCharField, haystack_fields.DateField: HaystackDateField, haystack_fields.DateTimeField: HaystackDateTimeField, haystack_fields.DecimalField: HaystackDecimalField, haystack_fields.EdgeNgramField: HaystackCharField, haystack_fields.FacetBooleanField: HaystackBooleanField, haystack_fields.FacetCharField: HaystackCharField, haystack_fields.FacetDateField: HaystackDateField, haystack_fields.FacetDateTimeField: HaystackDateTimeField, haystack_fields.FacetDecimalField: HaystackDecimalField, haystack_fields.FacetFloatField: HaystackFloatField, haystack_fields.FacetIntegerField: HaystackIntegerField, haystack_fields.FacetMultiValueField: HaystackMultiValueField, haystack_fields.FloatField: HaystackFloatField, haystack_fields.IntegerField: HaystackIntegerField, haystack_fields.LocationField: HaystackCharField, haystack_fields.MultiValueField: HaystackMultiValueField, haystack_fields.NgramField: HaystackCharField, }) def __init__(self, instance=None, data=empty, **kwargs): super(HaystackSerializer, self).__init__(instance, data, **kwargs) if not self.Meta.index_classes and not self.Meta.serializers: raise ImproperlyConfigured("You must set either the 'index_classes' or 'serializers' " "attribute on the serializer Meta class.") if not self.instance: self.instance = EmptySearchQuerySet() @staticmethod def _get_default_field_kwargs(model, field): """ Get the required attributes from the model field in order to instantiate a REST Framework serializer field. """ kwargs = {} try: field_name = field.model_attr or field.index_fieldname model_field = model._meta.get_field(field_name) kwargs.update(get_field_kwargs(field_name, model_field)) # Remove stuff we don't care about! delete_attrs = [ "allow_blank", "choices", "model_field", "allow_unicode", ] for attr in delete_attrs: if attr in kwargs: del kwargs[attr] except FieldDoesNotExist: pass return kwargs def _get_index_field(self, field_name): """ Returns the correct index field. """ return field_name def _get_index_class_name(self, index_cls): """ Converts in index model class to a name suitable for use as a field name prefix. A user may optionally specify custom aliases via an 'index_aliases' attribute on the Meta class """ cls_name = index_cls.__name__ aliases = self.Meta.index_aliases return aliases.get(cls_name, cls_name.split('.')[-1])
[docs] def get_fields(self): """ Get the required fields for serializing the result. """ fields = self.Meta.fields exclude = self.Meta.exclude ignore_fields = self.Meta.ignore_fields indices = self.Meta.index_classes declared_fields = copy.deepcopy(self._declared_fields) prefix_field_names = len(indices) > 1 field_mapping = OrderedDict() # overlapping fields on multiple indices is supported by internally prefixing the field # names with the index class to which they belong or, optionally, a user-provided alias # for the index. for index_cls in self.Meta.index_classes: prefix = "" if prefix_field_names: prefix = "_%s__" % self._get_index_class_name(index_cls) for field_name, field_type in six.iteritems(index_cls.fields): orig_name = field_name field_name = "%s%s" % (prefix, field_name) # Don't use this field if it is in `ignore_fields` if orig_name in ignore_fields or field_name in ignore_fields: continue # When fields to include are decided by `exclude` if exclude: if orig_name in exclude or field_name in exclude: continue # When fields to include are decided by `fields` if fields: if orig_name not in fields and field_name not in fields: continue # Look up the field attributes on the current index model, # in order to correctly instantiate the serializer field. model = index_cls().get_model() kwargs = self._get_default_field_kwargs(model, field_type) kwargs['prefix_field_names'] = prefix_field_names field_mapping[field_name] = self._field_mapping[field_type](**kwargs) # Add any explicitly declared fields. They *will* override any index fields # in case of naming collision!. if declared_fields: for field_name in declared_fields: field_mapping[field_name] = declared_fields[field_name] return field_mapping
[docs] def to_representation(self, instance): """ If we have a serializer mapping, use that. Otherwise, use standard serializer behavior Since we might be dealing with multiple indexes, some fields might not be valid for all results. Do not render the fields which don't belong to the search result. """ if self.Meta.serializers: ret = self.multi_serializer_representation(instance) else: ret = super(HaystackSerializer, self).to_representation(instance) prefix_field_names = len(getattr(self.Meta, "index_classes")) > 1 current_index = self._get_index_class_name(type(instance.searchindex)) for field in self.fields.keys(): orig_field = field if prefix_field_names: parts = field.split("__") if len(parts) > 1: index = parts[0][1:] # trim the preceding '_' field = parts[1] if index == current_index: ret[field] = ret[orig_field] del ret[orig_field] elif field not in chain(instance.searchindex.fields.keys(), self._declared_fields.keys()): del ret[orig_field] # include the highlighted field in either case if getattr(instance, "highlighted", None): ret["highlighted"] = instance.highlighted[0] return ret
[docs] def multi_serializer_representation(self, instance): serializers = self.Meta.serializers index = instance.searchindex serializer_class = serializers.get(type(index), None) if not serializer_class: raise ImproperlyConfigured("Could not find serializer for %s in mapping" % index) return serializer_class(context=self._context).to_representation(instance)
[docs]class FacetFieldSerializer(serializers.Serializer): """ Responsible for serializing a faceted result. """ text = serializers.SerializerMethodField() count = serializers.SerializerMethodField() narrow_url = serializers.SerializerMethodField() def __init__(self, *args, **kwargs): self._parent_field = None super(FacetFieldSerializer, self).__init__(*args, **kwargs) @property def parent_field(self): return self._parent_field @parent_field.setter def parent_field(self, value): self._parent_field = value
[docs] def get_paginate_by_param(self): """ Returns the ``paginate_by_param`` for the (root) view paginator class. This is needed in order to remove the query parameter from faceted narrow urls. If using a custom pagination class, this class attribute needs to be set manually. """ if hasattr(self.root, "paginate_by_param") and self.root.paginate_by_param: return self.root.paginate_by_param pagination_class = self.context["view"].pagination_class if not pagination_class: return None # PageNumberPagination if hasattr(pagination_class, "page_query_param"): return pagination_class.page_query_param # LimitOffsetPagination elif hasattr(pagination_class, "offset_query_param"): return pagination_class.offset_query_param # CursorPagination elif hasattr(pagination_class, "cursor_query_param"): return pagination_class.cursor_query_param else: raise AttributeError( "%(root_cls)s is missing a `paginate_by_param` attribute. " "Define a %(root_cls)s.paginate_by_param or override " "%(cls)s.get_paginate_by_param()." % { "root_cls": self.root.__class__.__name__, "cls": self.__class__.__name__ })
[docs] def get_text(self, instance): """ Haystack facets are returned as a two-tuple (value, count). The text field should contain the faceted value. """ instance = instance[0] if isinstance(instance, (six.text_type, six.string_types)): return serializers.CharField(read_only=True).to_representation(instance) elif isinstance(instance, datetime): return serializers.DateTimeField(read_only=True).to_representation(instance) return instance
[docs] def get_count(self, instance): """ Haystack facets are returned as a two-tuple (value, count). The count field should contain the faceted count. """ instance = instance[1] return serializers.IntegerField(read_only=True).to_representation(instance)
[docs] def get_narrow_url(self, instance): """ Return a link suitable for narrowing on the current item. """ text = instance[0] request = self.context["request"] query_params = request.GET.copy() # Never keep the page query parameter in narrowing urls. # It will raise a NotFound exception when trying to paginate a narrowed queryset. page_query_param = self.get_paginate_by_param() if page_query_param and page_query_param in query_params: del query_params[page_query_param] selected_facets = set(query_params.pop(self.root.facet_query_params_text, [])) selected_facets.add("%(field)s_exact:%(text)s" % {"field": self.parent_field, "text": text}) query_params.setlist(self.root.facet_query_params_text, sorted(selected_facets)) path = "%(path)s?%(query)s" % {"path": request.path_info, "query": query_params.urlencode()} url = request.build_absolute_uri(path) return serializers.Hyperlink(url, "narrow-url")
[docs] def to_representation(self, field, instance): """ Set the ``parent_field`` property equal to the current field on the serializer class, so that each field can query it to see what kind of attribute they are processing. """ self.parent_field = field return super(FacetFieldSerializer, self).to_representation(instance)
[docs]class HaystackFacetSerializer(six.with_metaclass(HaystackSerializerMeta, serializers.Serializer)): """ The ``HaystackFacetSerializer`` is used to serialize the ``facet_counts()`` dictionary results on a ``SearchQuerySet`` instance. """ _abstract = True serialize_objects = False paginate_by_param = None facet_dict_field_class = FacetDictField facet_list_field_class = FacetListField facet_field_serializer_class = FacetFieldSerializer
[docs] def get_fields(self): """ This returns a dictionary containing the top most fields, ``dates``, ``fields`` and ``queries``. """ field_mapping = OrderedDict() for field, data in self.instance.items(): field_mapping.update( {field: self.facet_dict_field_class( child=self.facet_list_field_class(child=self.facet_field_serializer_class(data)), required=False)} ) if self.serialize_objects is True: field_mapping["objects"] = serializers.SerializerMethodField() return field_mapping
[docs] def get_objects(self, instance): """ Return a list of objects matching the faceted result. """ view = self.context["view"] queryset = self.context["objects"] page = view.paginate_queryset(queryset) if page is not None: serializer = view.get_facet_objects_serializer(page, many=True) return OrderedDict([ ("count", self.get_count(queryset)), ("next", view.paginator.get_next_link()), ("previous", view.paginator.get_previous_link()), ("results", serializer.data) ]) serializer = view.get_serializer(queryset, many=True) return serializer.data
[docs] def get_count(self, queryset): """ Determine an object count, supporting either querysets or regular lists. """ try: return queryset.count() except (AttributeError, TypeError): return len(queryset)
@property def facet_query_params_text(self): return self.context["facet_query_params_text"]
[docs]class HaystackSerializerMixin(object): """ This mixin can be added to a serializer to use the actual object as the data source for serialization rather than the data stored in the search index fields. This makes it easy to return data from search results in the same format as elsewhere in your API and reuse your existing serializers """
[docs] def to_representation(self, instance): obj = instance.object return super(HaystackSerializerMixin, self).to_representation(obj)
[docs]class HighlighterMixin(object): """ This mixin adds support for ``highlighting`` (the pure python, portable version, not SearchQuerySet().highlight()). See Haystack docs for more info). """ highlighter_class = Highlighter highlighter_css_class = "highlighted" highlighter_html_tag = "span" highlighter_max_length = 200 highlighter_field = None
[docs] def get_highlighter(self): if not self.highlighter_class: raise ImproperlyConfigured( "%(cls)s is missing a highlighter_class. Define %(cls)s.highlighter_class, " "or override %(cls)s.get_highlighter()." % {"cls": self.__class__.__name__} ) return self.highlighter_class
[docs] @staticmethod def get_document_field(instance): """ Returns which field the search index has marked as it's `document=True` field. """ for name, field in instance.searchindex.fields.items(): if field.document is True: return name
[docs] def to_representation(self, instance): ret = super(HighlighterMixin, self).to_representation(instance) terms = " ".join(six.itervalues(self.context["request"].GET)) if terms: highlighter = self.get_highlighter()(terms, **{ "html_tag": self.highlighter_html_tag, "css_class": self.highlighter_css_class, "max_length": self.highlighter_max_length }) document_field = self.get_document_field(instance) if highlighter and document_field: ret["highlighted"] = highlighter.highlight(getattr(instance, self.highlighter_field or document_field)) return ret