Skip to content

Commit

Permalink
make private attribute explicit by starting with underscore
Browse files Browse the repository at this point in the history
  • Loading branch information
Guest400123064 committed May 2, 2024
1 parent 3b41ef7 commit 2564442
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 33 deletions.
12 changes: 3 additions & 9 deletions src/bbm25_haystack/bbm25_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def _validate_search_params(filters: Optional[dict[str, Any]], top_k: int) -> No
:raises TypeError: If filters is not a dictionary.
"""
if not isinstance(top_k, int):
msg = f"top_k must be an integer; got {type(top_k)} instead."
msg = f"'top_k' must be an integer; got '{type(top_k)}' instead."
raise TypeError(msg)

if top_k <= 0:
msg = f"top_k must be > 0; got {top_k} instead."
msg = f"'top_k' must be > 0; got '{top_k}' instead."
raise ValueError(msg)

if filters is not None and (not isinstance(filters, dict)):
msg = f"filters must be a dictionary; got {type(filters)} instead."
msg = f"'filters' must be a dictionary; got '{type(filters)}' instead."
raise TypeError(msg)


Expand Down Expand Up @@ -80,20 +80,14 @@ def __init__(
_validate_search_params(filters, top_k)

self.filters = filters
"""@private"""

self.top_k = top_k
"""@private"""

self.set_score = set_score
"""@private"""

if not isinstance(document_store, BetterBM25DocumentStore):
msg = "'document_store' must be of type 'BetterBM25DocumentStore'"
raise TypeError(msg)

self.document_store = document_store
"""@private"""

@component.output_types(documents=list[Document])
def run(
Expand Down
41 changes: 17 additions & 24 deletions src/bbm25_haystack/bbm25_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@ class BetterBM25DocumentStore:
``InMemoryDocumentStore`` shipped with Haystack.
"""

default_sp_file: Final = os.path.join(
_default_sp_file: Final = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "default.model"
)
"""@private"""

def __init__(
self,
Expand Down Expand Up @@ -89,19 +88,13 @@ def __init__(
logic or the one implemented in this store.
:type haystack_filter_logic: ``Optional[bool]``
"""
self.k = k
"""@private"""
self._k = k
self._b = b

self.b = b
"""@private"""

self.delta = delta / (self.k + 1.0)
"""@private
Adjust the delta value so that we can bring the ``(k1 + 1)``
term out of the 'term frequency' term in BM25+ formula and
delete it; this will not affect the ranking.
"""
# Adjust the delta value so that we can bring the ``(k1 + 1)``
# term out of the 'term frequency' term in BM25+ formula and
# delete it; this will not affect the ranking.
self._delta = delta / (self._k + 1.0)

self._parse_sp_file(sp_file=sp_file)
self._parse_n_grams(n_grams=n_grams)
Expand All @@ -121,27 +114,27 @@ def _parse_sp_file(self, sp_file: Optional[str]) -> None:
self._sp_file = sp_file

if sp_file is None:
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file)
return

if not os.path.exists(sp_file) or not os.path.isfile(sp_file):
msg = (
f"Tokenizer model file '{sp_file}' not accessible; "
f"fallback to default {self.default_sp_file}."
f"fallback to default {self._default_sp_file}."
)
logger.warn(msg)
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file)
return

try:
self._sp_inst = SentencePieceProcessor(model_file=sp_file)
except Exception as exc:
msg = (
f"Failed to load tokenizer model file '{sp_file}': {exc}; "
f"fallback to default {self.default_sp_file}."
f"fallback to default {self._default_sp_file}."
)
logger.error(msg)
self._sp_inst = SentencePieceProcessor(model_file=self.default_sp_file)
self._sp_inst = SentencePieceProcessor(model_file=self._default_sp_file)

def _parse_n_grams(self, n_grams: Optional[Union[int, tuple[int, int]]]) -> None:
self._n_grams = n_grams
Expand Down Expand Up @@ -222,9 +215,9 @@ def _compute_bm25plus(
scr = 0.0
for token, idf_val in idf.items():
freq_term = freq.get(token, 0.0)
freq_damp = self.k * (1 + self.b * (doc_len_scaled - 1))
freq_damp = self._k * (1 + self._b * (doc_len_scaled - 1))

tf_val = freq_term / (freq_term + freq_damp) + self.delta
tf_val = freq_term / (freq_term + freq_damp) + self._delta
scr += idf_val * tf_val

sim.append((doc, scr))
Expand Down Expand Up @@ -392,9 +385,9 @@ def to_dict(self) -> dict[str, Any]:
"""Serializes this store to a dictionary."""
return default_to_dict(
self,
k=self.k,
b=self.b,
delta=self.delta * (self.k + 1.0), # Because we scaled it on init
k=self._k,
b=self._b,
delta=self._delta * (self._k + 1.0), # Because we scaled it on init
sp_file=self._sp_file,
n_grams=self._n_grams,
haystack_filter_logic=self._haystack_filter_logic,
Expand Down

0 comments on commit 2564442

Please sign in to comment.