Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement embeddings for use with LLM agents #680

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,16 @@ async def _system_prompt_with_context(
self, messages: list | str, context: str = ""
) -> str:
system_prompt = self.system_prompt
table_name = memory.get("current_table")
if self.embeddings:
context = self.embeddings.query(messages)
# TODO: refactor this so it joins messages in a more robust way
text = "\n".join([message["content"] for message in messages])
# TODO: refactor this so it's not subsetting by index
# [(0, 'The creator of this dataset is named Andrew HH', 0.7491879463195801, 'windturbines.parquet')]
result = self.embeddings.query(text, table_name=table_name)[0][1]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another TODO: handle ephemeral tables

context += "\n" + result
if context:
system_prompt += f"\n### CONTEXT: {context}"
system_prompt = f"{system_prompt}\n### CONTEXT: {context}".strip()
return system_prompt

async def _get_closest_tables(self, messages: list | str, tables: list[str], n: int = 3) -> list[str]:
Expand Down Expand Up @@ -283,10 +289,7 @@ async def _system_prompt_with_context(
f"\nHere's a summary of the dataset the user just asked about:\n```\n{memory['current_data']}\n```"
)

system_prompt = self.system_prompt
if context:
system_prompt += f"\n### CONTEXT: {context}"
return system_prompt
return await super()._system_prompt_with_context(messages, context=context)


