diff --git a/docs/persistence.rst b/docs/persistence.rst index cb1dc226..ab3ffbe4 100644 --- a/docs/persistence.rst +++ b/docs/persistence.rst @@ -9,7 +9,7 @@ layer for your application. For more comprehensive examples have a look at the examples_ directory in the repository. -.. _examples: https://github.com/elastic/elasticsearch-dsl-py/tree/master/examples +.. _examples: https://github.com/elastic/elasticsearch-dsl-py/tree/main/examples .. _doc_type: @@ -66,14 +66,14 @@ settings in elasticsearch (see :ref:`life-cycle` for details). Data types ~~~~~~~~~~ -The ``Document`` instances should be using native python types like +The ``Document`` instances use native python types like ``str`` and ``datetime``. In case of ``Object`` or ``Nested`` fields an instance of the -``InnerDoc`` subclass should be used just like in the ``add_comment`` method in -the above example where we are creating an instance of the ``Comment`` class. +``InnerDoc`` subclass is used, as in the ``add_comment`` method in the above +example where we are creating an instance of the ``Comment`` class. There are some specific types that were created as part of this library to make -working with specific field types easier, for example the ``Range`` object used -in any of the `range fields +working with some field types easier, for example the ``Range`` object used in +any of the `range fields `_: .. code:: python @@ -103,6 +103,174 @@ in any of the `range fields # empty range is unbounded Range().lower # None, False +Python Type Hints +~~~~~~~~~~~~~~~~~ + +Document fields can be defined using standard Python type hints if desired. +Here are some simple examples: + +.. code:: python + + from typing import Optional + + class Post(Document): + title: str # same as title = Text(required=True) + created_at: Optional[datetime] # same as created_at = Date(required=False) + published: bool # same as published = Boolean(required=True) + +It is important to note that when using ``Field`` subclasses such as ``Text``, +``Date`` and ``Boolean``, they must be given in the right-side of an assignment, +as shown in examples above. Using these classes as type hints will result in +errors. + +Python types are mapped to their corresponding field type according to the +following table: + +.. list-table:: Python type to DSL field mappings + :header-rows: 1 + + * - Python type + - DSL field + * - ``str`` + - ``Text(required=True)`` + * - ``bool`` + - ``Boolean(required=True)`` + * - ``int`` + - ``Integer(required=True)`` + * - ``float`` + - ``Float(required=True)`` + * - ``bytes`` + - ``Binary(required=True)`` + * - ``datetime`` + - ``Date(required=True)`` + * - ``date`` + - ``Date(format="yyyy-MM-dd", required=True)`` + +To type a field as optional, the standard ``Optional`` modifier from the Python +``typing`` package can be used. The ``List`` modifier can be added to a field +to convert it to an array, similar to using the ``multi=True`` argument on the +field object. + +.. code:: python + + from typing import Optional, List + + class MyDoc(Document): + pub_date: Optional[datetime] # same as pub_date = Date() + authors: List[str] # same as authors = Text(multi=True, required=True) + comments: Optional[List[str]] # same as comments = Text(multi=True) + +A field can also be given a type hint of an ``InnerDoc`` subclass, in which +case it becomes an ``Object`` field of that class. When the ``InnerDoc`` +subclass is wrapped with ``List``, a ``Nested`` field is created instead. + +.. code:: python + + from typing import List + + class Address(InnerDoc): + ... + + class Comment(InnerDoc): + ... + + class Post(Document): + address: Address # same as address = Object(Address, required=True) + comments: List[Comment] # same as comments = Nested(Comment, required=True) + +Unfortunately it is impossible to have Python type hints that uniquely +identify every possible Elasticsearch field type. To choose a field type that +is different than the ones in the table above, the field instance can be added +explicitly as a right-side assignment in the field declaration. The next +example creates a field that is typed as ``Optional[str]``, but is mapped to +``Keyword`` instead of ``Text``: + +.. code:: python + + class MyDocument(Document): + category: Optional[str] = Keyword() + +This form can also be used when additional options need to be given to +initialize the field, such as when using custom analyzer settings or changing +the ``required`` default: + +.. code:: python + + class Comment(InnerDoc): + content: str = Text(analyzer='snowball', required=True) + +When using type hints as above, subclasses of ``Document`` and ``InnerDoc`` +inherit some of the behaviors associated with Python dataclasses, as defined by +`PEP 681 `_ and the +`dataclass_transform decorator `_. +To add per-field dataclass options such as ``default`` or ``default_factory``, +the ``mapped_field()`` wrapper can be used on the right side of a typed field +declaration: + +.. code:: python + + class MyDocument(Document): + title: str = mapped_field(default="no title") + created_at: datetime = mapped_field(default_factory=datetime.now) + published: bool = mapped_field(default=False) + category: str = mapped_field(Keyword(required=True), default="general") + +When using the ``mapped_field()`` wrapper function, an explicit field type +instance can be passed as a first positional argument, as the ``category`` +field does in the example above. + +Static type checkers such as `mypy `_ and +`pyright `_ can use the type hints and +the dataclass-specific options added to the ``mapped_field()`` function to +improve type inference and provide better real-time suggestions in IDEs. + +One situation in which type checkers can't infer the correct type is when +using fields as class attributes. Consider the following example: + +.. code:: python + + class MyDocument(Document): + title: str + + doc = MyDocument() + # doc.title is typed as "str" (correct) + # MyDocument.title is also typed as "str" (incorrect) + +To help type checkers correctly identify class attributes as such, the ``M`` +generic must be used as a wrapper to the type hint, as shown in the next +examples: + +.. code:: python + + from elasticsearch_dsl import M + + class MyDocument(Document): + title: M[str] + created_at: M[datetime] = mapped_field(default_factory=datetime.now) + + doc = MyDocument() + # doc.title is typed as "str" + # doc.created_at is typed as "datetime" + # MyDocument.title is typed as "InstrumentedField" + # MyDocument.created_at is typed as "InstrumentedField" + +Note that the ``M`` type hint does not provide any runtime behavior and its use +is not required, but it can be useful to eliminate spurious type errors in IDEs +or type checking builds. + +The ``InstrumentedField`` objects returned when fields are accessed as class +attributes are proxies for the field instances that can be used anywhere a +field needs to be referenced, such as when specifying sort options in a +``Search`` object: + +.. code:: python + + # sort by creation date descending, and title ascending + s = MyDocument.search().sort(-MyDocument.created_at, MyDocument.title) + +When specifying sorting order, the ``+`` and ``-`` unary operators can be used +on the class field attributes to indicate ascending and descending order. + Note on dates ~~~~~~~~~~~~~ diff --git a/elasticsearch_dsl/__init__.py b/elasticsearch_dsl/__init__.py index e7de5319..fd4433c2 100644 --- a/elasticsearch_dsl/__init__.py +++ b/elasticsearch_dsl/__init__.py @@ -19,7 +19,7 @@ from .aggs import A from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer from .document import AsyncDocument, Document -from .document_base import InnerDoc, MetaField +from .document_base import InnerDoc, M, MetaField, mapped_field from .exceptions import ( ElasticsearchDslException, IllegalOperation, @@ -148,6 +148,7 @@ "Keyword", "Long", "LongRange", + "M", "Mapping", "MetaField", "MultiSearch", @@ -178,6 +179,7 @@ "char_filter", "connections", "construct_field", + "mapped_field", "normalizer", "token_filter", "tokenizer", diff --git a/elasticsearch_dsl/_async/document.py b/elasticsearch_dsl/_async/document.py index 89ed06f4..1dfb5b9d 100644 --- a/elasticsearch_dsl/_async/document.py +++ b/elasticsearch_dsl/_async/document.py @@ -18,10 +18,11 @@ import collections.abc from elasticsearch.exceptions import NotFoundError, RequestError +from typing_extensions import dataclass_transform from .._async.index import AsyncIndex from ..async_connections import get_connection -from ..document_base import DocumentBase, DocumentMeta +from ..document_base import DocumentBase, DocumentMeta, mapped_field from ..exceptions import IllegalOperation from ..utils import DOC_META_FIELDS, META_FIELDS, merge from .search import AsyncSearch @@ -62,6 +63,7 @@ def construct_index(cls, opts, bases): return i +@dataclass_transform(field_specifiers=(mapped_field,)) class AsyncDocument(DocumentBase, metaclass=AsyncIndexMeta): """ Model-like class for persisting documents in elasticsearch. diff --git a/elasticsearch_dsl/_sync/document.py b/elasticsearch_dsl/_sync/document.py index c851c8e8..7e7acd51 100644 --- a/elasticsearch_dsl/_sync/document.py +++ b/elasticsearch_dsl/_sync/document.py @@ -18,10 +18,11 @@ import collections.abc from elasticsearch.exceptions import NotFoundError, RequestError +from typing_extensions import dataclass_transform from .._sync.index import Index from ..connections import get_connection -from ..document_base import DocumentBase, DocumentMeta +from ..document_base import DocumentBase, DocumentMeta, mapped_field from ..exceptions import IllegalOperation from ..utils import DOC_META_FIELDS, META_FIELDS, merge from .search import Search @@ -60,6 +61,7 @@ def construct_index(cls, opts, bases): return i +@dataclass_transform(field_specifiers=(mapped_field,)) class Document(DocumentBase, metaclass=IndexMeta): """ Model-like class for persisting documents in elasticsearch. diff --git a/elasticsearch_dsl/document_base.py b/elasticsearch_dsl/document_base.py index 8abbc796..46694157 100644 --- a/elasticsearch_dsl/document_base.py +++ b/elasticsearch_dsl/document_base.py @@ -15,10 +15,24 @@ # specific language governing permissions and limitations # under the License. +from datetime import date, datetime from fnmatch import fnmatch +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + List, + Optional, + TypeVar, + Union, + overload, +) + +from typing_extensions import dataclass_transform from .exceptions import ValidationException -from .field import Field +from .field import Binary, Boolean, Date, Field, Float, Integer, Nested, Object, Text from .mapping import Mapping from .utils import DOC_META_FIELDS, ObjectBase @@ -28,26 +42,175 @@ def __init__(self, *args, **kwargs): self.args, self.kwargs = args, kwargs +class InstrumentedField: + """Proxy object for a mapped document field. + + An object of this instance is returned when a field is accessed as a class + attribute of a ``Document`` or ``InnerDoc`` subclass. These objects can + be used in any situation in which a reference to a field is required, such + as when specifying sort options in a search:: + + class MyDocument(Document): + name: str + + s = MyDocument.search() + s = s.sort(-MyDocument.name) # sort by name in descending order + """ + + def __init__(self, name, field): + self._name = name + self._field = field + + def __getattr__(self, attr): + try: + # first let's see if this is an attribute of this object + return super().__getattribute__(attr) + except AttributeError: + try: + # next we see if we have a sub-field with this name + return InstrumentedField(f"{self._name}.{attr}", self._field[attr]) + except KeyError: + # lastly we let the wrapped field resolve this attribute + return getattr(self._field, attr) + + def __pos__(self): + """Return the field name representation for ascending sort order""" + return f"{self._name}" + + def __neg__(self): + """Return the field name representation for descending sort order""" + return f"-{self._name}" + + def __str__(self): + return self._name + + def __repr__(self): + return f"InstrumentedField[{self._name}]" + + class DocumentMeta(type): def __new__(cls, name, bases, attrs): # DocumentMeta filters attrs in place attrs["_doc_type"] = DocumentOptions(name, bases, attrs) return super().__new__(cls, name, bases, attrs) + def __getattr__(cls, attr): + if attr in cls._doc_type.mapping: + return InstrumentedField(attr, cls._doc_type.mapping[attr]) + return super().__getattribute__(attr) + class DocumentOptions: + type_annotation_map = { + int: (Integer, {}), + float: (Float, {}), + bool: (Boolean, {}), + str: (Text, {}), + bytes: (Binary, {}), + datetime: (Date, {}), + date: (Date, {"format": "yyyy-MM-dd"}), + } + def __init__(self, name, bases, attrs): meta = attrs.pop("Meta", None) # create the mapping instance self.mapping = getattr(meta, "mapping", Mapping()) - # register all declared fields into the mapping - for name, value in list(attrs.items()): - if isinstance(value, Field): - self.mapping.field(name, value) + # register the document's fields, which can be given in a few formats: + # + # class MyDocument(Document): + # # required field using native typing + # # (str, int, float, bool, datetime, date) + # field1: str + # + # # optional field using native typing + # field2: Optional[datetime] + # + # # array field using native typing + # field3: list[int] + # + # # sub-object, same as Object(MyInnerDoc) + # field4: MyInnerDoc + # + # # nested sub-objects, same as Nested(MyInnerDoc) + # field5: list[MyInnerDoc] + # + # # use typing, but override with any stock or custom field + # field6: bool = MyCustomField() + # + # # best mypy and pyright support and dataclass-like behavior + # field7: M[date] + # field8: M[str] = mapped_field(MyCustomText(), default="foo") + # + # # legacy format without Python typing + # field8 = Text() + annotations = attrs.get("__annotations__", {}) + fields = set([n for n in attrs if isinstance(attrs[n], Field)]) + fields.update(annotations.keys()) + field_defaults = {} + for name in fields: + value = None + if name in attrs: + # this field has a right-side value, which can be field + # instance on its own or wrapped with mapped_field() + value = attrs[name] + if isinstance(value, dict): + # the mapped_field() wrapper function was used so we need + # to look for the field instance and also record any + # dataclass-style defaults + value = attrs[name].get("_field") + default_value = attrs[name].get("default") or attrs[name].get( + "default_factory" + ) + if default_value: + field_defaults[name] = default_value + if value is None: + # the field does not have an explicit field instance given in + # a right-side assignment, so we need to figure out what field + # type to use from the annotation + type_ = annotations[name] + required = True + multi = False + while hasattr(type_, "__origin__"): + if type_.__origin__ == Mapped: + # M[type] -> extract the wrapped type + type_ = type_.__args__[0] + elif type_.__origin__ == Union: + if len(type_.__args__) == 2 and type_.__args__[1] is type(None): + # Optional[type] -> mark instance as optional + required = False + type_ = type_.__args__[0] + else: + raise TypeError("Unsupported union") + elif type_.__origin__ in [list, List]: + # List[type] -> mark instance as multi + multi = True + type_ = type_.__args__[0] + else: + break + field_args = [] + field_kwargs = {} + if not isinstance(type_, type): + raise TypeError(f"Cannot map type {type_}") + elif issubclass(type_, InnerDoc): + # object or nested field + field = Nested if multi else Object + field_args = [type_] + elif type_ in self.type_annotation_map: + # use best field type for the type hint provided + field, field_kwargs = self.type_annotation_map[type_] + else: + raise TypeError(f"Cannot map type {type_}") + field_kwargs = {"multi": multi, "required": required, **field_kwargs} + value = field(*field_args, **field_kwargs) + self.mapping.field(name, value) + if name in attrs: del attrs[name] + # store dataclass-style defaults for ObjectBase.__init__ to assign + attrs["_defaults"] = field_defaults + # add all the mappings for meta fields for name in dir(meta): if isinstance(getattr(meta, name, None), MetaField): @@ -64,6 +227,86 @@ def name(self): return self.mapping.properties.name +_FieldType = TypeVar("_FieldType") + + +class Mapped(Generic[_FieldType]): + """Class that represents the type of a mapped field. + + This class can be used as an optional wrapper on a field type to help type + checkers assign the correct type when the field is used as a class + attribute. + + Consider the following definitions:: + + class MyDocument(Document): + first: str + second: M[str] + + mydoc = MyDocument(first="1", second="2") + + Type checkers have no trouble inferring the type of both ``mydoc.first`` + and ``mydoc.second`` as ``str``, but while ``MyDocument.first`` will be + incorrectly typed as ``str``, ``MyDocument.second`` should be assigned the + correct ``InstrumentedField`` type. + """ + + __slots__ = {} + + if TYPE_CHECKING: + + @overload + def __get__(self, instance: None, owner: Any) -> InstrumentedField: ... + + @overload + def __get__(self, instance: object, owner: Any) -> _FieldType: ... + + def __get__( + self, instance: object | None, owner: Any + ) -> Union[InstrumentedField, _FieldType]: ... + + def __set__(self, instance: Optional[object], value: _FieldType) -> None: ... + + def __delete__(self, instance: Any) -> None: ... + + +M = Mapped + + +def mapped_field( + field: Optional[Field] = None, + *, + init: bool = True, + default: Any = None, + default_factory: Callable = None, + **kwargs, +) -> Any: + """Construct a field using dataclass behaviors + + This function can be used in the right side of a document field definition + as a wrapper for the field instance or as a way to provide dataclass-compatible + options. + + :param field: The instance of ``Field`` to use for this field. If not provided, + an instance that is appropriate for the type given to the field is used. + :param init: a value of ``True`` adds this field to the constructor, and a + value of ``False`` omits it from it. The default is ``True``. + :param default: a default value to use for this field when one is not provided + explicitly. + :param default_factory: a callable that returns a default value for the field, + when one isn't provided explicitly. Only one of ``factory`` and + ``default_factory`` can be used. + """ + return { + "_field": field, + "init": init, + "default": default, + "default_factory": default_factory, + **kwargs, + } + + +@dataclass_transform(field_specifiers=(mapped_field,)) class InnerDoc(ObjectBase, metaclass=DocumentMeta): """ Common class for inner documents like Object or Nested diff --git a/elasticsearch_dsl/search_base.py b/elasticsearch_dsl/search_base.py index 7a940d4a..893464c6 100644 --- a/elasticsearch_dsl/search_base.py +++ b/elasticsearch_dsl/search_base.py @@ -523,7 +523,7 @@ def knn( """ Add a k-nearest neighbor (kNN) search. - :arg field: the name of the vector field to search against + :arg field: the vector field to search against as a string or document class attribute :arg k: number of nearest neighbors to return as top hits :arg num_candidates: number of nearest neighbor candidates to consider per shard :arg query_vector: the vector to search for @@ -542,7 +542,7 @@ def knn( s = self._clone() s._knn.append( { - "field": field, + "field": str(field), # str() is for InstrumentedField instances "k": k, "num_candidates": num_candidates, } @@ -596,11 +596,15 @@ def source(self, fields=None, **kwargs): """ Selectively control how the _source field is returned. - :arg fields: wildcard string, array of wildcards, or dictionary of includes and excludes + :arg fields: field name, wildcard string, list of field names or wildcards, + or dictionary of includes and excludes + :arg kwargs: ``includes`` or ``excludes`` arguments, when ``fields`` is ``None``. - If ``fields`` is None, the entire document will be returned for - each hit. If fields is a dictionary with keys of 'includes' and/or - 'excludes' the fields will be either included or excluded appropriately. + When no arguments are given, the entire document will be returned for + each hit. If ``fields`` is a string or list of strings, the field names or field + wildcards given will be included. If ``fields`` is a dictionary with keys of + 'includes' and/or 'excludes' the fields will be either included or excluded + appropriately. Calling this multiple times with the same named parameter will override the previous values with the new ones. @@ -619,8 +623,16 @@ def source(self, fields=None, **kwargs): if fields and kwargs: raise ValueError("You cannot specify fields and kwargs at the same time.") + def ensure_strings(fields): + if isinstance(fields, list): + return [str(f) for f in fields] + elif isinstance(fields, dict): + return {k: ensure_strings(v) for k, v in fields.items()} + else: + return str(fields) + if fields is not None: - s._source = fields + s._source = fields if isinstance(fields, bool) else ensure_strings(fields) return s if kwargs and not isinstance(s._source, dict): @@ -633,7 +645,7 @@ def source(self, fields=None, **kwargs): except KeyError: pass else: - s._source[key] = value + s._source[key] = ensure_strings(value) return s @@ -663,11 +675,12 @@ def sort(self, *keys): s = self._clone() s._sort = [] for k in keys: - if isinstance(k, str) and k.startswith("-"): - if k[1:] == "_score": + sort_field = str(k) + if sort_field.startswith("-"): + if sort_field[1:] == "_score": raise IllegalOperation("Sorting by `-_score` is not allowed.") - k = {k[1:]: {"order": "desc"}} - s._sort.append(k) + sort_field = {sort_field[1:]: {"order": "desc"}} + s._sort.append(sort_field) return s def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=None): @@ -684,7 +697,7 @@ def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=No if field is None: return s - s._collapse["field"] = field + s._collapse["field"] = str(field) if inner_hits: s._collapse["inner_hits"] = inner_hits if max_concurrent_group_searches: @@ -740,7 +753,7 @@ def highlight(self, *fields, **kwargs): """ s = self._clone() for f in fields: - s._highlight[f] = kwargs + s._highlight[str(f)] = kwargs return s def suggest(self, name, text=None, regex=None, **kwargs): diff --git a/elasticsearch_dsl/utils.py b/elasticsearch_dsl/utils.py index 7fd0b08c..a58f0d3f 100644 --- a/elasticsearch_dsl/utils.py +++ b/elasticsearch_dsl/utils.py @@ -441,6 +441,15 @@ def __init__(self, meta=None, **kwargs): super(AttrDict, self).__setattr__("meta", HitMeta(meta)) + # process field defaults + if hasattr(self, "_defaults"): + for name in self._defaults: + if name not in kwargs: + value = self._defaults[name] + if callable(value): + value = value() + kwargs[name] = value + super().__init__(kwargs) @classmethod @@ -513,6 +522,12 @@ def __getattr__(self, name): return value raise + def __setattr__(self, name, value): + if name in self.__class__._doc_type.mapping: + self._d_[name] = value + else: + super().__setattr__(name, value) + def to_dict(self, skip_empty=True): out = {} for k, v in self._d_.items(): diff --git a/examples/async/vectors.py b/examples/async/vectors.py index 620ea45f..84bc001e 100644 --- a/examples/async/vectors.py +++ b/examples/async/vectors.py @@ -47,6 +47,8 @@ import asyncio import json import os +from datetime import datetime +from typing import List, Optional, cast from urllib.request import urlopen import nltk @@ -55,13 +57,12 @@ from elasticsearch_dsl import ( AsyncDocument, - Date, DenseVector, InnerDoc, Keyword, - Nested, - Text, + M, async_connections, + mapped_field, ) DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" @@ -72,45 +73,43 @@ class Passage(InnerDoc): - content = Text() - embedding = DenseVector() + content: M[str] + embedding: M[List[float]] = mapped_field(DenseVector()) class WorkplaceDoc(AsyncDocument): class Index: name = "workplace_documents" - name = Text() - summary = Text() - content = Text() - created = Date() - updated = Date() - url = Keyword() - category = Keyword() - passages = Nested(Passage) + name: M[str] + summary: M[str] + content: M[str] + created: M[datetime] + updated: M[Optional[datetime]] + url: M[str] = mapped_field(Keyword(required=True)) + category: M[str] = mapped_field(Keyword(required=True)) + passages: M[Optional[List[Passage]]] = mapped_field(default=[]) _model = None @classmethod - def get_embedding_model(cls): + def get_embedding(cls, input: str) -> List[float]: if cls._model is None: cls._model = SentenceTransformer(MODEL_NAME) - return cls._model + return cast(List[float], list(cls._model.encode(input))) def clean(self): # split the content into sentences passages = nltk.sent_tokenize(self.content) # generate an embedding for each passage and save it as a nested document - model = self.get_embedding_model() for passage in passages: self.passages.append( - Passage(content=passage, embedding=list(model.encode(passage))) + Passage(content=passage, embedding=self.get_embedding(passage)) ) async def create(): - # create the index await WorkplaceDoc._index.delete(ignore_unavailable=True) await WorkplaceDoc.init() @@ -133,12 +132,11 @@ async def create(): async def search(query): - model = WorkplaceDoc.get_embedding_model() return WorkplaceDoc.search().knn( - field="passages.embedding", + field=WorkplaceDoc.passages.embedding, k=5, num_candidates=50, - query_vector=list(model.encode(query)), + query_vector=list(WorkplaceDoc.get_embedding(query)), inner_hits={"size": 2}, ) diff --git a/examples/vectors.py b/examples/vectors.py index c204cb61..ae34eaae 100644 --- a/examples/vectors.py +++ b/examples/vectors.py @@ -46,6 +46,8 @@ import argparse import json import os +from datetime import datetime +from typing import List, Optional, cast from urllib.request import urlopen import nltk @@ -53,14 +55,13 @@ from tqdm import tqdm from elasticsearch_dsl import ( - Date, DenseVector, Document, InnerDoc, Keyword, - Nested, - Text, + M, connections, + mapped_field, ) DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" @@ -71,45 +72,43 @@ class Passage(InnerDoc): - content = Text() - embedding = DenseVector() + content: M[str] + embedding: M[List[float]] = mapped_field(DenseVector()) class WorkplaceDoc(Document): class Index: name = "workplace_documents" - name = Text() - summary = Text() - content = Text() - created = Date() - updated = Date() - url = Keyword() - category = Keyword() - passages = Nested(Passage) + name: M[str] + summary: M[str] + content: M[str] + created: M[datetime] + updated: M[Optional[datetime]] + url: M[str] = mapped_field(Keyword(required=True)) + category: M[str] = mapped_field(Keyword(required=True)) + passages: M[Optional[List[Passage]]] = mapped_field(default=[]) _model = None @classmethod - def get_embedding_model(cls): + def get_embedding(cls, input: str) -> List[float]: if cls._model is None: cls._model = SentenceTransformer(MODEL_NAME) - return cls._model + return cast(List[float], list(cls._model.encode(input))) def clean(self): # split the content into sentences passages = nltk.sent_tokenize(self.content) # generate an embedding for each passage and save it as a nested document - model = self.get_embedding_model() for passage in passages: self.passages.append( - Passage(content=passage, embedding=list(model.encode(passage))) + Passage(content=passage, embedding=self.get_embedding(passage)) ) def create(): - # create the index WorkplaceDoc._index.delete(ignore_unavailable=True) WorkplaceDoc.init() @@ -132,12 +131,11 @@ def create(): def search(query): - model = WorkplaceDoc.get_embedding_model() return WorkplaceDoc.search().knn( - field="passages.embedding", + field=WorkplaceDoc.passages.embedding, k=5, num_candidates=50, - query_vector=list(model.encode(query)), + query_vector=list(WorkplaceDoc.get_embedding(query)), inner_hits={"size": 2}, ) diff --git a/mypy.ini b/mypy.ini index 0c795321..e71761ce 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,6 @@ +[mypy] +explicit_package_bases = True + [mypy-elasticsearch_dsl.query] # Allow reexport of SF for tests -implicit_reexport = True \ No newline at end of file +implicit_reexport = True diff --git a/noxfile.py b/noxfile.py index a731f99b..612a0881 100644 --- a/noxfile.py +++ b/noxfile.py @@ -91,7 +91,7 @@ def type_check(session): session.install("mypy", ".[develop]") errors = [] popen = subprocess.Popen( - "mypy --strict elasticsearch_dsl tests", + "mypy --strict elasticsearch_dsl tests examples", env=session.env, shell=True, stdout=subprocess.PIPE, diff --git a/tests/_async/test_document.py b/tests/_async/test_document.py index 238e80fb..3bfb80b2 100644 --- a/tests/_async/test_document.py +++ b/tests/_async/test_document.py @@ -20,6 +20,7 @@ import pickle from datetime import datetime from hashlib import md5 +from typing import List, Optional import pytest from pytest import raises @@ -28,13 +29,16 @@ AsyncDocument, Index, InnerDoc, + M, Mapping, MetaField, Range, analyzer, field, + mapped_field, utils, ) +from elasticsearch_dsl.document_base import InstrumentedField from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException @@ -640,3 +644,176 @@ class MySubDocWithNested(MyDoc): }, "title": {"type": "keyword"}, } + + +def test_doc_with_type_hints(): + class TypedInnerDoc(InnerDoc): + st: M[str] + dt: M[Optional[datetime]] + li: M[List[int]] + + class TypedDoc(AsyncDocument): + st: str + dt: Optional[datetime] + li: List[int] + ob: Optional[TypedInnerDoc] + ns: Optional[List[TypedInnerDoc]] + ip: Optional[str] = field.Ip() + k1: str = field.Keyword(required=True) + k2: M[str] = field.Keyword() + k3: str = mapped_field(field.Keyword(), default="foo") + k4: M[Optional[str]] = mapped_field(field.Keyword()) + s1: Secret = SecretField() + s2: M[Secret] = SecretField() + s3: Secret = mapped_field(SecretField()) + s4: M[Optional[Secret]] = mapped_field( + SecretField(), default_factory=lambda: "foo" + ) + + props = TypedDoc._doc_type.mapping.to_dict()["properties"] + assert props == { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + "ob": { + "type": "object", + "properties": { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + }, + }, + "ns": { + "type": "nested", + "properties": { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + }, + }, + "ip": {"type": "ip"}, + "k1": {"type": "keyword"}, + "k2": {"type": "keyword"}, + "k3": {"type": "keyword"}, + "k4": {"type": "keyword"}, + "s1": {"type": "text"}, + "s2": {"type": "text"}, + "s3": {"type": "text"}, + "s4": {"type": "text"}, + } + + doc = TypedDoc() + assert doc.k3 == "foo" + assert doc.s4 == "foo" + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"st", "li", "k1"} + + doc.st = "s" + doc.li = [1, 2, 3] + doc.k1 = "k" + doc.full_clean() + + doc.ob = TypedInnerDoc() + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"ob"} + assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st", "li"} + + doc.ob.st = "s" + doc.ob.li = [1] + doc.full_clean() + + doc.ns.append(TypedInnerDoc(st="s")) + with raises(ValidationException) as exc_info: + doc.full_clean() + + doc.ns[0].li = [1, 2] + doc.full_clean() + + doc.ip = "1.2.3.4" + n = datetime.now() + doc.dt = n + assert doc.to_dict() == { + "st": "s", + "li": [1, 2, 3], + "dt": n, + "ob": { + "st": "s", + "li": [1], + }, + "ns": [ + { + "st": "s", + "li": [1, 2], + } + ], + "ip": "1.2.3.4", + "k1": "k", + "k3": "foo", + "s4": "foo", + } + + s = TypedDoc.search().sort(TypedDoc.st, -TypedDoc.dt, +TypedDoc.ob.st) + assert s.to_dict() == {"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"]} + + +def test_instrumented_field(): + class Child(InnerDoc): + st: M[str] + + class Doc(AsyncDocument): + st: str + ob: Child + ns: List[Child] + + doc = Doc( + st="foo", + ob=Child(st="bar"), + ns=[ + Child(st="baz"), + Child(st="qux"), + ], + ) + + assert type(doc.st) is str + assert doc.st == "foo" + + assert type(doc.ob) is Child + assert doc.ob.st == "bar" + + assert type(doc.ns) is utils.AttrList + assert doc.ns[0].st == "baz" + assert doc.ns[1].st == "qux" + assert type(doc.ns[0]) is Child + assert type(doc.ns[1]) is Child + + assert type(Doc.st) is InstrumentedField + assert str(Doc.st) == "st" + assert +Doc.st == "st" + assert -Doc.st == "-st" + assert Doc.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.st.something + + assert type(Doc.ob) is InstrumentedField + assert str(Doc.ob) == "ob" + assert str(Doc.ob.st) == "ob.st" + assert +Doc.ob.st == "ob.st" + assert -Doc.ob.st == "-ob.st" + assert Doc.ob.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.ob.something + with raises(AttributeError): + Doc.ob.st.something + + assert type(Doc.ns) is InstrumentedField + assert str(Doc.ns) == "ns" + assert str(Doc.ns.st) == "ns.st" + assert +Doc.ns.st == "ns.st" + assert -Doc.ns.st == "-ns.st" + assert Doc.ns.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.ns.something + with raises(AttributeError): + Doc.ns.st.something diff --git a/tests/_sync/test_document.py b/tests/_sync/test_document.py index 5cfa183c..27567ac9 100644 --- a/tests/_sync/test_document.py +++ b/tests/_sync/test_document.py @@ -20,6 +20,7 @@ import pickle from datetime import datetime from hashlib import md5 +from typing import List, Optional import pytest from pytest import raises @@ -28,13 +29,16 @@ Document, Index, InnerDoc, + M, Mapping, MetaField, Range, analyzer, field, + mapped_field, utils, ) +from elasticsearch_dsl.document_base import InstrumentedField from elasticsearch_dsl.exceptions import IllegalOperation, ValidationException @@ -640,3 +644,176 @@ class MySubDocWithNested(MyDoc): }, "title": {"type": "keyword"}, } + + +def test_doc_with_type_hints(): + class TypedInnerDoc(InnerDoc): + st: M[str] + dt: M[Optional[datetime]] + li: M[List[int]] + + class TypedDoc(Document): + st: str + dt: Optional[datetime] + li: List[int] + ob: Optional[TypedInnerDoc] + ns: Optional[List[TypedInnerDoc]] + ip: Optional[str] = field.Ip() + k1: str = field.Keyword(required=True) + k2: M[str] = field.Keyword() + k3: str = mapped_field(field.Keyword(), default="foo") + k4: M[Optional[str]] = mapped_field(field.Keyword()) + s1: Secret = SecretField() + s2: M[Secret] = SecretField() + s3: Secret = mapped_field(SecretField()) + s4: M[Optional[Secret]] = mapped_field( + SecretField(), default_factory=lambda: "foo" + ) + + props = TypedDoc._doc_type.mapping.to_dict()["properties"] + assert props == { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + "ob": { + "type": "object", + "properties": { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + }, + }, + "ns": { + "type": "nested", + "properties": { + "st": {"type": "text"}, + "dt": {"type": "date"}, + "li": {"type": "integer"}, + }, + }, + "ip": {"type": "ip"}, + "k1": {"type": "keyword"}, + "k2": {"type": "keyword"}, + "k3": {"type": "keyword"}, + "k4": {"type": "keyword"}, + "s1": {"type": "text"}, + "s2": {"type": "text"}, + "s3": {"type": "text"}, + "s4": {"type": "text"}, + } + + doc = TypedDoc() + assert doc.k3 == "foo" + assert doc.s4 == "foo" + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"st", "li", "k1"} + + doc.st = "s" + doc.li = [1, 2, 3] + doc.k1 = "k" + doc.full_clean() + + doc.ob = TypedInnerDoc() + with raises(ValidationException) as exc_info: + doc.full_clean() + assert set(exc_info.value.args[0].keys()) == {"ob"} + assert set(exc_info.value.args[0]["ob"][0].args[0].keys()) == {"st", "li"} + + doc.ob.st = "s" + doc.ob.li = [1] + doc.full_clean() + + doc.ns.append(TypedInnerDoc(st="s")) + with raises(ValidationException) as exc_info: + doc.full_clean() + + doc.ns[0].li = [1, 2] + doc.full_clean() + + doc.ip = "1.2.3.4" + n = datetime.now() + doc.dt = n + assert doc.to_dict() == { + "st": "s", + "li": [1, 2, 3], + "dt": n, + "ob": { + "st": "s", + "li": [1], + }, + "ns": [ + { + "st": "s", + "li": [1, 2], + } + ], + "ip": "1.2.3.4", + "k1": "k", + "k3": "foo", + "s4": "foo", + } + + s = TypedDoc.search().sort(TypedDoc.st, -TypedDoc.dt, +TypedDoc.ob.st) + assert s.to_dict() == {"sort": ["st", {"dt": {"order": "desc"}}, "ob.st"]} + + +def test_instrumented_field(): + class Child(InnerDoc): + st: M[str] + + class Doc(Document): + st: str + ob: Child + ns: List[Child] + + doc = Doc( + st="foo", + ob=Child(st="bar"), + ns=[ + Child(st="baz"), + Child(st="qux"), + ], + ) + + assert type(doc.st) is str + assert doc.st == "foo" + + assert type(doc.ob) is Child + assert doc.ob.st == "bar" + + assert type(doc.ns) is utils.AttrList + assert doc.ns[0].st == "baz" + assert doc.ns[1].st == "qux" + assert type(doc.ns[0]) is Child + assert type(doc.ns[1]) is Child + + assert type(Doc.st) is InstrumentedField + assert str(Doc.st) == "st" + assert +Doc.st == "st" + assert -Doc.st == "-st" + assert Doc.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.st.something + + assert type(Doc.ob) is InstrumentedField + assert str(Doc.ob) == "ob" + assert str(Doc.ob.st) == "ob.st" + assert +Doc.ob.st == "ob.st" + assert -Doc.ob.st == "-ob.st" + assert Doc.ob.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.ob.something + with raises(AttributeError): + Doc.ob.st.something + + assert type(Doc.ns) is InstrumentedField + assert str(Doc.ns) == "ns" + assert str(Doc.ns.st) == "ns.st" + assert +Doc.ns.st == "ns.st" + assert -Doc.ns.st == "-ns.st" + assert Doc.ns.st.to_dict() == {"type": "text"} + with raises(AttributeError): + Doc.ns.something + with raises(AttributeError): + Doc.ns.st.something