-
-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Absolutely stupid implementation, just passing the simplest test
- Loading branch information
Showing
3 changed files
with
132 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from pathlib import Path | ||
from typing import Dict, Generator, List, Tuple | ||
from nomenklatura.index.index import Index | ||
import duckdb | ||
import logging | ||
|
||
from nomenklatura.util import PathLike | ||
from nomenklatura.resolver import Pair, Identifier | ||
from nomenklatura.dataset import DS | ||
from nomenklatura.entity import CE | ||
from nomenklatura.store import View | ||
from nomenklatura.index.entry import Field | ||
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class DuckDBIndex(Index): | ||
def __init__(self, view: View[DS, CE], path: Path): | ||
self.view = view | ||
self.tokenizer = Tokenizer[DS, CE]() | ||
self.con = duckdb.connect(path.as_posix()) | ||
self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") | ||
|
||
def index(self, entity: CE) -> None: | ||
"""Index one entity. This is not idempotent, you need to remove the | ||
entity before re-indexing it.""" | ||
|
||
if not entity.schema.matchable or entity.id is None: | ||
return | ||
rows = [] | ||
|
||
for field, token in self.tokenizer.entity(entity): | ||
rows.append([entity.id, field, token]) | ||
self.con.executemany("INSERT INTO entries VALUES (?, ?, ?)", rows) | ||
|
||
def build(self) -> None: | ||
"""Index all entities in the dataset.""" | ||
log.info("Building index from: %r...", self.view) | ||
self.con.execute("BEGIN TRANSACTION") | ||
for idx, entity in enumerate(self.view.entities()): | ||
if idx % 10000 == 0: | ||
log.info("Indexing entity %s", idx) | ||
self.index(entity) | ||
self.con.execute("COMMIT") | ||
|
||
def match(self, entity: CE) -> List[Tuple[Identifier, float]]: | ||
scores: Dict[str, float] = {} | ||
for field_name, token in self.tokenizer.entity(entity): | ||
for id, weight in self.frequencies(field_name, token): | ||
if id not in scores: | ||
scores[id] = 0.0 | ||
scores[id] += weight * self.BOOSTS.get(field_name, 1.0) | ||
scores = sorted(scores.items(), key=lambda x: x[1], reverse=True) | ||
for id, score in scores.items(): | ||
yield Identifier.get(id), score | ||
|
||
def frequencies( | ||
self, field: str, token: str | ||
) -> Generator[Tuple[str, float], None, None]: | ||
# This can probably done with relational query in DuckDB instead of | ||
# a select per id per token | ||
mentions_query = """ | ||
SELECT id, count(*) as mentions | ||
FROM entries | ||
WHERE field = ? AND token = ? | ||
GROUP BY id | ||
""" | ||
field_len_query = """ | ||
SELECT count(*) from entries WHERE field = ? and id = ? | ||
""" | ||
mentions_result = self.con.execute(mentions_query, [field, token]) | ||
for id, mentions in mentions_result.fetchall(): | ||
(field_len,) = self.con.execute(field_len_query, [field, id]).fetchone() | ||
field_len = max(1, field_len) | ||
yield id, mentions / field_len | ||
|
||
def __repr__(self) -> str: | ||
return "<DuckDBIndex(%r, %r)>" % ( | ||
self.view.scope.name, | ||
self.con, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from pathlib import Path | ||
from tempfile import NamedTemporaryFile | ||
|
||
from nomenklatura.dataset import Dataset | ||
from nomenklatura.entity import CompositeEntity | ||
from nomenklatura.index import Index | ||
from nomenklatura.resolver.identifier import Identifier | ||
from nomenklatura.store import SimpleMemoryStore | ||
|
||
DAIMLER = "66ce9f62af8c7d329506da41cb7c36ba058b3d28" | ||
VERBAND_ID = "62ad0fe6f56dbbf6fee57ce3da76e88c437024d5" | ||
VERBAND_BADEN_ID = "69401823a9f0a97cfdc37afa7c3158374e007669" | ||
VERBAND_BADEN_DATA = { | ||
"id": "bla", | ||
"schema": "Company", | ||
"properties": { | ||
"name": ["VERBAND DER METALL UND ELEKTROINDUSTRIE BADEN WURTTEMBERG"] | ||
}, | ||
} | ||
|
||
|
||
def test_match_score(dstore: SimpleMemoryStore, duckdb_index: Index): | ||
"""Match an entity that isn't itself in the index""" | ||
dx = Dataset.make({"name": "test", "title": "Test"}) | ||
entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) | ||
matches = duckdb_index.match(entity) | ||
# 9 entities in the index where some token in the query entity matches some | ||
# token in the index. | ||
assert len(matches) == 9, matches | ||
|
||
top_result = matches[0] | ||
assert top_result[0] == Identifier(VERBAND_BADEN_ID), top_result | ||
assert 1.99 < top_result[1] < 2, top_result | ||
|
||
next_result = matches[1] | ||
assert next_result[0] == Identifier(VERBAND_ID), next_result | ||
assert 1.66 < next_result[1] < 1.67, next_result | ||
|
||
match_identifiers = set(str(m[0]) for m in matches) | ||
assert VERBAND_ID in match_identifiers # validity | ||
assert DAIMLER not in match_identifiers |