class ChatDetailsAgent(ChatAgent):
Expand All @@ -313,7 +316,6 @@ class ChatDetailsAgent(ChatAgent):
async def _system_prompt_with_context(
self, messages: list | str, context: str = ""
) -> str:
system_prompt = self.system_prompt
topic = (await self.llm.invoke(
messages,
system="What is the topic of the table?",
Expand All @@ -329,9 +331,7 @@ async def _system_prompt_with_context(
columns = list(current_data.columns)
context += f"\nHere are the columns of the table: {columns}"

if context:
system_prompt += f"\n### CONTEXT: {context}"
return system_prompt
return await super()._system_prompt_with_context(messages, context=context)


class LumenBaseAgent(Agent):
Expand Down
8 changes: 8 additions & 0 deletions lumen/ai/assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Agent, AnalysisAgent, ChatAgent, SQLAgent,
)
from .config import DEMO_MESSAGES, GETTING_STARTED_SUGGESTIONS
from .embeddings import Embeddings
from .export import export_notebook
from .llm import Llama, Llm
from .logs import ChatLogs
Expand All @@ -37,6 +38,8 @@ class Assistant(Viewer):

agents = param.List(default=[ChatAgent])

embeddings = param.ClassSelector(class_=Embeddings)

llm = param.ClassSelector(class_=Llm, default=Llama())

interface = param.ClassSelector(class_=ChatInterface)
Expand All @@ -54,6 +57,7 @@ class Assistant(Viewer):
def __init__(
self,
llm: Llm | None = None,
embeddings: Embeddings | None = None,
interface: ChatInterface | None = None,
agents: list[Agent | type[Agent]] | None = None,
logs_filename: str = "",
Expand Down Expand Up @@ -111,11 +115,15 @@ def download_notebook():
interface.post_hook = on_message

llm = llm or self.llm
embeddings = embeddings or self.embeddings
instantiated = []
self._analyses = []
for agent in agents or self.agents:
if not isinstance(agent, Agent):
kwargs = {"llm": llm} if agent.llm is None else {}
if embeddings:
print(f"embeddings for {agent}")
kwargs["embeddings"] = embeddings
agent = agent(interface=interface, **kwargs)
if agent.llm is None:
agent.llm = llm
Expand Down
143 changes: 123 additions & 20 deletions lumen/ai/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,136 @@
import os

from pathlib import Path

from .config import DEFAULT_EMBEDDINGS_PATH
import duckdb

DEFAULT_EMBEDDINGS_PATH = Path("embeddings")


class Embeddings:
def __init__(self, database_path: str = ":memory:"):
self.database_path = database_path
self.connection = duckdb.connect(database_path)
self.setup_database()

def add_directory(self, data_dir: Path):
raise NotImplementedError
def setup_database(self):
self.connection.execute(
"""
INSTALL vss;
LOAD vss;
CREATE TABLE document_data (
id INTEGER,
text VARCHAR,
embedding FLOAT[1536],
table_name VARCHAR
);
CREATE INDEX embedding_index ON document_data USING HNSW (embedding) WITH (metric = 'cosine');
"""
)

@classmethod
def from_directory(
cls,
data_dir: Path,
file_type: str = "json",
database_path: str = ":memory:",
table_name: str = "default",
):
embeddings = cls(database_path)
for i, path in enumerate(data_dir.glob(f"**/*.{file_type}")):
text = path.read_text()
embedding = embeddings.get_embedding(text)
embeddings.connection.execute(
"""
INSERT INTO document_data (id, text, embedding, table_name)
VALUES (?, ?, ?, ?);
""",
[i, text, embedding, table_name],
)
return embeddings

@classmethod
def from_dict(cls, data: dict, database_path: str = ":memory:"):
embeddings = cls(database_path)
global_id = 0
for table_name, texts in data.items():
for text in texts:
embedding = embeddings.get_embedding(text)
embeddings.connection.execute(
"""
INSERT INTO document_data (id, text, embedding, table_name)
VALUES (?, ?, ?, ?);
""",
[global_id, text, embedding, table_name],
)
global_id += 1
return embeddings

def query(self, query_texts: str) -> list:
def get_embedding(self, text: str) -> list:
raise NotImplementedError

def get_text_chunks(
self, text: str, chunk_size: int = 512, overlap: int = 50
) -> list:
words = text.split()
chunks = []
for i in range(0, len(words), chunk_size - overlap):
chunk = " ".join(words[i : i + chunk_size])
chunks.append(chunk)
return chunks

class ChromaDb(Embeddings):
def get_combined_embedding(self, text: str) -> list:
chunks = self.get_text_chunks(text)
embeddings = [self.get_embedding(chunk) for chunk in chunks]
combined_embedding = [sum(x) / len(x) for x in zip(*embeddings)]
return combined_embedding

def __init__(self, collection: str, persist_dir: str = DEFAULT_EMBEDDINGS_PATH):
import chromadb
self.client = chromadb.PersistentClient(path=str(persist_dir / collection))
self.collection = self.client.get_or_create_collection(collection)
def query(self, query_text: str, top_k: int = 1, table_name: str | None = None) -> list:
print(query_text, "QUERY")
query_embedding = self.get_combined_embedding(query_text)

def add_directory(self, data_dir: Path, file_type='json'):
add_kwargs = {
"ids": [],
"documents": [],
}
for i, path in enumerate(data_dir.glob(f"**/*.{file_type}")):
add_kwargs["ids"].append(f"{i}")
add_kwargs["documents"].append(path.read_text())
self.collection.add(**add_kwargs)
if table_name:
result = self.connection.execute(
"""
SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name
FROM document_data
WHERE table_name = ?
ORDER BY similarity DESC
LIMIT ?;
""",
[query_embedding, table_name, top_k],
).fetchall()
else:
result = self.connection.execute(
"""
SELECT id, text, array_cosine_similarity(embedding, ?::FLOAT[1536]) AS similarity, table_name
FROM document_data
ORDER BY similarity DESC
LIMIT ?;
""",
[query_embedding, top_k],
).fetchall()

return result

def close(self):
self.connection.close()


class OpenAIEmbeddings(Embeddings):
def __init__(
self, database_path: str = ":memory:", model: str = "text-embedding-3-small"
):
super().__init__(database_path)
self.model = model

def get_embedding(self, text: str) -> list:
from openai import OpenAI

def query(self, query_texts: str) -> list:
return self.collection.query(query_texts=query_texts)["documents"]
text = text.replace("\n", " ")
return (
OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
.embeddings.create(input=[text], model=self.model)
.data[0]
.embedding
)
Binary file added lumen/ai/interceptor.db
Binary file not shown.
Loading