Skip to content

Commit

Permalink
Support supplied index dir and xref index choice
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Jun 12, 2024
1 parent 1197dfc commit 25e5ba2
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 32 deletions.
5 changes: 4 additions & 1 deletion nomenklatura/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nomenklatura.xref import xref as run_xref
from nomenklatura.tui import dedupe_ui

INDEX_SEGMENT = "index-data"

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -75,6 +76,7 @@ def xref_file(
run_xref(
resolver_,
store,
path.parent / INDEX_SEGMENT,
auto_threshold=auto_threshold,
algorithm=algorithm_type,
scored=scored,
Expand Down Expand Up @@ -139,7 +141,8 @@ def dedupe(path: Path, xref: bool = False, resolver: Optional[Path] = None) -> N
resolver_ = _get_resolver(path, resolver)
store = load_entity_file_store(path, resolver=resolver_)
if xref:
run_xref(resolver_, store)
index_dir = path.parent / INDEX_SEGMENT
run_xref(resolver_, store, index_dir)

dedupe_ui(resolver_, store)
resolver_.save()
Expand Down
3 changes: 2 additions & 1 deletion nomenklatura/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from nomenklatura.index.index import Index
from nomenklatura.index.tantivy_index import TantivyIndex
from nomenklatura.index.common import BaseIndex

__all__ = ["Index", "TantivyIndex"]
__all__ = ["BaseIndex", "Index", "TantivyIndex"]
22 changes: 22 additions & 0 deletions nomenklatura/index/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from pathlib import Path
from typing import Generator, Generic, List, Tuple
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import View


class BaseIndex(Generic[DS, CE]):
name: str

def __init__(self, view: View[DS, CE], data_dir: Path) -> None:
raise NotImplementedError

def build(self) -> None:
raise NotImplementedError

def pairs(self) -> List[Tuple[Tuple[Identifier, Identifier], float]]:
raise NotImplementedError

def match(self, entity: CE) -> List[Tuple[Identifier, float]]:
raise NotImplementedError
11 changes: 7 additions & 4 deletions nomenklatura/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
from nomenklatura.store import View
from nomenklatura.index.entry import Field
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer
from nomenklatura.index.common import BaseIndex

log = logging.getLogger(__name__)


class Index(Generic[DS, CE]):
class Index(BaseIndex[DS, CE]):
"""An in-memory search index to match entities against a given dataset."""

name = "memory"

BOOSTS = {
NAME_PART_FIELD: 2.0,
WORD_FIELD: 0.5,
Expand All @@ -37,7 +40,7 @@ class Index(Generic[DS, CE]):

__slots__ = "view", "fields", "tokenizer", "entities"

def __init__(self, view: View[DS, CE]):
def __init__(self, view: View[DS, CE], data_dir: Path):
self.view = view
self.tokenizer = Tokenizer[DS, CE]()
self.fields: Dict[str, Field] = {}
Expand Down Expand Up @@ -119,8 +122,8 @@ def save(self, path: PathLike) -> None:
pickle.dump(self.to_dict(), fh)

@classmethod
def load(cls, view: View[DS, CE], path: Path) -> "Index[DS, CE]":
index = Index(view)
def load(cls, view: View[DS, CE], path: Path, data_dir: Path) -> "Index[DS, CE]":
index = Index(view, data_dir)
if not path.exists():
log.debug("Cannot load: %r", index)
index.build()
Expand Down
7 changes: 5 additions & 2 deletions nomenklatura/index/tantivy_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from nomenklatura.resolver import Identifier, Pair
from nomenklatura.store import View
from nomenklatura.util import clean_text_basic, fingerprint_name
from nomenklatura.index.common import BaseIndex

log = logging.getLogger(__name__)

Expand All @@ -27,7 +28,9 @@
}


class TantivyIndex:
class TantivyIndex(BaseIndex[DS, CE]):
name = "tantivy"

def __init__(
self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {}
):
Expand All @@ -48,7 +51,7 @@ def __init__(
schema_builder.add_text_field(registry.date.name, tokenizer_name="raw")
self.schema = schema_builder.build()

self.index_dir = data_dir / "tantivy-index"
self.index_dir = data_dir
if self.index_dir.exists():
self.exists = True
self.index = tantivy.Index.open(self.index_dir.as_posix())
Expand Down
9 changes: 4 additions & 5 deletions nomenklatura/xref.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from nomenklatura.store import Store
from nomenklatura.judgement import Judgement
from nomenklatura.resolver import Resolver
from nomenklatura.index import TantivyIndex
from nomenklatura.index import TantivyIndex, Index, BaseIndex
from nomenklatura.matching import DefaultAlgorithm, ScoringAlgorithm

log = logging.getLogger(__name__)
Expand All @@ -31,6 +31,7 @@ def _print_stats(pairs: int, suggested: int, scores: List[float]) -> None:
def xref(
resolver: Resolver[CE],
store: Store[DS, CE],
index_dir: Path,
limit: int = 5000,
scored: bool = True,
external: bool = True,
Expand All @@ -39,11 +40,11 @@ def xref(
focus_dataset: Optional[str] = None,
algorithm: Type[ScoringAlgorithm] = DefaultAlgorithm,
user: Optional[str] = None,
index_class: Type[BaseIndex[DS, CE]] = TantivyIndex,
) -> None:
log.info("Begin xref: %r, resolver: %s", store, resolver)
view = store.default_view(external=external)
working_dir = Path(mkdtemp())
index = TantivyIndex(view, working_dir)
index = index_class(view, index_dir)
index.build()
try:
scores: List[float] = []
Expand Down Expand Up @@ -98,5 +99,3 @@ def xref(

except KeyboardInterrupt:
log.info("User cancelled, xref will end gracefully.")
finally:
shutil.rmtree(working_dir, ignore_errors=True)
21 changes: 13 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,22 @@ def test_dataset() -> Dataset:
return Dataset.make({"name": "test_dataset", "title": "Test Dataset"})


@pytest.fixture(scope="module")
def dindex(dstore: SimpleMemoryStore):
index = Index(dstore.default_view())
@pytest.fixture(scope="function")
def dindex(index_path: Path, dstore: SimpleMemoryStore):
index = Index(dstore.default_view(), index_path)
index.build()
return index


@pytest.fixture(scope="module")
def tantivy_index(dstore: SimpleMemoryStore):
state_path = Path(mkdtemp())
index = TantivyIndex(dstore.default_view(), state_path)
@pytest.fixture(scope="function")
def tantivy_index(index_path: Path, dstore: SimpleMemoryStore):
index = TantivyIndex(dstore.default_view(), index_path)
index.build()
yield index
shutil.rmtree(state_path, ignore_errors=True)


@pytest.fixture(scope="function")
def index_path():
index_path = Path(mkdtemp()) / "index-dir"
yield index_path
shutil.rmtree(index_path, ignore_errors=True)
20 changes: 11 additions & 9 deletions tests/index/test_index.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from tempfile import NamedTemporaryFile
from tempfile import NamedTemporaryFile, TemporaryDirectory

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
Expand All @@ -19,8 +19,8 @@
}


def test_index_build(dstore: SimpleMemoryStore):
index = Index(dstore.default_view())
def test_index_build(index_path: Path, dstore: SimpleMemoryStore):
index = Index(dstore.default_view(), index_path)
assert len(index) == 0, index.fields
assert len(index.fields) == 0, index.fields
index.build()
Expand All @@ -29,16 +29,18 @@ def test_index_build(dstore: SimpleMemoryStore):

def test_index_persist(dstore: SimpleMemoryStore, dindex):
view = dstore.default_view()
with NamedTemporaryFile("w") as fh:
path = Path(fh.name)
dindex.save(path)
loaded = Index.load(dstore.default_view(), path)
with TemporaryDirectory() as tmpdir:
with NamedTemporaryFile("w") as fh:
path = Path(fh.name)
dindex.save(path)
loaded = Index.load(dstore.default_view(), path, tmpdir)
assert len(dindex.entities) == len(loaded.entities), (dindex, loaded)
assert len(dindex) == len(loaded), (dindex, loaded)

path.unlink(missing_ok=True)
empty = Index.load(view, path)
assert len(empty) == len(loaded), (empty, loaded)
with TemporaryDirectory() as tmpdir:
empty = Index.load(view, path, tmpdir)
assert len(empty) == len(loaded), (empty, loaded)


def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index):
Expand Down
5 changes: 3 additions & 2 deletions tests/test_xref.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from pathlib import Path
from nomenklatura.xref import xref
from nomenklatura.store import SimpleMemoryStore
from nomenklatura.resolver import Resolver
from nomenklatura.entity import CompositeEntity


def test_xref_candidates(
dresolver: Resolver[CompositeEntity], dstore: SimpleMemoryStore
index_path: Path, dresolver: Resolver[CompositeEntity], dstore: SimpleMemoryStore
):
xref(dresolver, dstore)
xref(dresolver, dstore, index_path)
view = dstore.default_view(external=True)
candidates = list(dresolver.get_candidates(limit=20))
assert len(candidates) == 20
Expand Down

0 comments on commit 25e5ba2

Please sign in to comment.