From 3a87485b9a392998890aa2a905a8326948e5e813 Mon Sep 17 00:00:00 2001 From: Michiel Scholten Date: Fri, 12 Sep 2025 12:03:12 +0200 Subject: [PATCH] Application now uses async DB session everywhere --- src/digimarks/main.py | 86 +++++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/src/digimarks/main.py b/src/digimarks/main.py index 414156d..7c1cee3 100644 --- a/src/digimarks/main.py +++ b/src/digimarks/main.py @@ -20,7 +20,10 @@ from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from pydantic import AnyUrl, DirectoryPath, FilePath, computed_field from pydantic_settings import BaseSettings -from sqlmodel import AutoString, Field, Session, SQLModel, create_engine, desc, select +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from sqlmodel import AutoString, Field, SQLModel, desc, select +from sqlmodel.ext.asyncio.session import AsyncSession DIGIMARKS_USER_AGENT = 'digimarks/2.0.0-dev' DIGIMARKS_VERSION = '2.0.0a1' @@ -49,16 +52,17 @@ class Settings(BaseSettings): settings = Settings() print(settings.model_dump()) -engine = create_engine(f'sqlite:///{settings.database_file}', connect_args={'check_same_thread': False}) +engine = create_async_engine(f'sqlite+aiosqlite:///{settings.database_file}', connect_args={'check_same_thread': False}) -def get_session(): +async def get_session() -> AsyncSession: """SQLAlchemy session factory.""" - with Session(engine) as session: + async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + async with async_session() as session: yield session -SessionDep = Annotated[Session, Depends(get_session)] +SessionDep = Annotated[AsyncSession, Depends(get_session)] @asynccontextmanager @@ -358,7 +362,7 @@ def index(request: Request): @app.get('/api/v1/admin/{system_key}/users/{user_id}', response_model=User) -def get_user(session: SessionDep, system_key: str, user_id: int) -> Type[User]: +async def get_user(session: SessionDep, system_key: str, user_id: int) -> Type[User]: """Show user information.""" if system_key != settings.system_key: raise HTTPException(status_code=404) @@ -371,7 +375,7 @@ def get_user(session: SessionDep, system_key: str, user_id: int) -> Type[User]: # @app.get('/admin/{system_key}/users/', response_model=list[User]) @app.get('/api/v1/admin/{system_key}/users/') -def list_users( +async def list_users( session: SessionDep, system_key: str, offset: int = 0, @@ -394,40 +398,42 @@ def list_users( @app.get('/api/v1/{user_key}/bookmarks/') -def list_bookmarks( +async def list_bookmarks( session: SessionDep, user_key: str, offset: int = 0, limit: Annotated[int, Query(le=10000)] = 100, ) -> list[Bookmark]: """List all bookmarks in the database. By default 100 items are returned.""" - bookmarks = session.exec( + result = await session.exec( select(Bookmark) .where(Bookmark.userkey == user_key, Bookmark.status != Visibility.DELETED) .offset(offset) .limit(limit) - ).all() + ) + bookmarks = result.all() return bookmarks @app.get('/api/v1/{user_key}/bookmarks/{url_hash}') -def get_bookmark( +async def get_bookmark( session: SessionDep, user_key: str, url_hash: str, ) -> Bookmark: """Show bookmark details.""" - bookmark = session.exec( + result = await session.exec( select(Bookmark).where( Bookmark.userkey == user_key, Bookmark.url_hash == url_hash, Bookmark.status != Visibility.DELETED ) - ).first() + ) + bookmark = result.first() # bookmark = session.get(Bookmark, {'url_hash': url_hash, 'userkey': user_key}) return bookmark @app.post('/api/v1/{user_key}/autocomplete_bookmark/', response_model=Bookmark) -def autocomplete_bookmark( +async def autocomplete_bookmark( session: SessionDep, request: Request, user_key: str, @@ -441,11 +447,12 @@ def autocomplete_bookmark( update_bookmark_with_info(bookmark, request, strip_params) url_hash = generate_hash(str(bookmark.url)) - bookmark_db = session.exec( + result = await session.exec( select(Bookmark).where( Bookmark.userkey == user_key, Bookmark.url_hash == url_hash, Bookmark.status != Visibility.DELETED ) - ).first() + ) + bookmark_db = result.first() if bookmark_db: # Bookmark with this URL already exists, provide the hash so the frontend can look it up and the user can # merge them if so wanted @@ -455,7 +462,7 @@ def autocomplete_bookmark( @app.post('/api/v1/{user_key}/bookmarks/', response_model=Bookmark) -def add_bookmark( +async def add_bookmark( session: SessionDep, request: Request, user_key: str, @@ -470,13 +477,13 @@ def add_bookmark( bookmark.url_hash = generate_hash(str(bookmark.url)) session.add(bookmark) - session.commit() - session.refresh(bookmark) + await session.commit() + await session.refresh(bookmark) return bookmark @app.patch('/api/v1/{user_key}/bookmarks/{url_hash}', response_model=Bookmark) -def update_bookmark( +async def update_bookmark( session: SessionDep, request: Request, user_key: str, @@ -485,11 +492,12 @@ def update_bookmark( strip_params: bool = False, ): """Update existing bookmark `bookmark_key` for user `user_key`.""" - bookmark_db = session.exec( + result = await session.exec( select(Bookmark).where( Bookmark.userkey == user_key, Bookmark.url_hash == url_hash, Bookmark.status != Visibility.DELETED ) - ).first() + ) + bookmark_db = result.first() if not bookmark_db: raise HTTPException(status_code=404, detail='Bookmark not found') @@ -510,13 +518,14 @@ def update_bookmark( @app.delete('/api/v1/{user_key}/bookmarks/{url_hash}', response_model=Bookmark) -def delete_bookmark( +async def delete_bookmark( session: SessionDep, user_key: str, url_hash: str, ): """(Soft)Delete bookmark `bookmark_key` for user `user_key`.""" - bookmark = session.get(Bookmark, {'url_hash': url_hash, 'userkey': user_key}) + result = await session.get(Bookmark, {'url_hash': url_hash, 'userkey': user_key}) + bookmark = result if not bookmark: raise HTTPException(status_code=404, detail='Bookmark not found') bookmark.deleted_date = datetime.now(UTC) @@ -527,21 +536,23 @@ def delete_bookmark( @app.get('/api/v1/{user_key}/latest_changes/') -def bookmarks_changed_since( +async def bookmarks_changed_since( session: SessionDep, user_key: str, ): """Last update on server, so the (browser) client knows whether to fetch an update.""" - latest_modified_bookmark = session.exec( + result = await session.exec( select(Bookmark) .where(Bookmark.userkey == user_key, Bookmark.status != Visibility.DELETED) .order_by(desc(Bookmark.modified_date)) - ).first() - latest_created_bookmark = session.exec( + ) + latest_modified_bookmark = result.first() + result = await session.exec( select(Bookmark) .where(Bookmark.userkey == user_key, Bookmark.status != Visibility.DELETED) .order_by(desc(Bookmark.created_date)) - ).first() + ) + latest_created_bookmark = result.first() latest_modification = max(latest_modified_bookmark.modified_date, latest_created_bookmark.created_date) @@ -554,14 +565,15 @@ def bookmarks_changed_since( @app.get('/api/v1/{user_key}/tags/') -def list_tags_for_user( +async def list_tags_for_user( session: SessionDep, user_key: str, ) -> list[str]: """List all tags in use by the user.""" - bookmarks = session.exec( + result = await session.exec( select(Bookmark).where(Bookmark.userkey == user_key, Bookmark.status != Visibility.DELETED) - ).all() + ) + bookmarks = result.all() tags = [] for bookmark in bookmarks: tags += bookmark.tag_list @@ -569,23 +581,25 @@ def list_tags_for_user( @app.get('/api/v1/{user_key}/tags/{tag_key}') -def list_tags_for_user( +async def list_tags_for_user( session: SessionDep, user_key: str, ) -> list[str]: """List all tags in use by the user.""" - bookmarks = session.exec(select(Bookmark).where(Bookmark.userkey == user_key)).all() + result = await session.exec(select(Bookmark).where(Bookmark.userkey == user_key)) + bookmarks = result.all() return list_tags_for_bookmarks(bookmarks) @app.get('/{user_key}', response_class=HTMLResponse) -def page_user_landing( +async def page_user_landing( session: SessionDep, request: Request, user_key: str, ): """HTML page with the main view for the user.""" - user = session.exec(select(User).where(User.key == user_key)).first() + result = await session.exec(select(User).where(User.key == user_key)) + user = result.first() if not user: raise HTTPException(status_code=404, detail='User not found') language = 'en'