Skip to content

Commit

Permalink
fix crud rag index
Browse files Browse the repository at this point in the history
  • Loading branch information
AnniePacheco committed Sep 18, 2024
1 parent 3785c8b commit 4c8e4d9
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 59 deletions.
137 changes: 81 additions & 56 deletions backend/managers/RagManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,63 +42,70 @@ def __init__(self):
self._initialized = True

async def create_index(self, resource_id: str, path_files: List[str]) -> List[dict]:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ.get('CHUNK_SIZE')),
chunk_overlap=int(os.environ.get('CHUNK_OVERLAP')),
add_start_index=bool(strtobool(os.environ.get('ADD_START_INDEX')))
)

all_docs = []
all_ids = []
file_ids = []
file_names = []
page_ids = []
file_info_list = []
all_chunks = []

for path in path_files:
loader = PyPDFLoader(path)
docs = loader.load()
all_docs.append(docs)
file_id = str(uuid4())
all_ids.append(file_id)
file_name = Path(path).name
file_names.append(file_name)

# Collect file_id and file_name into a dictionary
file_info_list.append({"file_id": file_id, "file_name": file_name})
for doc in docs:
all_docs.append([doc])
page_id = str(uuid4())
file_info_list.append({"file_id": file_id, "file_name": file_name, "page_id": page_id})

file_ids = [item['file_id'] for item in file_info_list]
file_names = [item['file_name'] for item in file_info_list]
page_ids = [item['page_id'] for item in file_info_list]

print("\n\nFILE IDS: ", file_ids)
print("\n\nFILE NAMES: ", file_names)

text_splitter = RecursiveCharacterTextSplitter(
chunk_size=int(os.environ.get('CHUNK_SIZE')),
chunk_overlap=int(os.environ.get('CHUNK_OVERLAP')),
add_start_index=bool(strtobool(os.environ.get('ADD_START_INDEX')))
)

# Split documents while retaining metadata
split_documents = []
split_ids = []

for doc, doc_id in zip(all_docs, all_ids):
all_chunks = []
for doc, page_id in zip(all_docs, page_ids):
#split the document into smaller chunks
splits = text_splitter.split_documents(doc)
all_chunks.append(len(splits))
# Append each chunk to the split_documents list
for i, split in enumerate(splits):
split.metadata["original_id"] = doc_id
split.metadata["original_id"] = page_id
split_documents.append(split)
# Create unique IDs for each split based on the original ID and chunk index
split_ids.append(f"{doc_id}-{i}")

await self.create_files_for_resource(resource_id, file_names, all_ids, all_chunks)

split_ids.append(f"{page_id}-{i}")
await self.create_files_for_resource(resource_id, file_ids, file_names, page_ids, all_chunks)
# add the split documents to the vectorstore
vectorstore = await self.initialize_chroma(resource_id)
vectorstore.add_documents(documents=split_documents, ids=split_ids)
return file_info_list
print("AFTER ADDING DOCUMENTS")
return file_info_list


async def create_files_for_resource(self, resource_id: str, files: List[str], ids: List[str], num_chunks:List[int]):
for file_name, file_id, chunk in zip(files, ids, num_chunks):
await self.create_file(file_name, resource_id, file_id, chunk)
async def create_files_for_resource(self, resource_id: str, file_ids:List[str], file_names: List[str], page_ids: List[str], num_chunks:List[int]):
print(f'file_ids: {len(file_ids)}, file_names: {len(file_names)}, page_ids: {len(page_ids)}, num_chunks: {len(num_chunks)}')
for file_id, file_name, page_id, chunk in zip(file_ids, file_names, page_ids, num_chunks):
await self.create_file(resource_id, file_id, file_name, page_id, chunk)

async def create_file(self, file_name: str, assistant_id: str, file_id: str, num_chunks: int = 0):
async def create_file(self, assistant_id: str, file_id: str, file_name: str, page_id: str, num_chunks: int = 0):
async with db_session_context() as session:
new_file = File(id=file_id, name=file_name, assistant_id=assistant_id, num_chunks=str(num_chunks))
session.add(new_file)
await session.commit()
await session.refresh(new_file)
try:
new_file = File(id=page_id, name=file_name, assistant_id=assistant_id, file_id=file_id, num_chunks=str(num_chunks))
session.add(new_file)
await session.commit()
await session.refresh(new_file)
except Exception as e:
print(f"An error occurred: {e}")


