From 5b4b38ef1a7f8974f7858085d7535c2b9618c7e1 Mon Sep 17 00:00:00 2001 From: Sam Johnston Date: Sat, 20 Jul 2024 12:50:21 -0700 Subject: [PATCH 1/5] migrate to SQLAlchemy --- backend/db.py | 46 ++++++---- backend/env.py | 2 +- backend/managers/AssetsManager.py | 127 +++++++++++++--------------- backend/managers/ChannelsManager.py | 99 +++++++++++----------- backend/managers/ConfigManager.py | 44 ++++++---- backend/managers/UsersManager.py | 98 ++++++++++----------- backend/models.py | 43 +++++----- backend/requirements.txt | 2 +- backend/tests/test_db.py | 76 +++++++---------- common/paths.py | 5 +- 10 files changed, 271 insertions(+), 271 deletions(-) diff --git a/backend/db.py b/backend/db.py index 11260491..32ba0b1c 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,33 +1,45 @@ # database helper functions import os -import aiosqlite import logging +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from alembic import command from alembic.config import Config as AlembicConfig -from common.paths import base_dir, db_path +from common.paths import base_dir, db_path, db_url +from contextlib import asynccontextmanager logger = logging.getLogger(__name__) +# Define the SQLAlchemy Base +Base = declarative_base() + +# Create async engine +engine = create_async_engine(db_url, echo=True) + +# Create async session factory +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, +) + # use alembic to create the database or migrate to the latest schema def init_db(): logger.info("Initializing database...") alembic_cfg = AlembicConfig() os.makedirs(db_path.parent, exist_ok=True) alembic_cfg.set_main_option("script_location", str(base_dir / "migrations")) - alembic_cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") + alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("+aiosqlite", "")) # because Alembic doesn't like async apparently command.upgrade(alembic_cfg, "head") -async def execute_query(query, params=None): - # TODO: logger.adebug from structlog - logger.debug(f"Executing query: {query} with params: {params}") - async with aiosqlite.connect(db_path) as conn: - async with conn.cursor() as cursor: - try: - await cursor.execute(query, params or ()) - result = await cursor.fetchall() - await conn.commit() - except Exception as e: - await conn.rollback() - raise e - - return result +@asynccontextmanager +async def db_session_context(): + session = AsyncSessionLocal() + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() \ No newline at end of file diff --git a/backend/env.py b/backend/env.py index 46efa123..d8071ff6 100644 --- a/backend/env.py +++ b/backend/env.py @@ -29,7 +29,7 @@ def check_env(): print(f"\nOnce you have activated the virtual environment, run this again.") sys.exit(1) - required_modules = ['connexion', 'uvicorn'] + required_modules = ['connexion', 'uvicorn', 'sqlalchemy', 'alembic', 'aiosqlite'] for module in required_modules: try: __import__(module) diff --git a/backend/managers/AssetsManager.py b/backend/managers/AssetsManager.py index 4cb1b6d8..61d7a7b8 100644 --- a/backend/managers/AssetsManager.py +++ b/backend/managers/AssetsManager.py @@ -1,7 +1,9 @@ from uuid import uuid4 -import backend.db as db -from backend.utils import remove_null_fields, zip_fields from threading import Lock +from sqlalchemy import select, insert, update, delete, func, or_ +from backend.models import Asset +from backend.db import db_session_context +from backend.utils import remove_null_fields class AssetsManager: _instance = None @@ -18,80 +20,69 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_asset(self, user_id, title, creator, subject, description): - id = str(uuid4()) - query = 'INSERT INTO asset (id, user_id, title, creator, subject, description) VALUES (?, ?, ?, ?, ?, ?)' - await db.execute_query(query, (id, user_id, title, creator, subject, description)) - return id + async with db_session_context() as session: + new_asset = Asset(id=str(uuid4()), user_id=user_id, title=title, creator=creator, subject=subject, description=description) + session.add(new_asset) + await session.commit() + return new_asset.id async def update_asset(self, id, user_id, title, creator, subject, description): - query = 'INSERT OR REPLACE INTO asset (id, user_id, title, creator, subject, description) VALUES (?, ?, ?, ?, ?, ?)' - return await db.execute_query(query, (id, user_id, title, creator, subject, description)) + async with db_session_context() as session: + stmt = update(Asset).where(Asset.id == id).values(user_id=user_id, title=title, creator=creator, subject=subject, description=description) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 async def delete_asset(self, id): - query = 'DELETE FROM asset WHERE id = ?' - return await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(Asset).where(Asset.id == id) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 async def retrieve_asset(self, id): - query = 'SELECT user_id, title, creator, subject, description FROM asset WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - fields = ['user_id', 'title', 'creator', 'subject', 'description'] - asset = remove_null_fields(zip_fields(fields, result[0])) - asset['id'] = id - return asset - return None + async with db_session_context() as session: + result = await session.execute(select(Asset).filter(Asset.id == id)) + asset = result.scalar_one_or_none() + return remove_null_fields(asset.to_dict()) if asset else None async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None, query=None): - base_query = 'SELECT id, user_id, title, creator, subject, description FROM asset' - query_params = [] - - # Apply filters - filter_clauses = [] - if filters: - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - - # Apply free text search - if query: - query_clause = "(title LIKE ? OR description LIKE ? OR creator LIKE ? OR subject LIKE ?)" - query_params.extend([f"%{query}%"] * 4) - filter_clauses.append(query_clause) - - if filter_clauses: - base_query += ' WHERE ' + ' AND '.join(filter_clauses) - - # Validate and apply sorting - valid_sort_columns = ['id', 'user_id', 'title', 'creator', 'subject', 'description'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' - - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) - - # Execute the main query - results = await db.execute_query(base_query, tuple(query_params)) - - fields = ['id', 'user_id', 'title', 'creator', 'subject', 'description'] - assets = [remove_null_fields(zip_fields(fields, result)) for result in results] - - # Get the total count of assets - total_count_query = 'SELECT COUNT(*) FROM asset' - total_count_params = query_params[:-2] # Exclude limit and offset for the count query - if filter_clauses: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(total_count_params)) - total_count = total_count_result[0][0] if total_count_result else 0 - - return assets, total_count + async with db_session_context() as session: + stmt = select(Asset) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + stmt = stmt.filter(getattr(Asset, key).in_(value)) + else: + stmt = stmt.filter(getattr(Asset, key) == value) + + if query: + search_condition = or_( + Asset.title.ilike(f"%{query}%"), + Asset.description.ilike(f"%{query}%"), + Asset.creator.ilike(f"%{query}%"), + Asset.subject.ilike(f"%{query}%") + ) + stmt = stmt.filter(search_condition) + + if sort_by and hasattr(Asset, sort_by): + order_column = getattr(Asset, sort_by) + stmt = stmt.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) + + stmt = stmt.offset(offset).limit(limit) + + result = await session.execute(stmt) + assets = [remove_null_fields(asset.to_dict()) for asset in result.scalars().all()] + + # Get total count + count_stmt = select(func.count()).select_from(Asset) + if filters or query: + count_stmt = count_stmt.filter(stmt.whereclause) + total_count = await session.execute(count_stmt) + total_count = total_count.scalar() + + return assets, total_count \ No newline at end of file diff --git a/backend/managers/ChannelsManager.py b/backend/managers/ChannelsManager.py index efd66af7..72408a92 100644 --- a/backend/managers/ChannelsManager.py +++ b/backend/managers/ChannelsManager.py @@ -1,6 +1,8 @@ from uuid import uuid4 -import backend.db as db from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import Channel +from backend.db import db_session_context class ChannelsManager: _instance = None @@ -17,68 +19,63 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_channel(self, name, uri): - id = str(uuid4()) - query = 'INSERT INTO channel (id, name, uri) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, uri)) - return id + async with db_session_context() as session: + new_channel = Channel(id=str(uuid4()), name=name, uri=uri) + session.add(new_channel) + await session.commit() + return new_channel.id async def update_channel(self, id, name, uri): - query = 'INSERT OR REPLACE INTO channel (id, name, uri) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, uri)) + async with db_session_context() as session: + stmt = update(Channel).where(Channel.id == id).values(name=name, uri=uri) + await session.execute(stmt) + await session.commit() async def delete_channel(self, id): - query = 'DELETE FROM channel WHERE id = ?' - await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(Channel).where(Channel.id == id) + await session.execute(stmt) + await session.commit() async def retrieve_channel(self, id): - query = 'SELECT name, uri FROM channel WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - return {'id': id, 'name': result[0][0], 'uri': result[0][1]} - return None - + async with db_session_context() as session: + result = await session.execute(select(Channel).filter(Channel.id == id)) + channel = result.scalar_one_or_none() + return channel.to_dict() if channel else None + async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): - base_query = 'SELECT id, name, uri FROM channel' - query_params = [] + async with db_session_context() as session: + query = select(Channel) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(Channel, key).in_(value)) + else: + query = query.filter(getattr(Channel, key) == value) + + if sort_by and sort_by in ['id', 'name', 'uri']: + order_column = getattr(Channel, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) - # Apply filters - if filters: - filter_clauses = [] - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - base_query += ' WHERE ' + ' AND '.join(filter_clauses) + query = query.offset(offset).limit(limit) - # Validate and apply sorting - valid_sort_columns = ['id', 'name', 'uri'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' + result = await session.execute(query) + channels = [channel.to_dict() for channel in result.scalars().all()] - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) + # Get total count + count_query = select(func.count()).select_from(Channel) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(Channel, key).in_(value)) + else: + count_query = count_query.filter(getattr(Channel, key) == value) - results = await db.execute_query(base_query, tuple(query_params)) - - channels = [] - for result in results: - channels.append({'id': result[0], 'name': result[1], 'uri': result[2]}) - - # Assuming you have a way to get the total count of channels - total_count_query = 'SELECT COUNT(*) FROM channel' - if filters: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(query_params[:len(query_params) - 2] if filters else ())) - total_count = total_count_result[0][0] if total_count_result else 0 + total_count = await session.execute(count_query) + total_count = total_count.scalar() - return channels, total_count + return channels, total_count diff --git a/backend/managers/ConfigManager.py b/backend/managers/ConfigManager.py index 0d599577..fa5df1b2 100644 --- a/backend/managers/ConfigManager.py +++ b/backend/managers/ConfigManager.py @@ -1,7 +1,9 @@ from uuid import uuid4 -import backend.db as db -from backend.encryption import Encryption from threading import Lock +from sqlalchemy import select, insert, update, delete +from backend.models import Config +from backend.db import db_session_context, init_db +from backend.encryption import Encryption class ConfigManager: _instance = None @@ -20,32 +22,38 @@ def __init__(self, tenant=None): if not hasattr(self, '_initialized'): self.encryption = Encryption() self.tenant = tenant - db.init_db() + init_db() self._initialized = True - # CRUD operations - # Note: Creating a new config item without specifying a key is unusual; use update_config_item instead. async def create_config_item(self, value): key = str(uuid4()) encrypted_value = self.encryption.encrypt_value(value) - print(f"ConfigManager: create_config_item {encrypted_value}") - query = 'INSERT INTO config (key, value) VALUES (?, ?)' - await db.execute_query(query, (key, encrypted_value)) + async with db_session_context() as session: + new_config = Config(key=key, value=encrypted_value) + session.add(new_config) + await session.commit() return key - + async def retrieve_config_item(self, key): - query = 'SELECT value FROM config WHERE key = ?' - result = await db.execute_query(query, (key,)) - if result: - encrypted_value = result[0][0] - return self.encryption.decrypt_value(encrypted_value) + async with db_session_context() as session: + result = await session.execute(select(Config).filter(Config.key == key)) + config = result.scalar_one_or_none() + if config: + return self.encryption.decrypt_value(config.value) return None async def update_config_item(self, key, value): encrypted_value = self.encryption.encrypt_value(value) - query = 'INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)' - await db.execute_query(query, (key, encrypted_value)) + async with db_session_context() as session: + stmt = update(Config).where(Config.key == key).values(value=encrypted_value) + result = await session.execute(stmt) + if result.rowcount == 0: + new_config = Config(key=key, value=encrypted_value) + session.add(new_config) + await session.commit() async def delete_config_item(self, key): - query = 'DELETE FROM config WHERE key = ?' - await db.execute_query(query, (key,)) + async with db_session_context() as session: + stmt = delete(Config).where(Config.key == key) + await session.execute(stmt) + await session.commit() diff --git a/backend/managers/UsersManager.py b/backend/managers/UsersManager.py index a2006d02..72c4aec5 100644 --- a/backend/managers/UsersManager.py +++ b/backend/managers/UsersManager.py @@ -1,6 +1,8 @@ from uuid import uuid4 -import backend.db as db from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import User +from backend.db import db_session_context class UsersManager: _instance = None @@ -17,69 +19,63 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_user(self, name, email): - id = str(uuid4()) - query = 'INSERT INTO user (id, name, email) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, email)) - return id + async with db_session_context() as session: + new_user = User(id=str(uuid4()), name=name, email=email) + session.add(new_user) + await session.commit() + return new_user.id async def update_user(self, id, name, email): - query = 'INSERT OR REPLACE INTO user (id, name, email) VALUES (?, ?, ?)' - return await db.execute_query(query, (id, name, email)) + async with db_session_context() as session: + stmt = update(User).where(User.id == id).values(name=name, email=email) + await session.execute(stmt) + await session.commit() async def delete_user(self, id): - query = 'DELETE FROM user WHERE id = ?' - return await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(User).where(User.id == id) + await session.execute(stmt) + await session.commit() async def retrieve_user(self, id): - query = 'SELECT name, email FROM user WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - return {'id': id, 'name': result[0][0], 'email': result[0][1]} - return None + async with db_session_context() as session: + result = await session.execute(select(User).filter(User.id == id)) + user = result.scalar_one_or_none() + return user.to_dict() if user else None async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): - base_query = 'SELECT id, name, email FROM user' - query_params = [] + async with db_session_context() as session: + query = select(User) - # Apply filters - if filters: - filter_clauses = [] - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - base_query += ' WHERE ' + ' AND '.join(filter_clauses) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(User, key).in_(value)) + else: + query = query.filter(getattr(User, key) == value) - # Validate and apply sorting - valid_sort_columns = ['id', 'name', 'email'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' + if sort_by and sort_by in ['id', 'name', 'email']: + order_column = getattr(User, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) + query = query.offset(offset).limit(limit) - results = await db.execute_query(base_query, tuple(query_params)) - - users = [] - for result in results: - users.append({'id': result[0], 'name': result[1], 'email': result[2]}) - - # Assuming you have a way to get the total count of users - total_count_query = 'SELECT COUNT(*) FROM user' - if filters: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(query_params[:len(query_params) - 2] if filters else ())) - total_count = total_count_result[0][0] if total_count_result else 0 + result = await session.execute(query) + users = [user.to_dict() for user in result.scalars().all()] - return users, total_count - \ No newline at end of file + # Get total count + count_query = select(func.count()).select_from(User) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(User, key).in_(value)) + else: + count_query = count_query.filter(getattr(User, key) == value) + + total_count = await session.execute(count_query) + total_count = total_count.scalar() + + return users, total_count \ No newline at end of file diff --git a/backend/models.py b/backend/models.py index c7cf029d..2320655b 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,23 +1,28 @@ -from sqlmodel import Field, SQLModel +from sqlalchemy import Column, String +from backend.db import Base -class Config(SQLModel, table=True): - key: str = Field(default=None, primary_key=True) - value: str = Field(nullable=True) +class Config(Base): + __tablename__ = "config" + key = Column(String, primary_key=True) + value = Column(String, nullable=True) -class Channel(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - name: str = Field(nullable=False) - uri: str = Field(nullable=False) +class Channel(Base): + __tablename__ = "channel" + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + uri = Column(String, nullable=False) -class User(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - name: str = Field(nullable=False) - email: str = Field(nullable=False) +class User(Base): + __tablename__ = "user" + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + email = Column(String, nullable=False) -class Asset(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - user_id: str = Field(nullable=True) - title: str = Field(nullable=False) - creator: str = Field(nullable=True) - subject: str = Field(nullable=True) - description: str = Field(nullable=True) +class Asset(Base): + __tablename__ = "asset" + id = Column(String, primary_key=True) + user_id = Column(String, nullable=True) + title = Column(String, nullable=False) + creator = Column(String, nullable=True) + subject = Column(String, nullable=True) + description = Column(String, nullable=True) diff --git a/backend/requirements.txt b/backend/requirements.txt index 37f8ce02..a9e2b17b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,7 +7,7 @@ cryptography setuptools packaging alembic -sqlmodel +sqlalchemy aiosqlite asyncio aiohttp diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index 885e3a0b..17e755dc 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -1,60 +1,48 @@ import unittest -from db import create_config_item, read_config_item, update_config_item, delete_config_item, set_config_item +from backend.managers.ConfigManager import ConfigManager +import asyncio -class TestDbFunctions(unittest.TestCase): +class TestConfigManager(unittest.TestCase): def setUp(self): - pass + self.config_manager = ConfigManager() + + def asyncTest(func): + def wrapper(*args, **kwargs): + return asyncio.run(func(*args, **kwargs)) + return wrapper - def test_create_config_item(self): - # Test that create_config_item correctly encrypts the value and inserts it into the database - tenant = 'test' - key = 'test_key' + @asyncTest + async def test_create_config_item(self): value = 'test_value' - create_config_item(key, value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + key = await self.config_manager.create_config_item(value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, value) - def test_read_config_item(self): - # Test that read_config_item correctly retrieves and decrypts a value from the database - tenant = 'test' - key = 'test_key' + @asyncTest + async def test_read_config_item(self): value = 'test_value' - create_config_item(key, value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + key = await self.config_manager.create_config_item(value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, value) - def test_update_config_item(self): - # Test that update_config_item correctly updates a value in the database - tenant = 'test' - key = 'test_key' - old_value = 'old_test_value' - new_value = 'new_test_value' - create_config_item(key, old_value, tenant=tenant) - update_config_item(key, new_value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) - self.assertEqual(result, new_value) - - def test_set_config_item(self): - # Test that set_config_item correctly updates a value in the database - tenant = 'test' - key = 'test_key' - old_value = 'old_test_value' + @asyncTest + async def test_update_config_item(self): + value = 'test_value' + key = await self.config_manager.create_config_item(value) new_value = 'new_test_value' - set_config_item(key, old_value, tenant=tenant) - set_config_item(key, new_value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + await self.config_manager.update_config_item(key, new_value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, new_value) - def test_delete_config_item(self): - # Test that delete_config_item correctly removes a value from the database - tenant = 'test' - key = 'test_key' - delete_config_item(key, tenant=tenant) - result = read_config_item(key, tenant=tenant) + @asyncTest + async def test_delete_config_item(self): + value = 'test_value' + key = await self.config_manager.create_config_item(value) + await self.config_manager.delete_config_item(key) + result = await self.config_manager.retrieve_config_item(key) self.assertIsNone(result) def tearDown(self): diff --git a/common/paths.py b/common/paths.py index 5be2e94c..d30711c5 100644 --- a/common/paths.py +++ b/common/paths.py @@ -30,5 +30,8 @@ abilities_data_dir = data_dir / abilities_subdir # paths -db_path = data_dir / 'paios.db' +db_name = 'paios.db' +db_path = data_dir / db_name +db_url = f"sqlite+aiosqlite:///{db_path}" downloads_dir = data_dir / 'downloads' + From 95326a11b3cb582b69ef74459fc94cced589774d Mon Sep 17 00:00:00 2001 From: Sam Johnston Date: Sat, 20 Jul 2024 15:22:55 -0700 Subject: [PATCH 2/5] add pydantic schemas --- backend/api/AbilitiesView.py | 1 - backend/api/AssetsView.py | 40 +++++++++------------- backend/api/ChannelsView.py | 24 +++++++++----- backend/api/ConfigView.py | 31 +++++++++++++----- backend/api/DownloadsView.py | 1 - backend/api/UsersView.py | 12 +++++-- backend/managers/AssetsManager.py | 51 +++++++++++++++++++++-------- backend/managers/ChannelsManager.py | 38 +++++++++++++-------- backend/managers/ConfigManager.py | 17 ++++++++-- backend/managers/UsersManager.py | 9 +++-- backend/schemas.py | 48 +++++++++++++++++++++++++++ 11 files changed, 192 insertions(+), 80 deletions(-) create mode 100644 backend/schemas.py diff --git a/backend/api/AbilitiesView.py b/backend/api/AbilitiesView.py index b493129b..09278e54 100644 --- a/backend/api/AbilitiesView.py +++ b/backend/api/AbilitiesView.py @@ -1,7 +1,6 @@ from starlette.responses import JSONResponse from backend.managers.AbilitiesManager import AbilitiesManager from backend.pagination import parse_pagination_params -from pkg_resources import ContextualVersionConflict import logging logger = logging.getLogger(__name__) diff --git a/backend/api/AssetsView.py b/backend/api/AssetsView.py index 6b4d82fe..6caec5c1 100644 --- a/backend/api/AssetsView.py +++ b/backend/api/AssetsView.py @@ -2,6 +2,8 @@ from backend.managers.AssetsManager import AssetsManager from common.paths import api_base_url from backend.pagination import parse_pagination_params +from backend.schemas import AssetCreateSchema, AssetSchema +from typing import List class AssetsView: def __init__(self): @@ -11,34 +13,22 @@ async def get(self, id: str): asset = await self.am.retrieve_asset(id) if asset is None: return JSONResponse({"error": "Asset not found"}, status_code=404) - return JSONResponse(asset, status_code=200) + return JSONResponse(asset.model_dump(), status_code=200) - async def post(self, body: dict): - asset_data = { - 'user_id': body.get('user_id'), - 'title': body.get('title'), - 'creator': body.get('creator'), - 'subject': body.get('subject'), - 'description': body.get('description') - } - id = await self.am.create_asset(**asset_data) - asset = await self.am.retrieve_asset(id) - return JSONResponse(asset, status_code=201, headers={'Location': f'{api_base_url}/assets/{id}'}) + async def post(self, body: AssetCreateSchema): + new_asset = await self.am.create_asset(body) + return JSONResponse(new_asset.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/assets/{new_asset.id}'}) - async def put(self, id: str, body: dict): - asset_data = { - 'user_id': body.get('user_id'), - 'title': body.get('title'), - 'creator': body.get('creator'), - 'subject': body.get('subject'), - 'description': body.get('description') - } - await self.am.update_asset(id, **asset_data) - asset = await self.am.retrieve_asset(id) - return JSONResponse(asset, status_code=200) + async def put(self, id: str, body: AssetCreateSchema): + updated_asset = await self.am.update_asset(id, body) + if updated_asset is None: + return JSONResponse({"error": "Asset not found"}, status_code=404) + return JSONResponse(updated_asset.model_dump(), status_code=200) async def delete(self, id: str): - await self.am.delete_asset(id) + success = await self.am.delete_asset(id) + if not success: + return JSONResponse({"error": "Asset not found"}, status_code=404) return Response(status_code=204) async def search(self, filter: str = None, range: str = None, sort: str = None): @@ -63,4 +53,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'X-Total-Count': str(total_count), 'Content-Range': f'assets {offset}-{offset + len(assets) - 1}/{total_count}' } - return JSONResponse(assets, status_code=200, headers=headers) \ No newline at end of file + return JSONResponse([asset.model_dump() for asset in assets], status_code=200, headers=headers) diff --git a/backend/api/ChannelsView.py b/backend/api/ChannelsView.py index 6d47d3f9..9f3ebadb 100644 --- a/backend/api/ChannelsView.py +++ b/backend/api/ChannelsView.py @@ -2,6 +2,8 @@ from common.paths import api_base_url from backend.managers.ChannelsManager import ChannelsManager from backend.pagination import parse_pagination_params +from backend.schemas import ChannelCreateSchema +from typing import List class ChannelsView: def __init__(self): @@ -11,18 +13,22 @@ async def get(self, channel_id: str): channel = await self.cm.retrieve_channel(channel_id) if channel is None: return JSONResponse({"error": "Channel not found"}, status_code=404) - return JSONResponse(channel, status_code=200) + return JSONResponse(channel.model_dump(), status_code=200) - async def post(self, body: dict): - channel_id = await self.cm.create_channel(body['name'], body['uri']) - return JSONResponse({"id": channel_id}, status_code=201, headers={'Location': f'{api_base_url}/channels/{channel_id}'}) + async def post(self, body: ChannelCreateSchema): + new_channel = await self.cm.create_channel(body) + return JSONResponse(new_channel.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/channels/{new_channel.id}'}) - async def put(self, channel_id: str, body: dict): - await self.cm.update_channel(channel_id, body['name'], body['uri']) - return JSONResponse({"message": "Channel updated successfully"}, status_code=200) + async def put(self, channel_id: str, body: ChannelCreateSchema): + updated_channel = await self.cm.update_channel(channel_id, body) + if updated_channel is None: + return JSONResponse({"error": "Channel not found"}, status_code=404) + return JSONResponse(updated_channel.model_dump(), status_code=200) async def delete(self, channel_id: str): - await self.cm.delete_channel(channel_id) + success = await self.cm.delete_channel(channel_id) + if not success: + return JSONResponse({"error": "Channel not found"}, status_code=404) return Response(status_code=204) async def search(self, filter: str = None, range: str = None, sort: str = None): @@ -37,4 +43,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'X-Total-Count': str(total_count), 'Content-Range': f'channels {offset}-{offset + len(channels) - 1}/{total_count}' } - return JSONResponse(channels, status_code=200, headers=headers) + return JSONResponse([channel.model_dump() for channel in channels], status_code=200, headers=headers) diff --git a/backend/api/ConfigView.py b/backend/api/ConfigView.py index 587fabee..9413d536 100644 --- a/backend/api/ConfigView.py +++ b/backend/api/ConfigView.py @@ -1,21 +1,34 @@ from starlette.responses import JSONResponse, Response from backend.managers.ConfigManager import ConfigManager +from backend.schemas import ConfigSchema class ConfigView: def __init__(self): self.cm = ConfigManager() async def get(self, key: str): - value = await self.cm.retrieve_config_item(key) - if value is None: - return JSONResponse(status_code=404, headers={"error": "Config item not found"}) - return JSONResponse(value, status_code=200) + config_item = await self.cm.retrieve_config_item(key) + if config_item is None: + return JSONResponse(status_code=404, content={"error": "Config item not found"}) + return JSONResponse(config_item.model_dump(), status_code=200) - async def put(self, key: str, body: dict): + async def put(self, key: str, body: ConfigSchema): print(f"ConfigView: PUT {key}->{body}") - await self.cm.update_config_item(key, body) - return JSONResponse({"message": "Config item updated successfully"}, status_code=200) + updated_config = await self.cm.update_config_item(key, body.value) + if updated_config: + return JSONResponse(updated_config.model_dump(), status_code=200) + return JSONResponse({"error": "Failed to update config item"}, status_code=400) async def delete(self, key: str): - await self.cm.delete_config_item(key) - return Response(status_code=204) + success = await self.cm.delete_config_item(key) + if success: + return Response(status_code=204) + return JSONResponse({"error": "Config item not found"}, status_code=404) + + async def list(self): + config_items = await self.cm.retrieve_all_config_items() + return JSONResponse([item.model_dump() for item in config_items], status_code=200) + + async def create(self, body: ConfigSchema): + new_config = await self.cm.create_config_item(body.value) + return JSONResponse(new_config.model_dump(), status_code=201) diff --git a/backend/api/DownloadsView.py b/backend/api/DownloadsView.py index fa172a80..e91a33e9 100644 --- a/backend/api/DownloadsView.py +++ b/backend/api/DownloadsView.py @@ -1,4 +1,3 @@ -from starlette.requests import Request from starlette.responses import Response, JSONResponse from backend.managers.DownloadsManager import DownloadsManager from backend.pagination import parse_pagination_params diff --git a/backend/api/UsersView.py b/backend/api/UsersView.py index c6c44967..f2cbd803 100644 --- a/backend/api/UsersView.py +++ b/backend/api/UsersView.py @@ -3,6 +3,7 @@ from backend.managers.UsersManager import UsersManager from backend.pagination import parse_pagination_params from aiosqlite import IntegrityError +from backend.schemas import UserSchema class UsersView: def __init__(self): @@ -12,7 +13,7 @@ async def get(self, id: str): user = await self.um.retrieve_user(id) if user is None: return JSONResponse(status_code=404, headers={"error": "User not found"}) - return JSONResponse(user, status_code=200) + return JSONResponse(user.model_dump(), status_code=200) async def post(self, body: dict): try: @@ -37,8 +38,13 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result users, total_count = await self.um.retrieve_users(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + + # Convert Pydantic models to dictionaries + users_dict = [user.model_dump() for user in users] + headers = { 'X-Total-Count': str(total_count), - 'Content-Range': f'users {offset}-{offset + len(users) - 1}/{total_count}' + 'Content-Range': f'users {offset}-{offset+len(users)}/{total_count}', + 'Access-Control-Expose-Headers': 'Content-Range' } - return JSONResponse(users, status_code=200, headers=headers) + return JSONResponse(users_dict, status_code=200, headers=headers) diff --git a/backend/managers/AssetsManager.py b/backend/managers/AssetsManager.py index 61d7a7b8..1ba8d02d 100644 --- a/backend/managers/AssetsManager.py +++ b/backend/managers/AssetsManager.py @@ -3,7 +3,8 @@ from sqlalchemy import select, insert, update, delete, func, or_ from backend.models import Asset from backend.db import db_session_context -from backend.utils import remove_null_fields +from backend.schemas import AssetSchema, AssetCreateSchema +from typing import List, Tuple, Optional, Dict, Any class AssetsManager: _instance = None @@ -22,34 +23,49 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_asset(self, user_id, title, creator, subject, description): + async def create_asset(self, asset_data: AssetCreateSchema) -> AssetSchema: async with db_session_context() as session: - new_asset = Asset(id=str(uuid4()), user_id=user_id, title=title, creator=creator, subject=subject, description=description) + new_asset = Asset(id=str(uuid4()), **asset_data.model_dump()) session.add(new_asset) await session.commit() - return new_asset.id + await session.refresh(new_asset) + return AssetSchema(id=new_asset.id, **asset_data.model_dump()) - async def update_asset(self, id, user_id, title, creator, subject, description): + async def update_asset(self, id: str, asset_data: AssetCreateSchema) -> Optional[AssetSchema]: async with db_session_context() as session: - stmt = update(Asset).where(Asset.id == id).values(user_id=user_id, title=title, creator=creator, subject=subject, description=description) + stmt = update(Asset).where(Asset.id == id).values(**asset_data.model_dump(exclude_unset=True)) result = await session.execute(stmt) - await session.commit() - return result.rowcount > 0 + if result.rowcount > 0: + await session.commit() + updated_asset = await session.get(Asset, id) + return AssetSchema(id=updated_asset.id, **asset_data.model_dump()) + return None - async def delete_asset(self, id): + async def delete_asset(self, id: str) -> bool: async with db_session_context() as session: stmt = delete(Asset).where(Asset.id == id) result = await session.execute(stmt) await session.commit() return result.rowcount > 0 - async def retrieve_asset(self, id): + async def retrieve_asset(self, id: str) -> Optional[AssetSchema]: async with db_session_context() as session: result = await session.execute(select(Asset).filter(Asset.id == id)) asset = result.scalar_one_or_none() - return remove_null_fields(asset.to_dict()) if asset else None + if asset: + return AssetSchema( + id=asset.id, + title=asset.title, + user_id=asset.user_id, + creator=asset.creator, + subject=asset.subject, + description=asset.description + ) + return None - async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None, query=None): + async def retrieve_assets(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None, + query: Optional[str] = None) -> Tuple[List[AssetSchema], int]: async with db_session_context() as session: stmt = select(Asset) @@ -76,7 +92,14 @@ async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='a stmt = stmt.offset(offset).limit(limit) result = await session.execute(stmt) - assets = [remove_null_fields(asset.to_dict()) for asset in result.scalars().all()] + assets = [AssetSchema( + id=asset.id, + title=asset.title, + user_id=asset.user_id, + creator=asset.creator, + subject=asset.subject, + description=asset.description + ) for asset in result.scalars().all()] # Get total count count_stmt = select(func.count()).select_from(Asset) @@ -85,4 +108,4 @@ async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='a total_count = await session.execute(count_stmt) total_count = total_count.scalar() - return assets, total_count \ No newline at end of file + return assets, total_count diff --git a/backend/managers/ChannelsManager.py b/backend/managers/ChannelsManager.py index 72408a92..0ef808cd 100644 --- a/backend/managers/ChannelsManager.py +++ b/backend/managers/ChannelsManager.py @@ -3,6 +3,8 @@ from sqlalchemy import select, insert, update, delete, func from backend.models import Channel from backend.db import db_session_context +from backend.schemas import ChannelCreateSchema, ChannelSchema +from typing import List, Tuple, Optional, Dict, Any class ChannelsManager: _instance = None @@ -21,32 +23,41 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_channel(self, name, uri): + async def create_channel(self, channel_data: ChannelCreateSchema) -> ChannelSchema: async with db_session_context() as session: - new_channel = Channel(id=str(uuid4()), name=name, uri=uri) + new_channel = Channel(id=str(uuid4()), **channel_data.model_dump()) session.add(new_channel) await session.commit() - return new_channel.id + await session.refresh(new_channel) + return ChannelSchema(id=new_channel.id, **channel_data.model_dump()) - async def update_channel(self, id, name, uri): + async def update_channel(self, id: str, channel_data: ChannelCreateSchema) -> Optional[ChannelSchema]: async with db_session_context() as session: - stmt = update(Channel).where(Channel.id == id).values(name=name, uri=uri) - await session.execute(stmt) - await session.commit() + stmt = update(Channel).where(Channel.id == id).values(**channel_data.dict()) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + updated_channel = await session.get(Channel, id) + return ChannelSchema(id=updated_channel.id, **channel_data.model_dump()) + return None - async def delete_channel(self, id): + async def delete_channel(self, id: str) -> bool: async with db_session_context() as session: stmt = delete(Channel).where(Channel.id == id) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + return result.rowcount > 0 - async def retrieve_channel(self, id): + async def retrieve_channel(self, id: str) -> Optional[ChannelSchema]: async with db_session_context() as session: result = await session.execute(select(Channel).filter(Channel.id == id)) channel = result.scalar_one_or_none() - return channel.to_dict() if channel else None + if channel: + return ChannelSchema(id=channel.id, name=channel.name, uri=channel.uri) + return None - async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): + async def retrieve_channels(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[ChannelSchema], int]: async with db_session_context() as session: query = select(Channel) @@ -64,7 +75,8 @@ async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order= query = query.offset(offset).limit(limit) result = await session.execute(query) - channels = [channel.to_dict() for channel in result.scalars().all()] + channels = [ChannelSchema(id=channel.id, name=channel.name, uri=channel.uri) + for channel in result.scalars().all()] # Get total count count_query = select(func.count()).select_from(Channel) diff --git a/backend/managers/ConfigManager.py b/backend/managers/ConfigManager.py index fa5df1b2..f147a192 100644 --- a/backend/managers/ConfigManager.py +++ b/backend/managers/ConfigManager.py @@ -4,6 +4,7 @@ from backend.models import Config from backend.db import db_session_context, init_db from backend.encryption import Encryption +from backend.schemas import ConfigSchema class ConfigManager: _instance = None @@ -32,14 +33,15 @@ async def create_config_item(self, value): new_config = Config(key=key, value=encrypted_value) session.add(new_config) await session.commit() - return key + return ConfigSchema(key=key, value=value) async def retrieve_config_item(self, key): async with db_session_context() as session: result = await session.execute(select(Config).filter(Config.key == key)) config = result.scalar_one_or_none() if config: - return self.encryption.decrypt_value(config.value) + decrypted_value = self.encryption.decrypt_value(config.value) + return ConfigSchema(key=config.key, value=decrypted_value) return None async def update_config_item(self, key, value): @@ -51,9 +53,18 @@ async def update_config_item(self, key, value): new_config = Config(key=key, value=encrypted_value) session.add(new_config) await session.commit() + return ConfigSchema(key=key, value=value) async def delete_config_item(self, key): async with db_session_context() as session: stmt = delete(Config).where(Config.key == key) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + return result.rowcount > 0 + + async def retrieve_all_config_items(self): + async with db_session_context() as session: + result = await session.execute(select(Config)) + configs = result.scalars().all() + return [ConfigSchema(key=config.key, value=self.encryption.decrypt_value(config.value)) + for config in configs] \ No newline at end of file diff --git a/backend/managers/UsersManager.py b/backend/managers/UsersManager.py index 72c4aec5..2ef7cc93 100644 --- a/backend/managers/UsersManager.py +++ b/backend/managers/UsersManager.py @@ -3,6 +3,7 @@ from sqlalchemy import select, insert, update, delete, func from backend.models import User from backend.db import db_session_context +from backend.schemas import UserSchema class UsersManager: _instance = None @@ -44,7 +45,7 @@ async def retrieve_user(self, id): async with db_session_context() as session: result = await session.execute(select(User).filter(User.id == id)) user = result.scalar_one_or_none() - return user.to_dict() if user else None + return UserSchema(id=user.id, name=user.name, email=user.email) if user else None async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): async with db_session_context() as session: @@ -64,7 +65,11 @@ async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='as query = query.offset(offset).limit(limit) result = await session.execute(query) - users = [user.to_dict() for user in result.scalars().all()] + users = [UserSchema( + id=user.id, + name=user.name, + email=user.email + ) for user in result.scalars().all()] # Get total count count_query = select(func.count()).select_from(User) diff --git a/backend/schemas.py b/backend/schemas.py new file mode 100644 index 00000000..74602dfe --- /dev/null +++ b/backend/schemas.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel +from typing import Optional + +# We have *Create schemas because API clients ideally don't set the id field, it's set by the server +# Alternatively we could have made the id optional but then we would have to check if it's set by the client + +# Config schemas +class ConfigBaseSchema(BaseModel): + value: Optional[str] = None + +class ConfigSchema(ConfigBaseSchema): + key: str + +# Channel schemas +class ChannelBaseSchema(BaseModel): + name: str + uri: str + +class ChannelCreateSchema(ChannelBaseSchema): + pass + +class ChannelSchema(ChannelBaseSchema): + id: str + +# User schemas +class UserBaseSchema(BaseModel): + name: str + email: str + +class UserCreateSchema(UserBaseSchema): + pass + +class UserSchema(UserBaseSchema): + id: str + +# Asset schemas +class AssetBaseSchema(BaseModel): + title: str + user_id: Optional[str] = None + creator: Optional[str] = None + subject: Optional[str] = None + description: Optional[str] = None + +class AssetCreateSchema(AssetBaseSchema): + pass + +class AssetSchema(AssetBaseSchema): + id: str From 0a89cac5b227f6a36302acdb0592661126abe6a8 Mon Sep 17 00:00:00 2001 From: Sam Johnston Date: Sat, 20 Jul 2024 12:50:21 -0700 Subject: [PATCH 3/5] migrate to SQLAlchemy --- backend/db.py | 46 ++++++---- backend/env.py | 2 +- backend/managers/AssetsManager.py | 127 +++++++++++++--------------- backend/managers/ChannelsManager.py | 99 +++++++++++----------- backend/managers/ConfigManager.py | 44 ++++++---- backend/managers/UsersManager.py | 98 ++++++++++----------- backend/models.py | 43 +++++----- backend/requirements.txt | 2 +- backend/tests/test_db.py | 76 +++++++---------- common/paths.py | 5 +- 10 files changed, 271 insertions(+), 271 deletions(-) diff --git a/backend/db.py b/backend/db.py index 81e5ea21..2561b159 100644 --- a/backend/db.py +++ b/backend/db.py @@ -1,33 +1,45 @@ # database helper functions import os -import aiosqlite import logging +from sqlalchemy.orm import sessionmaker, declarative_base +from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from alembic import command from alembic.config import Config as AlembicConfig -from common.paths import base_dir, db_path +from common.paths import base_dir, db_path, db_url +from contextlib import asynccontextmanager logger = logging.getLogger(__name__) +# Define the SQLAlchemy Base +Base = declarative_base() + +# Create async engine +engine = create_async_engine(db_url, echo=True) + +# Create async session factory +AsyncSessionLocal = sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, +) + # use alembic to create the database or migrate to the latest schema def init_db(): logger.info("Initializing database.") alembic_cfg = AlembicConfig() os.makedirs(db_path.parent, exist_ok=True) alembic_cfg.set_main_option("script_location", str(base_dir / "migrations")) - alembic_cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") + alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("+aiosqlite", "")) # because Alembic doesn't like async apparently command.upgrade(alembic_cfg, "head") -async def execute_query(query, params=None): - # TODO: logger.adebug from structlog - logger.debug(f"Executing query: {query} with params: {params}") - async with aiosqlite.connect(db_path) as conn: - async with conn.cursor() as cursor: - try: - await cursor.execute(query, params or ()) - result = await cursor.fetchall() - await conn.commit() - except Exception as e: - await conn.rollback() - raise e - - return result +@asynccontextmanager +async def db_session_context(): + session = AsyncSessionLocal() + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + finally: + await session.close() \ No newline at end of file diff --git a/backend/env.py b/backend/env.py index 46efa123..d8071ff6 100644 --- a/backend/env.py +++ b/backend/env.py @@ -29,7 +29,7 @@ def check_env(): print(f"\nOnce you have activated the virtual environment, run this again.") sys.exit(1) - required_modules = ['connexion', 'uvicorn'] + required_modules = ['connexion', 'uvicorn', 'sqlalchemy', 'alembic', 'aiosqlite'] for module in required_modules: try: __import__(module) diff --git a/backend/managers/AssetsManager.py b/backend/managers/AssetsManager.py index 4cb1b6d8..61d7a7b8 100644 --- a/backend/managers/AssetsManager.py +++ b/backend/managers/AssetsManager.py @@ -1,7 +1,9 @@ from uuid import uuid4 -import backend.db as db -from backend.utils import remove_null_fields, zip_fields from threading import Lock +from sqlalchemy import select, insert, update, delete, func, or_ +from backend.models import Asset +from backend.db import db_session_context +from backend.utils import remove_null_fields class AssetsManager: _instance = None @@ -18,80 +20,69 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_asset(self, user_id, title, creator, subject, description): - id = str(uuid4()) - query = 'INSERT INTO asset (id, user_id, title, creator, subject, description) VALUES (?, ?, ?, ?, ?, ?)' - await db.execute_query(query, (id, user_id, title, creator, subject, description)) - return id + async with db_session_context() as session: + new_asset = Asset(id=str(uuid4()), user_id=user_id, title=title, creator=creator, subject=subject, description=description) + session.add(new_asset) + await session.commit() + return new_asset.id async def update_asset(self, id, user_id, title, creator, subject, description): - query = 'INSERT OR REPLACE INTO asset (id, user_id, title, creator, subject, description) VALUES (?, ?, ?, ?, ?, ?)' - return await db.execute_query(query, (id, user_id, title, creator, subject, description)) + async with db_session_context() as session: + stmt = update(Asset).where(Asset.id == id).values(user_id=user_id, title=title, creator=creator, subject=subject, description=description) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 async def delete_asset(self, id): - query = 'DELETE FROM asset WHERE id = ?' - return await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(Asset).where(Asset.id == id) + result = await session.execute(stmt) + await session.commit() + return result.rowcount > 0 async def retrieve_asset(self, id): - query = 'SELECT user_id, title, creator, subject, description FROM asset WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - fields = ['user_id', 'title', 'creator', 'subject', 'description'] - asset = remove_null_fields(zip_fields(fields, result[0])) - asset['id'] = id - return asset - return None + async with db_session_context() as session: + result = await session.execute(select(Asset).filter(Asset.id == id)) + asset = result.scalar_one_or_none() + return remove_null_fields(asset.to_dict()) if asset else None async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None, query=None): - base_query = 'SELECT id, user_id, title, creator, subject, description FROM asset' - query_params = [] - - # Apply filters - filter_clauses = [] - if filters: - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - - # Apply free text search - if query: - query_clause = "(title LIKE ? OR description LIKE ? OR creator LIKE ? OR subject LIKE ?)" - query_params.extend([f"%{query}%"] * 4) - filter_clauses.append(query_clause) - - if filter_clauses: - base_query += ' WHERE ' + ' AND '.join(filter_clauses) - - # Validate and apply sorting - valid_sort_columns = ['id', 'user_id', 'title', 'creator', 'subject', 'description'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' - - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) - - # Execute the main query - results = await db.execute_query(base_query, tuple(query_params)) - - fields = ['id', 'user_id', 'title', 'creator', 'subject', 'description'] - assets = [remove_null_fields(zip_fields(fields, result)) for result in results] - - # Get the total count of assets - total_count_query = 'SELECT COUNT(*) FROM asset' - total_count_params = query_params[:-2] # Exclude limit and offset for the count query - if filter_clauses: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(total_count_params)) - total_count = total_count_result[0][0] if total_count_result else 0 - - return assets, total_count + async with db_session_context() as session: + stmt = select(Asset) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + stmt = stmt.filter(getattr(Asset, key).in_(value)) + else: + stmt = stmt.filter(getattr(Asset, key) == value) + + if query: + search_condition = or_( + Asset.title.ilike(f"%{query}%"), + Asset.description.ilike(f"%{query}%"), + Asset.creator.ilike(f"%{query}%"), + Asset.subject.ilike(f"%{query}%") + ) + stmt = stmt.filter(search_condition) + + if sort_by and hasattr(Asset, sort_by): + order_column = getattr(Asset, sort_by) + stmt = stmt.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) + + stmt = stmt.offset(offset).limit(limit) + + result = await session.execute(stmt) + assets = [remove_null_fields(asset.to_dict()) for asset in result.scalars().all()] + + # Get total count + count_stmt = select(func.count()).select_from(Asset) + if filters or query: + count_stmt = count_stmt.filter(stmt.whereclause) + total_count = await session.execute(count_stmt) + total_count = total_count.scalar() + + return assets, total_count \ No newline at end of file diff --git a/backend/managers/ChannelsManager.py b/backend/managers/ChannelsManager.py index efd66af7..72408a92 100644 --- a/backend/managers/ChannelsManager.py +++ b/backend/managers/ChannelsManager.py @@ -1,6 +1,8 @@ from uuid import uuid4 -import backend.db as db from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import Channel +from backend.db import db_session_context class ChannelsManager: _instance = None @@ -17,68 +19,63 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_channel(self, name, uri): - id = str(uuid4()) - query = 'INSERT INTO channel (id, name, uri) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, uri)) - return id + async with db_session_context() as session: + new_channel = Channel(id=str(uuid4()), name=name, uri=uri) + session.add(new_channel) + await session.commit() + return new_channel.id async def update_channel(self, id, name, uri): - query = 'INSERT OR REPLACE INTO channel (id, name, uri) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, uri)) + async with db_session_context() as session: + stmt = update(Channel).where(Channel.id == id).values(name=name, uri=uri) + await session.execute(stmt) + await session.commit() async def delete_channel(self, id): - query = 'DELETE FROM channel WHERE id = ?' - await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(Channel).where(Channel.id == id) + await session.execute(stmt) + await session.commit() async def retrieve_channel(self, id): - query = 'SELECT name, uri FROM channel WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - return {'id': id, 'name': result[0][0], 'uri': result[0][1]} - return None - + async with db_session_context() as session: + result = await session.execute(select(Channel).filter(Channel.id == id)) + channel = result.scalar_one_or_none() + return channel.to_dict() if channel else None + async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): - base_query = 'SELECT id, name, uri FROM channel' - query_params = [] + async with db_session_context() as session: + query = select(Channel) + + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(Channel, key).in_(value)) + else: + query = query.filter(getattr(Channel, key) == value) + + if sort_by and sort_by in ['id', 'name', 'uri']: + order_column = getattr(Channel, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) - # Apply filters - if filters: - filter_clauses = [] - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - base_query += ' WHERE ' + ' AND '.join(filter_clauses) + query = query.offset(offset).limit(limit) - # Validate and apply sorting - valid_sort_columns = ['id', 'name', 'uri'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' + result = await session.execute(query) + channels = [channel.to_dict() for channel in result.scalars().all()] - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) + # Get total count + count_query = select(func.count()).select_from(Channel) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(Channel, key).in_(value)) + else: + count_query = count_query.filter(getattr(Channel, key) == value) - results = await db.execute_query(base_query, tuple(query_params)) - - channels = [] - for result in results: - channels.append({'id': result[0], 'name': result[1], 'uri': result[2]}) - - # Assuming you have a way to get the total count of channels - total_count_query = 'SELECT COUNT(*) FROM channel' - if filters: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(query_params[:len(query_params) - 2] if filters else ())) - total_count = total_count_result[0][0] if total_count_result else 0 + total_count = await session.execute(count_query) + total_count = total_count.scalar() - return channels, total_count + return channels, total_count diff --git a/backend/managers/ConfigManager.py b/backend/managers/ConfigManager.py index 0d599577..fa5df1b2 100644 --- a/backend/managers/ConfigManager.py +++ b/backend/managers/ConfigManager.py @@ -1,7 +1,9 @@ from uuid import uuid4 -import backend.db as db -from backend.encryption import Encryption from threading import Lock +from sqlalchemy import select, insert, update, delete +from backend.models import Config +from backend.db import db_session_context, init_db +from backend.encryption import Encryption class ConfigManager: _instance = None @@ -20,32 +22,38 @@ def __init__(self, tenant=None): if not hasattr(self, '_initialized'): self.encryption = Encryption() self.tenant = tenant - db.init_db() + init_db() self._initialized = True - # CRUD operations - # Note: Creating a new config item without specifying a key is unusual; use update_config_item instead. async def create_config_item(self, value): key = str(uuid4()) encrypted_value = self.encryption.encrypt_value(value) - print(f"ConfigManager: create_config_item {encrypted_value}") - query = 'INSERT INTO config (key, value) VALUES (?, ?)' - await db.execute_query(query, (key, encrypted_value)) + async with db_session_context() as session: + new_config = Config(key=key, value=encrypted_value) + session.add(new_config) + await session.commit() return key - + async def retrieve_config_item(self, key): - query = 'SELECT value FROM config WHERE key = ?' - result = await db.execute_query(query, (key,)) - if result: - encrypted_value = result[0][0] - return self.encryption.decrypt_value(encrypted_value) + async with db_session_context() as session: + result = await session.execute(select(Config).filter(Config.key == key)) + config = result.scalar_one_or_none() + if config: + return self.encryption.decrypt_value(config.value) return None async def update_config_item(self, key, value): encrypted_value = self.encryption.encrypt_value(value) - query = 'INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)' - await db.execute_query(query, (key, encrypted_value)) + async with db_session_context() as session: + stmt = update(Config).where(Config.key == key).values(value=encrypted_value) + result = await session.execute(stmt) + if result.rowcount == 0: + new_config = Config(key=key, value=encrypted_value) + session.add(new_config) + await session.commit() async def delete_config_item(self, key): - query = 'DELETE FROM config WHERE key = ?' - await db.execute_query(query, (key,)) + async with db_session_context() as session: + stmt = delete(Config).where(Config.key == key) + await session.execute(stmt) + await session.commit() diff --git a/backend/managers/UsersManager.py b/backend/managers/UsersManager.py index a2006d02..72c4aec5 100644 --- a/backend/managers/UsersManager.py +++ b/backend/managers/UsersManager.py @@ -1,6 +1,8 @@ from uuid import uuid4 -import backend.db as db from threading import Lock +from sqlalchemy import select, insert, update, delete, func +from backend.models import User +from backend.db import db_session_context class UsersManager: _instance = None @@ -17,69 +19,63 @@ def __init__(self): if not hasattr(self, '_initialized'): with self._lock: if not hasattr(self, '_initialized'): - db.init_db() self._initialized = True async def create_user(self, name, email): - id = str(uuid4()) - query = 'INSERT INTO user (id, name, email) VALUES (?, ?, ?)' - await db.execute_query(query, (id, name, email)) - return id + async with db_session_context() as session: + new_user = User(id=str(uuid4()), name=name, email=email) + session.add(new_user) + await session.commit() + return new_user.id async def update_user(self, id, name, email): - query = 'INSERT OR REPLACE INTO user (id, name, email) VALUES (?, ?, ?)' - return await db.execute_query(query, (id, name, email)) + async with db_session_context() as session: + stmt = update(User).where(User.id == id).values(name=name, email=email) + await session.execute(stmt) + await session.commit() async def delete_user(self, id): - query = 'DELETE FROM user WHERE id = ?' - return await db.execute_query(query, (id,)) + async with db_session_context() as session: + stmt = delete(User).where(User.id == id) + await session.execute(stmt) + await session.commit() async def retrieve_user(self, id): - query = 'SELECT name, email FROM user WHERE id = ?' - result = await db.execute_query(query, (id,)) - if result: - return {'id': id, 'name': result[0][0], 'email': result[0][1]} - return None + async with db_session_context() as session: + result = await session.execute(select(User).filter(User.id == id)) + user = result.scalar_one_or_none() + return user.to_dict() if user else None async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): - base_query = 'SELECT id, name, email FROM user' - query_params = [] + async with db_session_context() as session: + query = select(User) - # Apply filters - if filters: - filter_clauses = [] - for key, value in filters.items(): - if isinstance(value, list): - placeholders = ', '.join(['?'] * len(value)) - filter_clauses.append(f"{key} IN ({placeholders})") - query_params.extend(value) - else: - filter_clauses.append(f"{key} = ?") - query_params.append(value) - base_query += ' WHERE ' + ' AND '.join(filter_clauses) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + query = query.filter(getattr(User, key).in_(value)) + else: + query = query.filter(getattr(User, key) == value) - # Validate and apply sorting - valid_sort_columns = ['id', 'name', 'email'] - if sort_by and sort_by in valid_sort_columns: - sort_order = 'DESC' if sort_order.lower() == 'desc' else 'ASC' - base_query += f' ORDER BY {sort_by} {sort_order}' + if sort_by and sort_by in ['id', 'name', 'email']: + order_column = getattr(User, sort_by) + query = query.order_by(order_column.desc() if sort_order.lower() == 'desc' else order_column) - # Apply pagination - base_query += ' LIMIT ? OFFSET ?' - query_params.extend([limit, offset]) + query = query.offset(offset).limit(limit) - results = await db.execute_query(base_query, tuple(query_params)) - - users = [] - for result in results: - users.append({'id': result[0], 'name': result[1], 'email': result[2]}) - - # Assuming you have a way to get the total count of users - total_count_query = 'SELECT COUNT(*) FROM user' - if filters: - total_count_query += ' WHERE ' + ' AND '.join(filter_clauses) - total_count_result = await db.execute_query(total_count_query, tuple(query_params[:len(query_params) - 2] if filters else ())) - total_count = total_count_result[0][0] if total_count_result else 0 + result = await session.execute(query) + users = [user.to_dict() for user in result.scalars().all()] - return users, total_count - \ No newline at end of file + # Get total count + count_query = select(func.count()).select_from(User) + if filters: + for key, value in filters.items(): + if isinstance(value, list): + count_query = count_query.filter(getattr(User, key).in_(value)) + else: + count_query = count_query.filter(getattr(User, key) == value) + + total_count = await session.execute(count_query) + total_count = total_count.scalar() + + return users, total_count \ No newline at end of file diff --git a/backend/models.py b/backend/models.py index c7cf029d..2320655b 100644 --- a/backend/models.py +++ b/backend/models.py @@ -1,23 +1,28 @@ -from sqlmodel import Field, SQLModel +from sqlalchemy import Column, String +from backend.db import Base -class Config(SQLModel, table=True): - key: str = Field(default=None, primary_key=True) - value: str = Field(nullable=True) +class Config(Base): + __tablename__ = "config" + key = Column(String, primary_key=True) + value = Column(String, nullable=True) -class Channel(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - name: str = Field(nullable=False) - uri: str = Field(nullable=False) +class Channel(Base): + __tablename__ = "channel" + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + uri = Column(String, nullable=False) -class User(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - name: str = Field(nullable=False) - email: str = Field(nullable=False) +class User(Base): + __tablename__ = "user" + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + email = Column(String, nullable=False) -class Asset(SQLModel, table=True): - id: str = Field(default=None, primary_key=True) - user_id: str = Field(nullable=True) - title: str = Field(nullable=False) - creator: str = Field(nullable=True) - subject: str = Field(nullable=True) - description: str = Field(nullable=True) +class Asset(Base): + __tablename__ = "asset" + id = Column(String, primary_key=True) + user_id = Column(String, nullable=True) + title = Column(String, nullable=False) + creator = Column(String, nullable=True) + subject = Column(String, nullable=True) + description = Column(String, nullable=True) diff --git a/backend/requirements.txt b/backend/requirements.txt index 37f8ce02..a9e2b17b 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -7,7 +7,7 @@ cryptography setuptools packaging alembic -sqlmodel +sqlalchemy aiosqlite asyncio aiohttp diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index 885e3a0b..17e755dc 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -1,60 +1,48 @@ import unittest -from db import create_config_item, read_config_item, update_config_item, delete_config_item, set_config_item +from backend.managers.ConfigManager import ConfigManager +import asyncio -class TestDbFunctions(unittest.TestCase): +class TestConfigManager(unittest.TestCase): def setUp(self): - pass + self.config_manager = ConfigManager() + + def asyncTest(func): + def wrapper(*args, **kwargs): + return asyncio.run(func(*args, **kwargs)) + return wrapper - def test_create_config_item(self): - # Test that create_config_item correctly encrypts the value and inserts it into the database - tenant = 'test' - key = 'test_key' + @asyncTest + async def test_create_config_item(self): value = 'test_value' - create_config_item(key, value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + key = await self.config_manager.create_config_item(value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, value) - def test_read_config_item(self): - # Test that read_config_item correctly retrieves and decrypts a value from the database - tenant = 'test' - key = 'test_key' + @asyncTest + async def test_read_config_item(self): value = 'test_value' - create_config_item(key, value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + key = await self.config_manager.create_config_item(value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, value) - def test_update_config_item(self): - # Test that update_config_item correctly updates a value in the database - tenant = 'test' - key = 'test_key' - old_value = 'old_test_value' - new_value = 'new_test_value' - create_config_item(key, old_value, tenant=tenant) - update_config_item(key, new_value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) - self.assertEqual(result, new_value) - - def test_set_config_item(self): - # Test that set_config_item correctly updates a value in the database - tenant = 'test' - key = 'test_key' - old_value = 'old_test_value' + @asyncTest + async def test_update_config_item(self): + value = 'test_value' + key = await self.config_manager.create_config_item(value) new_value = 'new_test_value' - set_config_item(key, old_value, tenant=tenant) - set_config_item(key, new_value, tenant=tenant) - result = read_config_item(key, tenant=tenant) - delete_config_item(key, tenant=tenant) + await self.config_manager.update_config_item(key, new_value) + result = await self.config_manager.retrieve_config_item(key) + await self.config_manager.delete_config_item(key) self.assertEqual(result, new_value) - def test_delete_config_item(self): - # Test that delete_config_item correctly removes a value from the database - tenant = 'test' - key = 'test_key' - delete_config_item(key, tenant=tenant) - result = read_config_item(key, tenant=tenant) + @asyncTest + async def test_delete_config_item(self): + value = 'test_value' + key = await self.config_manager.create_config_item(value) + await self.config_manager.delete_config_item(key) + result = await self.config_manager.retrieve_config_item(key) self.assertIsNone(result) def tearDown(self): diff --git a/common/paths.py b/common/paths.py index 5be2e94c..d30711c5 100644 --- a/common/paths.py +++ b/common/paths.py @@ -30,5 +30,8 @@ abilities_data_dir = data_dir / abilities_subdir # paths -db_path = data_dir / 'paios.db' +db_name = 'paios.db' +db_path = data_dir / db_name +db_url = f"sqlite+aiosqlite:///{db_path}" downloads_dir = data_dir / 'downloads' + From 28fc4090187e7a5f8fa24bc0af5488568af463a5 Mon Sep 17 00:00:00 2001 From: Sam Johnston Date: Sat, 20 Jul 2024 15:22:55 -0700 Subject: [PATCH 4/5] add pydantic schemas --- backend/api/AbilitiesView.py | 1 - backend/api/AssetsView.py | 40 +++++++++------------- backend/api/ChannelsView.py | 24 +++++++++----- backend/api/ConfigView.py | 31 +++++++++++++----- backend/api/DownloadsView.py | 1 - backend/api/UsersView.py | 12 +++++-- backend/managers/AssetsManager.py | 51 +++++++++++++++++++++-------- backend/managers/ChannelsManager.py | 38 +++++++++++++-------- backend/managers/ConfigManager.py | 17 ++++++++-- backend/managers/UsersManager.py | 9 +++-- backend/schemas.py | 48 +++++++++++++++++++++++++++ 11 files changed, 192 insertions(+), 80 deletions(-) create mode 100644 backend/schemas.py diff --git a/backend/api/AbilitiesView.py b/backend/api/AbilitiesView.py index b493129b..09278e54 100644 --- a/backend/api/AbilitiesView.py +++ b/backend/api/AbilitiesView.py @@ -1,7 +1,6 @@ from starlette.responses import JSONResponse from backend.managers.AbilitiesManager import AbilitiesManager from backend.pagination import parse_pagination_params -from pkg_resources import ContextualVersionConflict import logging logger = logging.getLogger(__name__) diff --git a/backend/api/AssetsView.py b/backend/api/AssetsView.py index 6b4d82fe..6caec5c1 100644 --- a/backend/api/AssetsView.py +++ b/backend/api/AssetsView.py @@ -2,6 +2,8 @@ from backend.managers.AssetsManager import AssetsManager from common.paths import api_base_url from backend.pagination import parse_pagination_params +from backend.schemas import AssetCreateSchema, AssetSchema +from typing import List class AssetsView: def __init__(self): @@ -11,34 +13,22 @@ async def get(self, id: str): asset = await self.am.retrieve_asset(id) if asset is None: return JSONResponse({"error": "Asset not found"}, status_code=404) - return JSONResponse(asset, status_code=200) + return JSONResponse(asset.model_dump(), status_code=200) - async def post(self, body: dict): - asset_data = { - 'user_id': body.get('user_id'), - 'title': body.get('title'), - 'creator': body.get('creator'), - 'subject': body.get('subject'), - 'description': body.get('description') - } - id = await self.am.create_asset(**asset_data) - asset = await self.am.retrieve_asset(id) - return JSONResponse(asset, status_code=201, headers={'Location': f'{api_base_url}/assets/{id}'}) + async def post(self, body: AssetCreateSchema): + new_asset = await self.am.create_asset(body) + return JSONResponse(new_asset.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/assets/{new_asset.id}'}) - async def put(self, id: str, body: dict): - asset_data = { - 'user_id': body.get('user_id'), - 'title': body.get('title'), - 'creator': body.get('creator'), - 'subject': body.get('subject'), - 'description': body.get('description') - } - await self.am.update_asset(id, **asset_data) - asset = await self.am.retrieve_asset(id) - return JSONResponse(asset, status_code=200) + async def put(self, id: str, body: AssetCreateSchema): + updated_asset = await self.am.update_asset(id, body) + if updated_asset is None: + return JSONResponse({"error": "Asset not found"}, status_code=404) + return JSONResponse(updated_asset.model_dump(), status_code=200) async def delete(self, id: str): - await self.am.delete_asset(id) + success = await self.am.delete_asset(id) + if not success: + return JSONResponse({"error": "Asset not found"}, status_code=404) return Response(status_code=204) async def search(self, filter: str = None, range: str = None, sort: str = None): @@ -63,4 +53,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'X-Total-Count': str(total_count), 'Content-Range': f'assets {offset}-{offset + len(assets) - 1}/{total_count}' } - return JSONResponse(assets, status_code=200, headers=headers) \ No newline at end of file + return JSONResponse([asset.model_dump() for asset in assets], status_code=200, headers=headers) diff --git a/backend/api/ChannelsView.py b/backend/api/ChannelsView.py index 6d47d3f9..9f3ebadb 100644 --- a/backend/api/ChannelsView.py +++ b/backend/api/ChannelsView.py @@ -2,6 +2,8 @@ from common.paths import api_base_url from backend.managers.ChannelsManager import ChannelsManager from backend.pagination import parse_pagination_params +from backend.schemas import ChannelCreateSchema +from typing import List class ChannelsView: def __init__(self): @@ -11,18 +13,22 @@ async def get(self, channel_id: str): channel = await self.cm.retrieve_channel(channel_id) if channel is None: return JSONResponse({"error": "Channel not found"}, status_code=404) - return JSONResponse(channel, status_code=200) + return JSONResponse(channel.model_dump(), status_code=200) - async def post(self, body: dict): - channel_id = await self.cm.create_channel(body['name'], body['uri']) - return JSONResponse({"id": channel_id}, status_code=201, headers={'Location': f'{api_base_url}/channels/{channel_id}'}) + async def post(self, body: ChannelCreateSchema): + new_channel = await self.cm.create_channel(body) + return JSONResponse(new_channel.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/channels/{new_channel.id}'}) - async def put(self, channel_id: str, body: dict): - await self.cm.update_channel(channel_id, body['name'], body['uri']) - return JSONResponse({"message": "Channel updated successfully"}, status_code=200) + async def put(self, channel_id: str, body: ChannelCreateSchema): + updated_channel = await self.cm.update_channel(channel_id, body) + if updated_channel is None: + return JSONResponse({"error": "Channel not found"}, status_code=404) + return JSONResponse(updated_channel.model_dump(), status_code=200) async def delete(self, channel_id: str): - await self.cm.delete_channel(channel_id) + success = await self.cm.delete_channel(channel_id) + if not success: + return JSONResponse({"error": "Channel not found"}, status_code=404) return Response(status_code=204) async def search(self, filter: str = None, range: str = None, sort: str = None): @@ -37,4 +43,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): 'X-Total-Count': str(total_count), 'Content-Range': f'channels {offset}-{offset + len(channels) - 1}/{total_count}' } - return JSONResponse(channels, status_code=200, headers=headers) + return JSONResponse([channel.model_dump() for channel in channels], status_code=200, headers=headers) diff --git a/backend/api/ConfigView.py b/backend/api/ConfigView.py index 587fabee..9413d536 100644 --- a/backend/api/ConfigView.py +++ b/backend/api/ConfigView.py @@ -1,21 +1,34 @@ from starlette.responses import JSONResponse, Response from backend.managers.ConfigManager import ConfigManager +from backend.schemas import ConfigSchema class ConfigView: def __init__(self): self.cm = ConfigManager() async def get(self, key: str): - value = await self.cm.retrieve_config_item(key) - if value is None: - return JSONResponse(status_code=404, headers={"error": "Config item not found"}) - return JSONResponse(value, status_code=200) + config_item = await self.cm.retrieve_config_item(key) + if config_item is None: + return JSONResponse(status_code=404, content={"error": "Config item not found"}) + return JSONResponse(config_item.model_dump(), status_code=200) - async def put(self, key: str, body: dict): + async def put(self, key: str, body: ConfigSchema): print(f"ConfigView: PUT {key}->{body}") - await self.cm.update_config_item(key, body) - return JSONResponse({"message": "Config item updated successfully"}, status_code=200) + updated_config = await self.cm.update_config_item(key, body.value) + if updated_config: + return JSONResponse(updated_config.model_dump(), status_code=200) + return JSONResponse({"error": "Failed to update config item"}, status_code=400) async def delete(self, key: str): - await self.cm.delete_config_item(key) - return Response(status_code=204) + success = await self.cm.delete_config_item(key) + if success: + return Response(status_code=204) + return JSONResponse({"error": "Config item not found"}, status_code=404) + + async def list(self): + config_items = await self.cm.retrieve_all_config_items() + return JSONResponse([item.model_dump() for item in config_items], status_code=200) + + async def create(self, body: ConfigSchema): + new_config = await self.cm.create_config_item(body.value) + return JSONResponse(new_config.model_dump(), status_code=201) diff --git a/backend/api/DownloadsView.py b/backend/api/DownloadsView.py index fa172a80..e91a33e9 100644 --- a/backend/api/DownloadsView.py +++ b/backend/api/DownloadsView.py @@ -1,4 +1,3 @@ -from starlette.requests import Request from starlette.responses import Response, JSONResponse from backend.managers.DownloadsManager import DownloadsManager from backend.pagination import parse_pagination_params diff --git a/backend/api/UsersView.py b/backend/api/UsersView.py index c6c44967..f2cbd803 100644 --- a/backend/api/UsersView.py +++ b/backend/api/UsersView.py @@ -3,6 +3,7 @@ from backend.managers.UsersManager import UsersManager from backend.pagination import parse_pagination_params from aiosqlite import IntegrityError +from backend.schemas import UserSchema class UsersView: def __init__(self): @@ -12,7 +13,7 @@ async def get(self, id: str): user = await self.um.retrieve_user(id) if user is None: return JSONResponse(status_code=404, headers={"error": "User not found"}) - return JSONResponse(user, status_code=200) + return JSONResponse(user.model_dump(), status_code=200) async def post(self, body: dict): try: @@ -37,8 +38,13 @@ async def search(self, filter: str = None, range: str = None, sort: str = None): offset, limit, sort_by, sort_order, filters = result users, total_count = await self.um.retrieve_users(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters) + + # Convert Pydantic models to dictionaries + users_dict = [user.model_dump() for user in users] + headers = { 'X-Total-Count': str(total_count), - 'Content-Range': f'users {offset}-{offset + len(users) - 1}/{total_count}' + 'Content-Range': f'users {offset}-{offset+len(users)}/{total_count}', + 'Access-Control-Expose-Headers': 'Content-Range' } - return JSONResponse(users, status_code=200, headers=headers) + return JSONResponse(users_dict, status_code=200, headers=headers) diff --git a/backend/managers/AssetsManager.py b/backend/managers/AssetsManager.py index 61d7a7b8..1ba8d02d 100644 --- a/backend/managers/AssetsManager.py +++ b/backend/managers/AssetsManager.py @@ -3,7 +3,8 @@ from sqlalchemy import select, insert, update, delete, func, or_ from backend.models import Asset from backend.db import db_session_context -from backend.utils import remove_null_fields +from backend.schemas import AssetSchema, AssetCreateSchema +from typing import List, Tuple, Optional, Dict, Any class AssetsManager: _instance = None @@ -22,34 +23,49 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_asset(self, user_id, title, creator, subject, description): + async def create_asset(self, asset_data: AssetCreateSchema) -> AssetSchema: async with db_session_context() as session: - new_asset = Asset(id=str(uuid4()), user_id=user_id, title=title, creator=creator, subject=subject, description=description) + new_asset = Asset(id=str(uuid4()), **asset_data.model_dump()) session.add(new_asset) await session.commit() - return new_asset.id + await session.refresh(new_asset) + return AssetSchema(id=new_asset.id, **asset_data.model_dump()) - async def update_asset(self, id, user_id, title, creator, subject, description): + async def update_asset(self, id: str, asset_data: AssetCreateSchema) -> Optional[AssetSchema]: async with db_session_context() as session: - stmt = update(Asset).where(Asset.id == id).values(user_id=user_id, title=title, creator=creator, subject=subject, description=description) + stmt = update(Asset).where(Asset.id == id).values(**asset_data.model_dump(exclude_unset=True)) result = await session.execute(stmt) - await session.commit() - return result.rowcount > 0 + if result.rowcount > 0: + await session.commit() + updated_asset = await session.get(Asset, id) + return AssetSchema(id=updated_asset.id, **asset_data.model_dump()) + return None - async def delete_asset(self, id): + async def delete_asset(self, id: str) -> bool: async with db_session_context() as session: stmt = delete(Asset).where(Asset.id == id) result = await session.execute(stmt) await session.commit() return result.rowcount > 0 - async def retrieve_asset(self, id): + async def retrieve_asset(self, id: str) -> Optional[AssetSchema]: async with db_session_context() as session: result = await session.execute(select(Asset).filter(Asset.id == id)) asset = result.scalar_one_or_none() - return remove_null_fields(asset.to_dict()) if asset else None + if asset: + return AssetSchema( + id=asset.id, + title=asset.title, + user_id=asset.user_id, + creator=asset.creator, + subject=asset.subject, + description=asset.description + ) + return None - async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None, query=None): + async def retrieve_assets(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None, + query: Optional[str] = None) -> Tuple[List[AssetSchema], int]: async with db_session_context() as session: stmt = select(Asset) @@ -76,7 +92,14 @@ async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='a stmt = stmt.offset(offset).limit(limit) result = await session.execute(stmt) - assets = [remove_null_fields(asset.to_dict()) for asset in result.scalars().all()] + assets = [AssetSchema( + id=asset.id, + title=asset.title, + user_id=asset.user_id, + creator=asset.creator, + subject=asset.subject, + description=asset.description + ) for asset in result.scalars().all()] # Get total count count_stmt = select(func.count()).select_from(Asset) @@ -85,4 +108,4 @@ async def retrieve_assets(self, offset=0, limit=100, sort_by=None, sort_order='a total_count = await session.execute(count_stmt) total_count = total_count.scalar() - return assets, total_count \ No newline at end of file + return assets, total_count diff --git a/backend/managers/ChannelsManager.py b/backend/managers/ChannelsManager.py index 72408a92..0ef808cd 100644 --- a/backend/managers/ChannelsManager.py +++ b/backend/managers/ChannelsManager.py @@ -3,6 +3,8 @@ from sqlalchemy import select, insert, update, delete, func from backend.models import Channel from backend.db import db_session_context +from backend.schemas import ChannelCreateSchema, ChannelSchema +from typing import List, Tuple, Optional, Dict, Any class ChannelsManager: _instance = None @@ -21,32 +23,41 @@ def __init__(self): if not hasattr(self, '_initialized'): self._initialized = True - async def create_channel(self, name, uri): + async def create_channel(self, channel_data: ChannelCreateSchema) -> ChannelSchema: async with db_session_context() as session: - new_channel = Channel(id=str(uuid4()), name=name, uri=uri) + new_channel = Channel(id=str(uuid4()), **channel_data.model_dump()) session.add(new_channel) await session.commit() - return new_channel.id + await session.refresh(new_channel) + return ChannelSchema(id=new_channel.id, **channel_data.model_dump()) - async def update_channel(self, id, name, uri): + async def update_channel(self, id: str, channel_data: ChannelCreateSchema) -> Optional[ChannelSchema]: async with db_session_context() as session: - stmt = update(Channel).where(Channel.id == id).values(name=name, uri=uri) - await session.execute(stmt) - await session.commit() + stmt = update(Channel).where(Channel.id == id).values(**channel_data.dict()) + result = await session.execute(stmt) + if result.rowcount > 0: + await session.commit() + updated_channel = await session.get(Channel, id) + return ChannelSchema(id=updated_channel.id, **channel_data.model_dump()) + return None - async def delete_channel(self, id): + async def delete_channel(self, id: str) -> bool: async with db_session_context() as session: stmt = delete(Channel).where(Channel.id == id) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + return result.rowcount > 0 - async def retrieve_channel(self, id): + async def retrieve_channel(self, id: str) -> Optional[ChannelSchema]: async with db_session_context() as session: result = await session.execute(select(Channel).filter(Channel.id == id)) channel = result.scalar_one_or_none() - return channel.to_dict() if channel else None + if channel: + return ChannelSchema(id=channel.id, name=channel.name, uri=channel.uri) + return None - async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): + async def retrieve_channels(self, offset: int = 0, limit: int = 100, sort_by: Optional[str] = None, + sort_order: str = 'asc', filters: Optional[Dict[str, Any]] = None) -> Tuple[List[ChannelSchema], int]: async with db_session_context() as session: query = select(Channel) @@ -64,7 +75,8 @@ async def retrieve_channels(self, offset=0, limit=100, sort_by=None, sort_order= query = query.offset(offset).limit(limit) result = await session.execute(query) - channels = [channel.to_dict() for channel in result.scalars().all()] + channels = [ChannelSchema(id=channel.id, name=channel.name, uri=channel.uri) + for channel in result.scalars().all()] # Get total count count_query = select(func.count()).select_from(Channel) diff --git a/backend/managers/ConfigManager.py b/backend/managers/ConfigManager.py index fa5df1b2..f147a192 100644 --- a/backend/managers/ConfigManager.py +++ b/backend/managers/ConfigManager.py @@ -4,6 +4,7 @@ from backend.models import Config from backend.db import db_session_context, init_db from backend.encryption import Encryption +from backend.schemas import ConfigSchema class ConfigManager: _instance = None @@ -32,14 +33,15 @@ async def create_config_item(self, value): new_config = Config(key=key, value=encrypted_value) session.add(new_config) await session.commit() - return key + return ConfigSchema(key=key, value=value) async def retrieve_config_item(self, key): async with db_session_context() as session: result = await session.execute(select(Config).filter(Config.key == key)) config = result.scalar_one_or_none() if config: - return self.encryption.decrypt_value(config.value) + decrypted_value = self.encryption.decrypt_value(config.value) + return ConfigSchema(key=config.key, value=decrypted_value) return None async def update_config_item(self, key, value): @@ -51,9 +53,18 @@ async def update_config_item(self, key, value): new_config = Config(key=key, value=encrypted_value) session.add(new_config) await session.commit() + return ConfigSchema(key=key, value=value) async def delete_config_item(self, key): async with db_session_context() as session: stmt = delete(Config).where(Config.key == key) - await session.execute(stmt) + result = await session.execute(stmt) await session.commit() + return result.rowcount > 0 + + async def retrieve_all_config_items(self): + async with db_session_context() as session: + result = await session.execute(select(Config)) + configs = result.scalars().all() + return [ConfigSchema(key=config.key, value=self.encryption.decrypt_value(config.value)) + for config in configs] \ No newline at end of file diff --git a/backend/managers/UsersManager.py b/backend/managers/UsersManager.py index 72c4aec5..2ef7cc93 100644 --- a/backend/managers/UsersManager.py +++ b/backend/managers/UsersManager.py @@ -3,6 +3,7 @@ from sqlalchemy import select, insert, update, delete, func from backend.models import User from backend.db import db_session_context +from backend.schemas import UserSchema class UsersManager: _instance = None @@ -44,7 +45,7 @@ async def retrieve_user(self, id): async with db_session_context() as session: result = await session.execute(select(User).filter(User.id == id)) user = result.scalar_one_or_none() - return user.to_dict() if user else None + return UserSchema(id=user.id, name=user.name, email=user.email) if user else None async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='asc', filters=None): async with db_session_context() as session: @@ -64,7 +65,11 @@ async def retrieve_users(self, offset=0, limit=100, sort_by=None, sort_order='as query = query.offset(offset).limit(limit) result = await session.execute(query) - users = [user.to_dict() for user in result.scalars().all()] + users = [UserSchema( + id=user.id, + name=user.name, + email=user.email + ) for user in result.scalars().all()] # Get total count count_query = select(func.count()).select_from(User) diff --git a/backend/schemas.py b/backend/schemas.py new file mode 100644 index 00000000..74602dfe --- /dev/null +++ b/backend/schemas.py @@ -0,0 +1,48 @@ +from pydantic import BaseModel +from typing import Optional + +# We have *Create schemas because API clients ideally don't set the id field, it's set by the server +# Alternatively we could have made the id optional but then we would have to check if it's set by the client + +# Config schemas +class ConfigBaseSchema(BaseModel): + value: Optional[str] = None + +class ConfigSchema(ConfigBaseSchema): + key: str + +# Channel schemas +class ChannelBaseSchema(BaseModel): + name: str + uri: str + +class ChannelCreateSchema(ChannelBaseSchema): + pass + +class ChannelSchema(ChannelBaseSchema): + id: str + +# User schemas +class UserBaseSchema(BaseModel): + name: str + email: str + +class UserCreateSchema(UserBaseSchema): + pass + +class UserSchema(UserBaseSchema): + id: str + +# Asset schemas +class AssetBaseSchema(BaseModel): + title: str + user_id: Optional[str] = None + creator: Optional[str] = None + subject: Optional[str] = None + description: Optional[str] = None + +class AssetCreateSchema(AssetBaseSchema): + pass + +class AssetSchema(AssetBaseSchema): + id: str From 532b2ddc849f8031452086d76d9f5a3bff9529f4 Mon Sep 17 00:00:00 2001 From: Sam Johnston Date: Sat, 20 Jul 2024 16:05:23 -0700 Subject: [PATCH 5/5] turn off sql echo --- backend/db.py | 2 +- common/config.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/backend/db.py b/backend/db.py index 2561b159..fbf1b163 100644 --- a/backend/db.py +++ b/backend/db.py @@ -14,7 +14,7 @@ Base = declarative_base() # Create async engine -engine = create_async_engine(db_url, echo=True) +engine = create_async_engine(db_url, echo=False) # Create async session factory AsyncSessionLocal = sessionmaker( diff --git a/common/config.py b/common/config.py index 2b27c70c..df504cba 100644 --- a/common/config.py +++ b/common/config.py @@ -65,5 +65,6 @@ "uvicorn.error": {"level": "INFO"}, "uvicorn.access": {"handlers": ["access"], "level": "INFO", "propagate": False}, "watchfiles.main": {"level": "ERROR"}, # filter watchfiles noise + "sqlalchemy.engine": {"level": "WARNING", "propagate": False}, # filter sqlalchemy noise }, }