From 53d18c4504665c93e982af319eb74a3b18f9fd93 Mon Sep 17 00:00:00 2001 From: Miguel Grinberg Date: Mon, 19 Aug 2024 11:44:48 +0100 Subject: [PATCH] Added support for the `semantic_text` field and `semantic` query type (#1881) * Added support for the `semantic_text` field and `semantic` query type * Fix nltk code... again * feedback (cherry picked from commit 7fa4f8c8ae85c14f8ed6ae580277ed9085d3f4cb) --- elasticsearch_dsl/field.py | 4 + elasticsearch_dsl/query.py | 4 + examples/async/semantic_text.py | 148 +++++++++++++++++++++++++++++++ examples/async/sparse_vectors.py | 2 +- examples/async/vectors.py | 2 +- examples/semantic_text.py | 147 ++++++++++++++++++++++++++++++ examples/sparse_vectors.py | 2 +- examples/vectors.py | 2 +- 8 files changed, 307 insertions(+), 4 deletions(-) create mode 100644 examples/async/semantic_text.py create mode 100644 examples/semantic_text.py diff --git a/elasticsearch_dsl/field.py b/elasticsearch_dsl/field.py index 7896fe5f..26f2336b 100644 --- a/elasticsearch_dsl/field.py +++ b/elasticsearch_dsl/field.py @@ -560,3 +560,7 @@ class TokenCount(Field): class Murmur3(Field): name = "murmur3" + + +class SemanticText(Field): + name = "semantic_text" diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index ce445216..993213c6 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -527,6 +527,10 @@ class Shape(Query): name = "shape" +class Semantic(Query): + name = "semantic" + + class SimpleQueryString(Query): name = "simple_query_string" diff --git a/examples/async/semantic_text.py b/examples/async/semantic_text.py new file mode 100644 index 00000000..cf63b003 --- /dev/null +++ b/examples/async/semantic_text.py @@ -0,0 +1,148 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +# Semantic Text example + +Requirements: + +$ pip install "elasticsearch-dsl[async]" tqdm + +Before running this example, an ELSER inference endpoint must be created in the +Elasticsearch cluster. This can be done manually from Kibana, or with the +following curl command from a terminal: + +curl -X PUT \ + "$ELASTICSEARCH_URL/_inference/sparse_embedding/my-elser-endpoint" \ + -H "Content-Type: application/json" \ + -d '{"service":"elser","service_settings":{"num_allocations":1,"num_threads":1}}' + +To run the example: + +$ python semantic_text.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to the command to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python semantic_text.py "work from home" +$ python semantic_text.py "vacation time" +$ python semantic_text.py "can I bring a bird to work?" + +When the index is created, the inference service will split the documents into +short passages, and for each passage a sparse embedding will be generated using +Elastic's ELSER v2 model. +""" + +import argparse +import asyncio +import json +import os +from datetime import datetime +from typing import Any, Optional +from urllib.request import urlopen + +from tqdm import tqdm + +import elasticsearch_dsl as dsl + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" + + +class WorkplaceDoc(dsl.AsyncDocument): + class Index: + name = "workplace_documents_semantic" + + name: str + summary: str + content: Any = dsl.mapped_field( + dsl.field.SemanticText(inference_id="my-elser-endpoint") + ) + created: datetime + updated: Optional[datetime] + url: str = dsl.mapped_field(dsl.Keyword()) + category: str = dsl.mapped_field(dsl.Keyword()) + + +async def create() -> None: + + # create the index + await WorkplaceDoc._index.delete(ignore_unavailable=True) + await WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + await doc.save() + + # refresh the index + await WorkplaceDoc._index.refresh() + + +async def search(query: str) -> dsl.AsyncSearch[WorkplaceDoc]: + search = WorkplaceDoc.search() + search = search[:5] + return search.query(dsl.query.Semantic(field=WorkplaceDoc.content, query=query)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +async def main() -> None: + args = parse_args() + + # initiate the default connection to elasticsearch + dsl.async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not await WorkplaceDoc._index.exists(): + await create() + + results = await search(args.query) + + async for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Content: {hit.content.text}") + print("--------------------\n") + + # close the connection + await dsl.async_connections.get_connection().close() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/async/sparse_vectors.py b/examples/async/sparse_vectors.py index d50b4080..458798a3 100644 --- a/examples/async/sparse_vectors.py +++ b/examples/async/sparse_vectors.py @@ -84,7 +84,7 @@ DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" # initialize sentence tokenizer -nltk.download("punkt", quiet=True) +nltk.download("punkt_tab", quiet=True) class Passage(InnerDoc): diff --git a/examples/async/vectors.py b/examples/async/vectors.py index 5221929d..b58c184e 100644 --- a/examples/async/vectors.py +++ b/examples/async/vectors.py @@ -70,7 +70,7 @@ MODEL_NAME = "all-MiniLM-L6-v2" # initialize sentence tokenizer -nltk.download("punkt", quiet=True) +nltk.download("punkt_tab", quiet=True) # this will be the embedding model embedding_model: Any = None diff --git a/examples/semantic_text.py b/examples/semantic_text.py new file mode 100644 index 00000000..7461e18d --- /dev/null +++ b/examples/semantic_text.py @@ -0,0 +1,147 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +""" +# Semantic Text example + +Requirements: + +$ pip install "elasticsearch-dsl" tqdm + +Before running this example, an ELSER inference endpoint must be created in the +Elasticsearch cluster. This can be done manually from Kibana, or with the +following curl command from a terminal: + +curl -X PUT \ + "$ELASTICSEARCH_URL/_inference/sparse_embedding/my-elser-endpoint" \ + -H "Content-Type: application/json" \ + -d '{"service":"elser","service_settings":{"num_allocations":1,"num_threads":1}}' + +To run the example: + +$ python semantic_text.py "text to search" + +The index will be created automatically if it does not exist. Add +`--recreate-index` to the command to regenerate it. + +The example dataset includes a selection of workplace documents. The +following are good example queries to try out with this dataset: + +$ python semantic_text.py "work from home" +$ python semantic_text.py "vacation time" +$ python semantic_text.py "can I bring a bird to work?" + +When the index is created, the inference service will split the documents into +short passages, and for each passage a sparse embedding will be generated using +Elastic's ELSER v2 model. +""" + +import argparse +import json +import os +from datetime import datetime +from typing import Any, Optional +from urllib.request import urlopen + +from tqdm import tqdm + +import elasticsearch_dsl as dsl + +DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" + + +class WorkplaceDoc(dsl.Document): + class Index: + name = "workplace_documents_semantic" + + name: str + summary: str + content: Any = dsl.mapped_field( + dsl.field.SemanticText(inference_id="my-elser-endpoint") + ) + created: datetime + updated: Optional[datetime] + url: str = dsl.mapped_field(dsl.Keyword()) + category: str = dsl.mapped_field(dsl.Keyword()) + + +def create() -> None: + + # create the index + WorkplaceDoc._index.delete(ignore_unavailable=True) + WorkplaceDoc.init() + + # download the data + dataset = json.loads(urlopen(DATASET_URL).read()) + + # import the dataset + for data in tqdm(dataset, desc="Indexing documents..."): + doc = WorkplaceDoc( + name=data["name"], + summary=data["summary"], + content=data["content"], + created=data.get("created_on"), + updated=data.get("updated_at"), + url=data["url"], + category=data["category"], + ) + doc.save() + + # refresh the index + WorkplaceDoc._index.refresh() + + +def search(query: str) -> dsl.Search[WorkplaceDoc]: + search = WorkplaceDoc.search() + search = search[:5] + return search.query(dsl.query.Semantic(field=WorkplaceDoc.content, query=query)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Vector database with Elasticsearch") + parser.add_argument( + "--recreate-index", action="store_true", help="Recreate and populate the index" + ) + parser.add_argument("query", action="store", help="The search query") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + # initiate the default connection to elasticsearch + dsl.connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]]) + + if args.recreate_index or not WorkplaceDoc._index.exists(): + create() + + results = search(args.query) + + for hit in results: + print( + f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]" + ) + print(f"Content: {hit.content.text}") + print("--------------------\n") + + # close the connection + dsl.connections.get_connection().close() + + +if __name__ == "__main__": + main() diff --git a/examples/sparse_vectors.py b/examples/sparse_vectors.py index ae156fe7..ba853e8c 100644 --- a/examples/sparse_vectors.py +++ b/examples/sparse_vectors.py @@ -83,7 +83,7 @@ DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json" # initialize sentence tokenizer -nltk.download("punkt", quiet=True) +nltk.download("punkt_tab", quiet=True) class Passage(InnerDoc): diff --git a/examples/vectors.py b/examples/vectors.py index c983514d..ef1342ee 100644 --- a/examples/vectors.py +++ b/examples/vectors.py @@ -69,7 +69,7 @@ MODEL_NAME = "all-MiniLM-L6-v2" # initialize sentence tokenizer -nltk.download("punkt", quiet=True) +nltk.download("punkt_tab", quiet=True) # this will be the embedding model embedding_model: Any = None