async def initialize_chroma(self, collection_name: str):
embed = OllamaEmbeddings(model=os.environ.get('EMBEDDER_MODEL'))
Expand Down Expand Up @@ -176,44 +183,56 @@ async def delete_tmp_files(self, assistant_id: str):
except Exception as e:
logger.error(f"An error occurred while deleting folder for assistant {assistant_id}: {e}", exc_info=True)

async def retrieve_file(self, id:str) -> Optional[FileSchema]:
async def retrieve_file(self, file_id:str) -> Optional[List[FileSchema]]:
async with db_session_context() as session:
result = await session.execute(select(File).filter(File.id == id))
file = result.scalar_one_or_none()
if file:
return FileSchema(
id=file.id,
name=file.name,
num_chunks = file.num_chunks
)
result = await session.execute(select(File).filter(File.file_id == file_id))
#file = result.scalar_one_or_none()
files = [FileSchema.from_orm(file) for file in result.scalars().all()]
if files:
return files
# return FileSchema(
# id=file.id,
# name=file.name,
# num_chunks = file.num_chunks
# )
return None

async def delete_documents_from_chroma(self, resource_id: str, file_ids=List[str]):
for file_id in file_ids:
file = await self.retrieve_file(file_id)
if file:
num_chunks = int(file.num_chunks)
vectorstore = await self.initialize_chroma(resource_id)

#iterate over num_chunks to create list of embeddings {file.id}-{chunk}
vectorstore = await self.initialize_chroma(resource_id)
for file_id in file_ids:
files = await self.retrieve_file(file_id)
if files:
page_ids = []
for file in files:
num_chunks = int(file.num_chunks)
page_id = file.id
page_ids.append(page_id)

list_chunks_id = []
for n in range(0, num_chunks):
chunk_id = f"{file_id}-{n}"
chunk_id = f"{page_id}-{n}"
list_chunks_id.append(chunk_id)
vectorstore.delete(ids=list_chunks_id)
else:
return None
return "Documents deleted"
else:
return None
return "Documents deleted"

async def delete_file_from_db(self, file_ids: List[str]):
async def delete_file_from_db(self, file_ids: List[str]):
page_ids = []
for file_id in file_ids:
files = await self.retrieve_file(file_id)
if files:
for file in files:
page_id = file.id
page_ids.append(page_id)
async with db_session_context() as session:
try:
stmt = delete(File).where(File.id.in_(file_ids))
stmt = delete(File).where(File.id.in_(page_ids))
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
except Exception as e:
print(e)
print("error in delete from db",e)
return None

async def retrieve_files(self, resource_id: str, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None,
Expand All @@ -225,10 +244,16 @@ async def retrieve_files(self, resource_id: str, offset: int = 0, limit: int = 1

result = await session.execute(query)
files = [FileSchema.from_orm(file) for file in result.scalars().all()]
seen_file_ids = set()
unique_files = []
for file in files:
if file.file_id not in seen_file_ids:
unique_files.append(file)
seen_file_ids.add(file.file_id)

total_count = await self._get_total_count(filters)

return files, total_count
return unique_files, total_count

def _apply_filters(self, query, filters: Optional[Dict[str, Any]]):
if filters:
Expand Down
1 change: 1 addition & 0 deletions backend/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class File(Base):
name = Column(String, nullable=False)
assistant_id = Column(String, nullable=False)
num_chunks = Column(String, nullable=False)
file_id = Column(String, nullable=False)

class Message(Base):
__tablename__ = "message"
Expand Down
2 changes: 1 addition & 1 deletion backend/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class ConversationSchema(ConversationBaseSchema):
class FileBaseSchema(BaseModel):
name: str
num_chunks: str

file_id: str
class Config:
orm_mode = True
from_attributes = True
Expand Down
5 changes: 3 additions & 2 deletions migrations/versions/29df33c77244_added_file_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ def upgrade() -> None:
op.create_table('file',
sa.Column('id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('name', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('num_chunks', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('assistant_id', sqlmodel.sql.sqltypes.AutoString(), nullable=True),
sa.Column('file_id', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column('num_chunks', sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint('id')
)

Expand Down

0 comments on commit 4c8e4d9

Please sign in to comment.