import asyncio import collections import logging from typing import Union, Optional from src.database import db import asyncpg logger = logging.getLogger("rcgcdb.queue_handler") class UpdateDB: def __init__(self): self.updated: list[tuple[str, tuple[Union[str, int], ...]]] = [] def json(self): # since pickled Discord messages are bytes object which is not serializable, we strip it here return [(item[0], [arg for arg in item[1] if not isinstance(arg, bytes)]) for item in self.updated] def add(self, sql_expression: tuple[str, tuple[Union[str, int, bytes], ...]]): self.updated.append(sql_expression) def clear_list(self): self.updated.clear() async def fetch_rows(self, SQLstatement: str, *args: Union[str, int]) -> collections.abc.AsyncIterable: async with db.pool().acquire() as connection: async with connection.transaction(): async for row in connection.cursor(SQLstatement, *args): yield row async def update_db(self): try: while True: logger.debug("Running DB check") if self.updated: async with db.pool().acquire() as connection: async with connection.transaction(): while len(self.updated) > 0: update = self.updated[0] logger.debug("Executing: {} {}".format(update[0], update[1])) await connection.execute(update[0], *update[1]) self.updated.pop(0) await asyncio.sleep(10.0) except asyncio.CancelledError: logger.info("Shutting down after updating DB with {} more entries...".format(len(self.updated))) async with db.pool().acquire() as connection: async with connection.transaction(): for update in self.updated: await connection.execute(update[0], *update[1]) self.clear_list() await db.shutdown_connection() dbmanager = UpdateDB()