Skip to content

Commit

Permalink
Absolutely stupid implementation, just passing the simplest test
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed May 23, 2024
1 parent e6a93e5 commit b348ef8
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 0 deletions.
82 changes: 82 additions & 0 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from pathlib import Path
from typing import Dict, Generator, List, Tuple
from nomenklatura.index.index import Index
import duckdb
import logging

from nomenklatura.util import PathLike
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import View
from nomenklatura.index.entry import Field
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer

log = logging.getLogger(__name__)


class DuckDBIndex(Index):
def __init__(self, view: View[DS, CE], path: Path):
self.view = view
self.tokenizer = Tokenizer[DS, CE]()
self.con = duckdb.connect(path.as_posix())
self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")

def index(self, entity: CE) -> None:
"""Index one entity. This is not idempotent, you need to remove the
entity before re-indexing it."""

if not entity.schema.matchable or entity.id is None:
return
rows = []

for field, token in self.tokenizer.entity(entity):
rows.append([entity.id, field, token])
self.con.executemany("INSERT INTO entries VALUES (?, ?, ?)", rows)

def build(self) -> None:
"""Index all entities in the dataset."""
log.info("Building index from: %r...", self.view)
self.con.execute("BEGIN TRANSACTION")
for idx, entity in enumerate(self.view.entities()):
if idx % 10000 == 0:
log.info("Indexing entity %s", idx)
self.index(entity)
self.con.execute("COMMIT")

def match(self, entity: CE) -> List[Tuple[Identifier, float]]:
scores: Dict[str, float] = {}
for field_name, token in self.tokenizer.entity(entity):
for id, weight in self.frequencies(field_name, token):
if id not in scores:
scores[id] = 0.0
scores[id] += weight * self.BOOSTS.get(field_name, 1.0)
scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
for id, score in scores.items():
yield Identifier.get(id), score

def frequencies(
self, field: str, token: str
) -> Generator[Tuple[str, float], None, None]:
# This can probably done with relational query in DuckDB instead of
# a select per id per token
mentions_query = """
SELECT id, count(*) as mentions
FROM entries
WHERE field = ? AND token = ?
GROUP BY id
"""
field_len_query = """
SELECT count(*) from entries WHERE field = ? and id = ?
"""
mentions_result = self.con.execute(mentions_query, [field, token])
for id, mentions in mentions_result.fetchall():
(field_len,) = self.con.execute(field_len_query, [field, id]).fetchone()
field_len = max(1, field_len)
yield id, mentions / field_len

def __repr__(self) -> str:
return "<DuckDBIndex(%r, %r)>" % (
self.view.scope.name,
self.con,
)
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tempfile import mkdtemp

from nomenklatura import settings
from nomenklatura.index.duckdb_index import DuckDBIndex
from nomenklatura.store import load_entity_file_store, SimpleMemoryStore
from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
Expand Down Expand Up @@ -62,3 +63,11 @@ def dindex(dstore: SimpleMemoryStore):
index = Index(dstore.default_view())
index.build()
return index


@pytest.fixture(scope="module")
def duckdb_index(dstore: SimpleMemoryStore):
path = Path(WORK_PATH) / "duckdb_index.db"
index = DuckDBIndex(dstore.default_view(), path)
index.build()
return index
41 changes: 41 additions & 0 deletions tests/test_duckdb_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from pathlib import Path
from tempfile import NamedTemporaryFile

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
from nomenklatura.index import Index
from nomenklatura.resolver.identifier import Identifier
from nomenklatura.store import SimpleMemoryStore

DAIMLER = "66ce9f62af8c7d329506da41cb7c36ba058b3d28"
VERBAND_ID = "62ad0fe6f56dbbf6fee57ce3da76e88c437024d5"
VERBAND_BADEN_ID = "69401823a9f0a97cfdc37afa7c3158374e007669"
VERBAND_BADEN_DATA = {
"id": "bla",
"schema": "Company",
"properties": {
"name": ["VERBAND DER METALL UND ELEKTROINDUSTRIE BADEN WURTTEMBERG"]
},
}


def test_match_score(dstore: SimpleMemoryStore, duckdb_index: Index):
"""Match an entity that isn't itself in the index"""
dx = Dataset.make({"name": "test", "title": "Test"})
entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA)
matches = duckdb_index.match(entity)
# 9 entities in the index where some token in the query entity matches some
# token in the index.
assert len(matches) == 9, matches

top_result = matches[0]
assert top_result[0] == Identifier(VERBAND_BADEN_ID), top_result
assert 1.99 < top_result[1] < 2, top_result

next_result = matches[1]
assert next_result[0] == Identifier(VERBAND_ID), next_result
assert 1.66 < next_result[1] < 1.67, next_result

match_identifiers = set(str(m[0]) for m in matches)
assert VERBAND_ID in match_identifiers # validity
assert DAIMLER not in match_identifiers

0 comments on commit b348ef8

Please sign in to comment.