diff --git a/requirements.txt b/requirements.txt index c752fe3..2533b8e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,6 @@ beautifulsoup4 >= 4.6.0; python_version >= '3.6' aiohttp >= 3.6.2 lxml >= 4.2.1 nest-asyncio >= 1.4.0 -irc >= 19.0.1 \ No newline at end of file +irc >= 19.0.1 +beautifulsoup4>=4.9.3 +asyncpg>=0.22.0 \ No newline at end of file diff --git a/src/bot.py b/src/bot.py index ac6808e..8500343 100644 --- a/src/bot.py +++ b/src/bot.py @@ -11,7 +11,7 @@ from typing import Generator from contextlib import asynccontextmanager from src.argparser import command_line_args from src.config import settings -from src.database import db_cursor, db_connection +from src.database import connection, setup_connection, shutdown_connection from src.exceptions import * from src.misc import get_paths, get_domain from src.msgqueue import messagequeue, send_to_discord @@ -301,7 +301,7 @@ async def scan_group(group: str): else: local_wiki.rc_active = 0 DBHandler.add(queued_wiki.url, 0) - DBHandler.update_db() + await DBHandler.update_db() continue categorize_events = {} targets = generate_targets(queued_wiki.url, "AND (rcid != -1 OR rcid IS NULL)") @@ -347,7 +347,7 @@ async def scan_group(group: str): if recent_changes: # we don't have to test for highest_rc being null, because if there are no RC entries recent_changes will be an empty list which will result in false in here and DO NOT save the value local_wiki.rc_active = highest_rc DBHandler.add(queued_wiki.url, highest_rc) - DBHandler.update_db() + await DBHandler.update_db() except asyncio.CancelledError: return except QueueEmpty: @@ -426,7 +426,7 @@ async def discussion_handler(): ("-1", db_wiki["wiki"],)) else: await local_wiki.remove(db_wiki["wiki"], 1000) - DBHandler.update_db() + await DBHandler.update_db() continue raise WikiError discussion_feed = discussion_feed_resp["_embedded"]["doc:posts"] @@ -445,7 +445,7 @@ async def discussion_handler(): DBHandler.add(db_wiki["wiki"], discussion_feed[-1]["id"], True) else: DBHandler.add(db_wiki["wiki"], "0", True) - DBHandler.update_db() + await DBHandler.update_db() continue comment_events = [] targets = generate_targets(db_wiki["wiki"], "AND NOT postid = '-1'") @@ -496,7 +496,7 @@ async def discussion_handler(): DBHandler.add(db_wiki["wiki"], post["id"], True) await asyncio.sleep(delay=2.0) # hardcoded really doesn't need much more await asyncio.sleep(delay=1.0) # Avoid lock on no wikis - DBHandler.update_db() + await DBHandler.update_db() except asyncio.CancelledError: pass except: @@ -510,8 +510,6 @@ async def discussion_handler(): def shutdown(loop, signal=None): global main_tasks - DBHandler.update_db() - db_connection.close() loop.remove_signal_handler(signal) if len(messagequeue) > 0: logger.warning("Some messages are still queued!") @@ -558,6 +556,8 @@ async def main_loop(): main_tasks["msg_queue_shield"] = asyncio.shield(main_tasks["message_sender"]) await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"]) except KeyboardInterrupt: + await DBHandler.update_db() + await shutdown_connection() shutdown(loop) except asyncio.CancelledError: return diff --git a/src/database.py b/src/database.py index 4366668..3fb39cd 100644 --- a/src/database.py +++ b/src/database.py @@ -1,6 +1,18 @@ -import sqlite3 +import asyncpg +from typing import Any, Union, Optional from src.config import settings -db_connection = sqlite3.connect(settings.get("database_path", 'rcgcdb.db')) -db_connection.row_factory = sqlite3.Row -db_cursor = db_connection.cursor() +connection: Optional[asyncpg.Connection] = None + + +async def setup_connection(): + global connection + # Establish a connection to an existing database named "test" + # as a "postgres" user. + connection: asyncpg.connection = await asyncpg.connect(user=settings["pg_user"], host=settings.get("pg_host", "localhost"), + database=settings.get("pg_db", "RcGcDb"), password=settings.get("pg_pass")) + + +async def shutdown_connection(): + global connection + await connection.close() diff --git a/src/queue_handler.py b/src/queue_handler.py index a37a10d..49952c5 100644 --- a/src/queue_handler.py +++ b/src/queue_handler.py @@ -1,5 +1,5 @@ import logging -from src.database import db_cursor, db_connection +from src.database import connection logger = logging.getLogger("rcgcdb.queue_handler") @@ -14,15 +14,15 @@ class UpdateDB: def clear_list(self): self.updated.clear() - def update_db(self): - for update in self.updated: - if update[2] is None: - sql = "UPDATE rcgcdw SET rcid = ? WHERE wiki = ? AND ( rcid != -1 OR rcid IS NULL )" - else: - sql = "UPDATE rcgcdw SET postid = ? WHERE wiki = ? AND ( postid != '-1' OR postid IS NULL )" - db_cursor.execute(sql, (update[1], update[0],)) - db_connection.commit() - self.clear_list() + async def update_db(self): + async with connection.transaction(): + for update in self.updated: + if update[2] is None: + sql = "UPDATE rcgcdw SET rcid = ? WHERE wiki = ? AND ( rcid != -1 OR rcid IS NULL )" + else: + sql = "UPDATE rcgcdw SET postid = ? WHERE wiki = ? AND ( postid != '-1' OR postid IS NULL )" + await connection.execute(sql) + self.clear_list() DBHandler = UpdateDB()