Skip to content

Commit

Permalink
refactoring; support of haystack filtering logic; use llama2 tokenize…
Browse files Browse the repository at this point in the history
…r by default
  • Loading branch information
Guest400123064 committed Apr 14, 2024
1 parent 7b0c11f commit 896dd55
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 76 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pip install bbm25-haystack

## Usage

The initializer takes [three BM25+ hyperparameters](https://en.wikipedia.org/wiki/Okapi_BM25), namely `k1`, `b`, and `delta`, and one path to a trained SentencePiece tokenizer `.model` file. All parameters are optional. The default tokenizer is directly copied from [this SentencePiece test tokenizer](https://github.com/google/sentencepiece/blob/master/python/test/test_model.model) with a vocab size of 1000.
The initializer takes [three BM25+ hyperparameters](https://en.wikipedia.org/wiki/Okapi_BM25), namely `k1`, `b`, and `delta`, one path to a trained SentencePiece tokenizer `.model` file, and a filtering logic flag ([see below](#filtering-logic)). All parameters are optional. The default tokenizer is directly copied from [LLaMA-2-7B-32K tokenizer](https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/tokenizer.model) with a vocab size of 32,000.

```python
from haystack import Document
Expand All @@ -34,9 +34,11 @@ retriever = BetterBM25Retriever(document_store)
retriever.run(query="How many languages are spoken around the world today?")
```

## Filtering Logic and Caveats
## Filtering Logic

The filtering logic is slightly different from the default implementation shipped with Haystack, but this logic may be subject to changes, and I am open to different suggestions. Please find comments and implementation details in [`filters.py`](./src/bbm25_haystack/filters.py). TL;DR:
The current document store uses `document_matches_filter` shipped with Haystack to perform filtering by default, which is the same as `InMemoryDocumentStore` except that it is DOES NOT support legacy operator names.

However, there is also an alternative filtering logic shipped with this implementation that is more conservative (and unstable at this point). To use this alternative logic, initialize the document store with `haystack_filter_logic=False` Please find comments and implementation details in [`filters.py`](./src/bbm25_haystack/filters.py). TL;DR:

- Comparison with `None`, i.e., missing values, involved will always return `False`, no matter the document attribute value or filter value.
- Comparison with `DataFrame` is always prohibited to reduce surprises.
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@ allow-direct-references = true

[tool.black]
target-version = ["py39"]
line-length = 90
line-length = 85
skip-string-normalization = true

[tool.ruff]
target-version = "py39"
line-length = 90
line-length = 85
select = [
"A",
"ARG",
Expand Down
21 changes: 12 additions & 9 deletions src/bbm25_haystack/bbm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ def __init__(
"""
Create an BetterBM25Retriever component.
:param document_store: A Document Store object used to retrieve documents
:param document_store: A Document Store object used to
retrieve documents
:type document_store: BetterBM25DocumentStore
:param filters: A dictionary with filters to narrow down the search space
(default is None).
:param filters: A dictionary with filters to narrow down the
search space (default is None).
:type filters: Optional[dict[str, Any]]
:param top_k: The maximum number of documents to retrieve (default is 10).
:param top_k: The maximum number of documents to retrieve
(default is 10).
:type top_k: int
:raises ValueError: If the specified top_k is not > 0.
Expand All @@ -88,10 +90,11 @@ def run(
:param query: The query to run the Retriever on.
:type query: str
:param filters: A dictionary with filters to narrow down the search space
(default is None).
:param filters: A dictionary with filters to narrow
down the search space (default is None).
:type filters: Optional[dict[str, Any]]
:param top_k: The maximum number of documents to retrieve (default is None).
:param top_k: The maximum number of documents to
retrieve (default is None).
:return: The retrieved documents.
"""
Expand Down Expand Up @@ -133,7 +136,7 @@ def from_dict(cls, data: dict[str, Any]) -> "BetterBM25Retriever":
msg = "Missing 'type' in document store's serialization data"
raise DeserializationError(msg)

data["init_parameters"]["document_store"] = BetterBM25DocumentStore.from_dict(
doc_store_params
data["init_parameters"]["document_store"] = (
BetterBM25DocumentStore.from_dict(doc_store_params)
)
return default_from_dict(cls, data)
58 changes: 33 additions & 25 deletions src/bbm25_haystack/bbm25_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,34 @@ def __init__(
b: float = 0.75,
delta: float = 1.0,
sp_file: Optional[str] = None,
haystack_filter_logic: bool = False,
haystack_filter_logic: bool = True,
) -> None:
"""
Creates a new BetterBM25DocumentStore instance.
An in-memory document store intended to improve the default BM25 document
store shipped with Haystack. The default store recompute the index for the
entire document store for every in-coming query, which is significantly
inefficient. This store aims to improve the efficiency by pre-computing
the index for all documents in the store and only do incremental updates
when new documents are added or removed. Further, it leverages a
SentencePiece model to tokenize the input text to allow more flexible
and dynamic tokenization adapted to domain-specific text.
An in-memory document store intended to improve the default
BM25 document store shipped with Haystack. The default store
recompute the index for the entire document store for every
in-coming query, which is significantly inefficient. This
store aims to improve the efficiency by pre-computing the
index for all documents in the store and only do incremental
updates when new documents are added or removed. Further, it
leverages a SentencePiece model to tokenize the input text
to allow more flexible and dynamic tokenization adapted to
domain-specific text.
:param k: the k1 parameter in BM25+ formula.
:type k: float, optional
:param b: the b parameter in BM25+ formula.
:type b: float, optional
:param delta: the delta parameter in BM25+ formula.
:type delta: float, optional
:param sp_file: the SentencePiece model file to use for tokenization.
:param sp_file: the SentencePiece model file to use for
tokenization.
:type sp_file: Optional[str], optional
:param haystack_filter_logic: Whether to use the Haystack filter logic
or the one implemented in this store, which is more conservative.
:param haystack_filter_logic: Whether to use the Haystack
filter logic or the one implemented in this store,
which is more conservative.
:type haystack_filter_logic: bool, optional
"""
self.k = k
Expand Down Expand Up @@ -161,7 +165,7 @@ def _retrieval(
*,
filters: Optional[dict[str, Any]] = None,
top_k: Optional[int] = None,
) -> list[Document]:
) -> list[tuple[Document, float]]:
"""
Retrieve documents from the store using the given query.
Expand Down Expand Up @@ -207,8 +211,12 @@ def filter_documents(
:return: the list of documents that match the given filters.
:rtype: list[Document]
"""
if filters is None or not filters:
return [doc for doc, _, _ in self._index.values()]
return [
doc for doc, _, _ in self._index.values() if self._filter_func(filters, doc)
doc
for doc, _, _ in self._index.values()
if self._filter_func(filters, doc)
]

def write_documents(
Expand All @@ -221,15 +229,15 @@ def write_documents(
:param documents: a list of documents.
:type documents: list[Document]
:param policy: documents with the same ID count as duplicates. When
duplicates are met, the store can:
:param policy: documents with the same ID count as duplicates.
When duplicates are met, the store can:
- skip: keep the existing document and ignore the new one.
- overwrite: remove the old document and write the new one.
- fail: an error is raised
:type policy: DuplicatePolicy, optional
:raises DuplicateDocumentError: Exception trigger on duplicate document if
`policy=DuplicatePolicy.FAIL`
:raises DuplicateDocumentError: Exception trigger on duplicate
document if `policy=DuplicatePolicy.FAIL`
:return: Number of documents written.
:rtype: int
Expand Down Expand Up @@ -257,9 +265,9 @@ def write_documents(

self._index[doc.id] = (doc, Counter(tokens), len(tokens))
self._freq_doc.update(set(tokens))
self._avg_doc_len = (len(tokens) + self._avg_doc_len * len(self._index)) / (
len(self._index) + 1
)
self._avg_doc_len = (
len(tokens) + self._avg_doc_len * len(self._index)
) / (len(self._index) + 1)

logger.debug(f"Document '{doc.id}' written to store.")
n_written += 1
Expand All @@ -268,15 +276,15 @@ def write_documents(

def delete_documents(self, document_ids: list[str]) -> int:
"""
Deletes all documents with a matching document_ids from the document store.
Deletes all documents with a matching document_ids.
Fails with `MissingDocumentError` if no document with this id is present in
the store.
Fails with `MissingDocumentError` if no document with
this id is present in the store.
:param object_ids: the object_ids to delete
:type object_ids: list[str]
:raises MissingDocumentError: Exception trigger on missing document.
:raises MissingDocumentError: trigger on missing document.
:return: Number of documents deleted.
:rtype: int
Expand Down
Binary file modified src/bbm25_haystack/default.model
Binary file not shown.
4 changes: 3 additions & 1 deletion src/bbm25_haystack/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def _and(document: Document, conditions: list[dict[str, Any]]) -> bool:
:return: True if not all conditions are met.
:rtype: bool
"""
return all(_run_comparison_condition(condition, document) for condition in conditions)
return all(
_run_comparison_condition(condition, document) for condition in conditions
)


def _or(document: Document, conditions: list[dict[str, Any]]) -> bool:
Expand Down
68 changes: 37 additions & 31 deletions tests/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class TestDocumentStore(DocumentStoreBaseTests):
def document_store(self) -> BetterBM25DocumentStore:
return BetterBM25DocumentStore()

@pytest.fixture
def document_store_bbm25_filter(self) -> BetterBM25DocumentStore:
return BetterBM25DocumentStore(haystack_filter_logic=False)

def test_write_documents(self, document_store: DocumentStore):
docs = [Document(id="1")]
assert document_store.write_documents(docs) == 1
Expand Down Expand Up @@ -77,25 +81,32 @@ def test_bm25_retrieval(self, document_store):
# Override a few filter test cases to account for new comparison logic
# Specifically, we alter the expected behavior when comparison involves
# None, DataFrame, and Iterables.
def test_comparison_equal_with_none(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
result = document_store.filter_documents(
def test_comparison_equal_with_none_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
document_store_bbm25_filter.write_documents(filterable_docs)
result = document_store_bbm25_filter.filter_documents(
filters={"field": "meta.number", "operator": "==", "value": None}
)
self.assert_documents_are_equal(result, [])

def test_comparison_not_equal_with_none(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
result = document_store.filter_documents(
def test_comparison_not_equal_with_none_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
document_store_bbm25_filter.write_documents(filterable_docs)
result = document_store_bbm25_filter.filter_documents(
filters={"field": "meta.number", "operator": "!=", "value": None}
)
self.assert_documents_are_equal(result, [])

def test_comparison_not_equal(self, document_store, filterable_docs):
"""Comparison with missing values will always return False. So the ground
truth is that we should only return documents with a non-missing value."""
document_store.write_documents(filterable_docs)
result = document_store.filter_documents(
def test_comparison_not_equal_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
"""Comparison with missing values will always return False.
So the ground truth is that we should only return documents
with a non-missing value."""
document_store_bbm25_filter.write_documents(filterable_docs)
result = document_store_bbm25_filter.filter_documents(
{"field": "meta.number", "operator": "!=", "value": 100}
)
self.assert_documents_are_equal(
Expand All @@ -107,10 +118,12 @@ def test_comparison_not_equal(self, document_store, filterable_docs):
],
)

def test_comparison_not_in(self, document_store, filterable_docs):
def test_comparison_not_in_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
"""Similar to the test above."""
document_store.write_documents(filterable_docs)
result = document_store.filter_documents(
document_store_bbm25_filter.write_documents(filterable_docs)
result = document_store_bbm25_filter.filter_documents(
{"field": "meta.number", "operator": "not in", "value": [9, 10]}
)
self.assert_documents_are_equal(
Expand All @@ -122,35 +135,28 @@ def test_comparison_not_in(self, document_store, filterable_docs):
],
)

def test_comparison_equal_with_dataframe(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
def test_comparison_equal_with_dataframe_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
document_store_bbm25_filter.write_documents(filterable_docs)
with pytest.raises(FilterError):
_ = document_store.filter_documents(
_ = document_store_bbm25_filter.filter_documents(
filters={
"field": "dataframe",
"operator": "==",
"value": pd.DataFrame([1]),
}
)

def test_comparison_not_equal_with_dataframe(self, document_store, filterable_docs):
document_store.write_documents(filterable_docs)
def test_comparison_not_equal_with_dataframe_bbm25_filter(
self, document_store_bbm25_filter, filterable_docs
):
document_store_bbm25_filter.write_documents(filterable_docs)
with pytest.raises(FilterError):
_ = document_store.filter_documents(
_ = document_store_bbm25_filter.filter_documents(
filters={
"field": "dataframe",
"operator": "==",
"value": pd.DataFrame([1]),
}
)

# Pass these two tests as we now support iterables other than lists
def test_comparison_in_with_with_non_list_iterable(
self, document_store, filterable_docs
):
pass

def test_comparison_not_in_with_with_non_list_iterable(
self, document_store, filterable_docs
):
pass
22 changes: 17 additions & 5 deletions tests/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,20 @@ def test_to_dict(self):
"MyFakeStore", bases=(BetterBM25DocumentStore,)
)
document_store = store_class()
document_store.to_dict = lambda: {"type": "MyFakeStore", "init_parameters": {}}
document_store.to_dict = lambda: {
"type": "MyFakeStore",
"init_parameters": {},
}
component = BetterBM25Retriever(document_store=document_store)

data = component.to_dict()
assert data == {
"type": "bbm25_haystack.bbm25_retriever.BetterBM25Retriever",
"init_parameters": {
"document_store": {"type": "MyFakeStore", "init_parameters": {}},
"document_store": {
"type": "MyFakeStore",
"init_parameters": {},
},
"filters": None,
"top_k": 10,
},
Expand Down Expand Up @@ -104,7 +110,8 @@ def test_from_dict(self):
def test_from_dict_without_docstore(self):
data = {"type": "BetterBM25Retriever", "init_parameters": {}}
with pytest.raises(
DeserializationError, match="Missing 'document_store' in serialization data"
DeserializationError,
match="Missing 'document_store' in serialization data",
):
BetterBM25Retriever.from_dict(data)

Expand All @@ -123,7 +130,10 @@ def test_from_dict_nonexisting_docstore(self):
data = {
"type": "bbm25_haystack.BetterBM25Retriever",
"init_parameters": {
"document_store": {"type": "Nonexisting.Docstore", "init_parameters": {}}
"document_store": {
"type": "Nonexisting.Docstore",
"init_parameters": {},
}
},
}
with pytest.raises(DeserializationError):
Expand All @@ -138,7 +148,9 @@ def test_retriever_valid_run(self, mock_docs):

assert "documents" in result
assert len(result["documents"]) == 5
assert result["documents"][0].content == "PHP is a popular programming language"
assert (
result["documents"][0].content == "PHP is a popular programming language"
)

def test_invalid_run_wrong_store_type(self):
store_class = document_store_class("SomeOtherDocumentStore")
Expand Down

0 comments on commit 896dd55

Please sign in to comment.