Skip to content

Commit

Permalink
Merge branch 'feature/sqlalchemy'
Browse files Browse the repository at this point in the history
  • Loading branch information
samj committed Jul 21, 2024
2 parents 8a5f930 + f02a3bb commit c53ebe6
Show file tree
Hide file tree
Showing 18 changed files with 446 additions and 333 deletions.
1 change: 0 additions & 1 deletion backend/api/AbilitiesView.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from starlette.responses import JSONResponse
from backend.managers.AbilitiesManager import AbilitiesManager
from backend.pagination import parse_pagination_params
from pkg_resources import ContextualVersionConflict
import logging
logger = logging.getLogger(__name__)

Expand Down
40 changes: 15 additions & 25 deletions backend/api/AssetsView.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from backend.managers.AssetsManager import AssetsManager
from common.paths import api_base_url
from backend.pagination import parse_pagination_params
from backend.schemas import AssetCreateSchema, AssetSchema
from typing import List

class AssetsView:
def __init__(self):
Expand All @@ -11,34 +13,22 @@ async def get(self, id: str):
asset = await self.am.retrieve_asset(id)
if asset is None:
return JSONResponse({"error": "Asset not found"}, status_code=404)
return JSONResponse(asset, status_code=200)
return JSONResponse(asset.model_dump(), status_code=200)

async def post(self, body: dict):
asset_data = {
'user_id': body.get('user_id'),
'title': body.get('title'),
'creator': body.get('creator'),
'subject': body.get('subject'),
'description': body.get('description')
}
id = await self.am.create_asset(**asset_data)
asset = await self.am.retrieve_asset(id)
return JSONResponse(asset, status_code=201, headers={'Location': f'{api_base_url}/assets/{id}'})
async def post(self, body: AssetCreateSchema):
new_asset = await self.am.create_asset(body)
return JSONResponse(new_asset.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/assets/{new_asset.id}'})

async def put(self, id: str, body: dict):
asset_data = {
'user_id': body.get('user_id'),
'title': body.get('title'),
'creator': body.get('creator'),
'subject': body.get('subject'),
'description': body.get('description')
}
await self.am.update_asset(id, **asset_data)
asset = await self.am.retrieve_asset(id)
return JSONResponse(asset, status_code=200)
async def put(self, id: str, body: AssetCreateSchema):
updated_asset = await self.am.update_asset(id, body)
if updated_asset is None:
return JSONResponse({"error": "Asset not found"}, status_code=404)
return JSONResponse(updated_asset.model_dump(), status_code=200)

async def delete(self, id: str):
await self.am.delete_asset(id)
success = await self.am.delete_asset(id)
if not success:
return JSONResponse({"error": "Asset not found"}, status_code=404)
return Response(status_code=204)

async def search(self, filter: str = None, range: str = None, sort: str = None):
Expand All @@ -63,4 +53,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
'X-Total-Count': str(total_count),
'Content-Range': f'assets {offset}-{offset + len(assets) - 1}/{total_count}'
}
return JSONResponse(assets, status_code=200, headers=headers)
return JSONResponse([asset.model_dump() for asset in assets], status_code=200, headers=headers)
24 changes: 15 additions & 9 deletions backend/api/ChannelsView.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from common.paths import api_base_url
from backend.managers.ChannelsManager import ChannelsManager
from backend.pagination import parse_pagination_params
from backend.schemas import ChannelCreateSchema
from typing import List

class ChannelsView:
def __init__(self):
Expand All @@ -11,18 +13,22 @@ async def get(self, channel_id: str):
channel = await self.cm.retrieve_channel(channel_id)
if channel is None:
return JSONResponse({"error": "Channel not found"}, status_code=404)
return JSONResponse(channel, status_code=200)
return JSONResponse(channel.model_dump(), status_code=200)

async def post(self, body: dict):
channel_id = await self.cm.create_channel(body['name'], body['uri'])
return JSONResponse({"id": channel_id}, status_code=201, headers={'Location': f'{api_base_url}/channels/{channel_id}'})
async def post(self, body: ChannelCreateSchema):
new_channel = await self.cm.create_channel(body)
return JSONResponse(new_channel.model_dump(), status_code=201, headers={'Location': f'{api_base_url}/channels/{new_channel.id}'})

async def put(self, channel_id: str, body: dict):
await self.cm.update_channel(channel_id, body['name'], body['uri'])
return JSONResponse({"message": "Channel updated successfully"}, status_code=200)
async def put(self, channel_id: str, body: ChannelCreateSchema):
updated_channel = await self.cm.update_channel(channel_id, body)
if updated_channel is None:
return JSONResponse({"error": "Channel not found"}, status_code=404)
return JSONResponse(updated_channel.model_dump(), status_code=200)

async def delete(self, channel_id: str):
await self.cm.delete_channel(channel_id)
success = await self.cm.delete_channel(channel_id)
if not success:
return JSONResponse({"error": "Channel not found"}, status_code=404)
return Response(status_code=204)

async def search(self, filter: str = None, range: str = None, sort: str = None):
Expand All @@ -37,4 +43,4 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
'X-Total-Count': str(total_count),
'Content-Range': f'channels {offset}-{offset + len(channels) - 1}/{total_count}'
}
return JSONResponse(channels, status_code=200, headers=headers)
return JSONResponse([channel.model_dump() for channel in channels], status_code=200, headers=headers)
31 changes: 22 additions & 9 deletions backend/api/ConfigView.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,34 @@
from starlette.responses import JSONResponse, Response
from backend.managers.ConfigManager import ConfigManager
from backend.schemas import ConfigSchema

