-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
18 changed files
with
446 additions
and
333 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,34 @@ | ||
from starlette.responses import JSONResponse, Response | ||
from backend.managers.ConfigManager import ConfigManager | ||
from backend.schemas import ConfigSchema | ||
|
||
class ConfigView: | ||
def __init__(self): | ||
self.cm = ConfigManager() | ||
|
||
async def get(self, key: str): | ||
value = await self.cm.retrieve_config_item(key) | ||
if value is None: | ||
return JSONResponse(status_code=404, headers={"error": "Config item not found"}) | ||
return JSONResponse(value, status_code=200) | ||
config_item = await self.cm.retrieve_config_item(key) | ||
if config_item is None: | ||
return JSONResponse(status_code=404, content={"error": "Config item not found"}) | ||
return JSONResponse(config_item.model_dump(), status_code=200) | ||
|
||
async def put(self, key: str, body: dict): | ||
async def put(self, key: str, body: ConfigSchema): | ||
print(f"ConfigView: PUT {key}->{body}") | ||
await self.cm.update_config_item(key, body) | ||
return JSONResponse({"message": "Config item updated successfully"}, status_code=200) | ||
updated_config = await self.cm.update_config_item(key, body.value) | ||
if updated_config: | ||
return JSONResponse(updated_config.model_dump(), status_code=200) | ||
return JSONResponse({"error": "Failed to update config item"}, status_code=400) | ||
|
||
async def delete(self, key: str): | ||
await self.cm.delete_config_item(key) | ||
return Response(status_code=204) | ||
success = await self.cm.delete_config_item(key) | ||
if success: | ||
return Response(status_code=204) | ||
return JSONResponse({"error": "Config item not found"}, status_code=404) | ||
|
||
async def list(self): | ||
config_items = await self.cm.retrieve_all_config_items() | ||
return JSONResponse([item.model_dump() for item in config_items], status_code=200) | ||
|
||
async def create(self, body: ConfigSchema): | ||
new_config = await self.cm.create_config_item(body.value) | ||
return JSONResponse(new_config.model_dump(), status_code=201) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,45 @@ | ||
# database helper functions | ||
import os | ||
import aiosqlite | ||
import logging | ||
from sqlalchemy.orm import sessionmaker, declarative_base | ||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession | ||
from alembic import command | ||
from alembic.config import Config as AlembicConfig | ||
from common.paths import base_dir, db_path | ||
from common.paths import base_dir, db_path, db_url | ||
from contextlib import asynccontextmanager | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
# Define the SQLAlchemy Base | ||
Base = declarative_base() | ||
|
||
# Create async engine | ||
engine = create_async_engine(db_url, echo=False) | ||
|
||
# Create async session factory | ||
AsyncSessionLocal = sessionmaker( | ||
bind=engine, | ||
class_=AsyncSession, | ||
expire_on_commit=False, | ||
) | ||
|
||
# use alembic to create the database or migrate to the latest schema | ||
def init_db(): | ||
logger.info("Initializing database.") | ||
alembic_cfg = AlembicConfig() | ||
os.makedirs(db_path.parent, exist_ok=True) | ||
alembic_cfg.set_main_option("script_location", str(base_dir / "migrations")) | ||
alembic_cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") | ||
alembic_cfg.set_main_option("sqlalchemy.url", db_url.replace("+aiosqlite", "")) # because Alembic doesn't like async apparently | ||
command.upgrade(alembic_cfg, "head") | ||
|
||
# Call init_db() when the module is first imported | ||
init_db() | ||
|
||
async def execute_query(query, params=None): | ||
# TODO: logger.adebug from structlog | ||
logger.debug(f"Executing query: {query} with params: {params}") | ||
async with aiosqlite.connect(db_path) as conn: | ||
async with conn.cursor() as cursor: | ||
try: | ||
await cursor.execute(query, params or ()) | ||
result = await cursor.fetchall() | ||
await conn.commit() | ||
except Exception as e: | ||
await conn.rollback() | ||
raise e | ||
|
||
return result | ||
@asynccontextmanager | ||
async def db_session_context(): | ||
session = AsyncSessionLocal() | ||
try: | ||
yield session | ||
await session.commit() | ||
except Exception: | ||
await session.rollback() | ||
raise | ||
finally: | ||
await session.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.