class ConfigView:
def __init__(self):
self.cm = ConfigManager()

async def get(self, key: str):
value = await self.cm.retrieve_config_item(key)
if value is None:
return JSONResponse(status_code=404, headers={"error": "Config item not found"})
return JSONResponse(value, status_code=200)
config_item = await self.cm.retrieve_config_item(key)
if config_item is None:
return JSONResponse(status_code=404, content={"error": "Config item not found"})
return JSONResponse(config_item.model_dump(), status_code=200)

async def put(self, key: str, body: dict):
async def put(self, key: str, body: ConfigSchema):
print(f"ConfigView: PUT {key}->{body}")
await self.cm.update_config_item(key, body)
return JSONResponse({"message": "Config item updated successfully"}, status_code=200)
updated_config = await self.cm.update_config_item(key, body.value)
if updated_config:
return JSONResponse(updated_config.model_dump(), status_code=200)
return JSONResponse({"error": "Failed to update config item"}, status_code=400)

async def delete(self, key: str):
await self.cm.delete_config_item(key)
return Response(status_code=204)
success = await self.cm.delete_config_item(key)
if success:
return Response(status_code=204)
return JSONResponse({"error": "Config item not found"}, status_code=404)

async def list(self):
config_items = await self.cm.retrieve_all_config_items()
return JSONResponse([item.model_dump() for item in config_items], status_code=200)

async def create(self, body: ConfigSchema):
new_config = await self.cm.create_config_item(body.value)
return JSONResponse(new_config.model_dump(), status_code=201)
1 change: 0 additions & 1 deletion backend/api/DownloadsView.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from starlette.requests import Request
from starlette.responses import Response, JSONResponse
from backend.managers.DownloadsManager import DownloadsManager
from backend.pagination import parse_pagination_params
Expand Down
12 changes: 9 additions & 3 deletions backend/api/UsersView.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from backend.managers.UsersManager import UsersManager
from backend.pagination import parse_pagination_params
from aiosqlite import IntegrityError
from backend.schemas import UserSchema

class UsersView:
def __init__(self):
Expand All @@ -12,7 +13,7 @@ async def get(self, id: str):
user = await self.um.retrieve_user(id)
if user is None:
return JSONResponse(status_code=404, headers={"error": "User not found"})
return JSONResponse(user, status_code=200)
return JSONResponse(user.model_dump(), status_code=200)

async def post(self, body: dict):
try:
Expand All @@ -37,8 +38,13 @@ async def search(self, filter: str = None, range: str = None, sort: str = None):
offset, limit, sort_by, sort_order, filters = result

users, total_count = await self.um.retrieve_users(limit=limit, offset=offset, sort_by=sort_by, sort_order=sort_order, filters=filters)

# Convert Pydantic models to dictionaries
users_dict = [user.model_dump() for user in users]

headers = {
'X-Total-Count': str(total_count),
'Content-Range': f'users {offset}-{offset + len(users) - 1}/{total_count}'
'Content-Range': f'users {offset}-{offset+len(users)}/{total_count}',
'Access-Control-Expose-Headers': 'Content-Range'
}
return JSONResponse(users, status_code=200, headers=headers)
return JSONResponse(users_dict, status_code=200, headers=headers)
49 changes: 29 additions & 20 deletions backend/db.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,45 @@
# database helper functions
import os
import aiosqlite
import logging
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from alembic import command
from alembic.config import Config as AlembicConfig
from common.paths import base_dir, db_path
from common.paths import base_dir, db_path, db_url
from contextlib import asynccontextmanager

logger = logging.getLogger(__name__)

# Define the SQLAlchemy Base
Base = declarative_base()

# Create async engine
engine = create_async_engine(db_url, echo=False)

# Create async session factory
AsyncSessionLocal = sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
)

# use alembic to create the database or migrate to the latest schema
def init_db():
logger.info("Initializing database.")
alembic_cfg = AlembicConfig()
os.makedirs(db_path.parent, exist_ok=True)
alembic_cfg.set_main_option("script_location", str(base_dir / "migrations"))
alembic_cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("+aiosqlite", "")) # because Alembic doesn't like async apparently
command.upgrade(alembic_cfg, "head")

# Call init_db() when the module is first imported
init_db()

async def execute_query(query, params=None):
# TODO: logger.adebug from structlog
logger.debug(f"Executing query: {query} with params: {params}")
async with aiosqlite.connect(db_path) as conn:
async with conn.cursor() as cursor:
try:
await cursor.execute(query, params or ())
result = await cursor.fetchall()
await conn.commit()
except Exception as e:
await conn.rollback()
raise e

return result
@asynccontextmanager
async def db_session_context():
session = AsyncSessionLocal()
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
2 changes: 1 addition & 1 deletion backend/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def check_env():
print(f"\nOnce you have activated the virtual environment, run this again.")
sys.exit(1)

required_modules = ['connexion', 'uvicorn']
required_modules = ['connexion', 'uvicorn', 'sqlalchemy', 'alembic', 'aiosqlite']
for module in required_modules:
try:
__import__(module)
Expand Down
Loading

0 comments on commit c53ebe6

Please sign in to comment.