From a2f1d54f39e4f8b194192e1214f256063767b3a6 Mon Sep 17 00:00:00 2001 From: Frisk Date: Sat, 20 Mar 2021 13:42:54 +0100 Subject: [PATCH] Added PostgreSQL compatibility, dropped SQLite compatibility --- settings.json.example | 4 + src/bot.py | 319 +++++++++++++++++++++--------------------- src/database.py | 43 ++++-- src/discord.py | 30 ++-- src/queue_handler.py | 19 +-- src/wiki.py | 5 +- 6 files changed, 227 insertions(+), 193 deletions(-) diff --git a/settings.json.example b/settings.json.example index 67dd2f2..33f9848 100644 --- a/settings.json.example +++ b/settings.json.example @@ -8,6 +8,10 @@ "monitoring_webhook": "111111111111111111/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "support": "https://discord.gg/v77RTk5", "irc_overtime": 3600, + "pg_user": "postgres", + "pg_host": "localhost", + "pg_db": "rcgcdb", + "pg_pass": "secret_password", "irc_servers": { "your custom name for the farm": { "domains": ["wikipedia.org", "otherwikipedia.org"], diff --git a/src/bot.py b/src/bot.py index ccc69dc..2b26a43 100644 --- a/src/bot.py +++ b/src/bot.py @@ -11,7 +11,7 @@ from typing import Generator from contextlib import asynccontextmanager from src.argparser import command_line_args from src.config import settings -from src.database import connection, setup_connection, shutdown_connection +from src.database import db from src.exceptions import * from src.misc import get_paths, get_domain from src.msgqueue import messagequeue, send_to_discord @@ -39,11 +39,13 @@ main_tasks: dict = {} # First populate the all_wikis list with every wiki # Reasons for this: 1. we require amount of wikis to calculate the cooldown between requests # 2. Easier to code + async def populate_allwikis(): - async with connection.transaction(): - async for db_wiki in connection.cursor('SELECT DISTINCT wiki, rcid FROM rcgcdw'): - all_wikis[db_wiki["wiki"]] = Wiki() # populate all_wikis - all_wikis[db_wiki["wiki"]].rc_active = db_wiki["rcid"] + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for db_wiki in connection.cursor('SELECT DISTINCT wiki, rcid FROM rcgcdw'): + all_wikis[db_wiki["wiki"]] = Wiki() # populate all_wikis + all_wikis[db_wiki["wiki"]].rc_active = db_wiki["rcid"] queue_limit = settings.get("queue_limit", 30) QueuedWiki = namedtuple("QueuedWiki", ['url', 'amount']) @@ -104,10 +106,11 @@ class RcQueue: del self.domain_list[group] async def check_if_domain_in_db(self, domain): - async with connection.transaction(): - async for wiki in connection.cursor('SELECT DISTINCT wiki FROM rcgcdw WHERE rcid != -1;'): - if get_domain(wiki["wiki"]) == domain: - return True + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for wiki in connection.cursor('SELECT DISTINCT wiki FROM rcgcdw WHERE rcid != -1;'): + if get_domain(wiki["wiki"]) == domain: + return True return False @asynccontextmanager @@ -146,47 +149,48 @@ class RcQueue: try: self.to_remove = [x[0] for x in filter(self.filter_rc_active, all_wikis.items())] # first populate this list and remove wikis that are still in the db, clean up the rest full = set() - async with connection.transaction(): - async for db_wiki in connection.cursor('SELECT DISTINCT wiki, row_number() over (ORDER BY webhook) AS ROWID, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL order by webhook'): - domain = get_domain(db_wiki["wiki"]) - try: - if db_wiki["wiki"] not in all_wikis: - raise AssertionError - self.to_remove.remove(db_wiki["wiki"]) - except AssertionError: - all_wikis[db_wiki["wiki"]] = Wiki() - all_wikis[db_wiki["wiki"]].rc_active = db_wiki["rcid"] - except ValueError: - pass - if domain in full: - continue - try: - current_domain: dict = self[domain] - if current_domain["irc"]: - logger.debug("DOMAIN LIST FOR IRC: {}".format(current_domain["irc"].updated)) - logger.debug("CURRENT DOMAIN INFO: {}".format(domain)) - logger.debug("IS WIKI IN A LIST?: {}".format(db_wiki["wiki"] in current_domain["irc"].updated)) - logger.debug("LAST CHECK FOR THE WIKI {} IS {}".format(db_wiki["wiki"], all_wikis[db_wiki["wiki"]].last_check)) - if db_wiki["wiki"] in current_domain["irc"].updated: # Priority wikis are the ones with IRC, if they get updated forcefully add them to queue - current_domain["irc"].updated.remove(db_wiki["wiki"]) - current_domain["query"].append(QueuedWiki(db_wiki["wiki"], 20), forced=True) - logger.debug("Updated in IRC so adding to queue.") - continue - elif all_wikis[db_wiki["wiki"]].last_check+settings["irc_overtime"] < time.time(): # if time went by and wiki should be updated now use default mechanics - logger.debug("Overtime so adding to queue.") - pass - else: # Continue without adding - logger.debug("No condition fulfilled so skipping.") - continue - if not db_wiki["rowid"] < current_domain["last_rowid"]: - current_domain["query"].append(QueuedWiki(db_wiki["wiki"], 20)) - except KeyError: - await self.start_group(domain, [QueuedWiki(db_wiki["wiki"], 20)]) - logger.info("A new domain group ({}) has been added since last time, adding it to the domain_list and starting a task...".format(domain)) - except ListFull: - full.add(domain) - current_domain["last_rowid"] = db_wiki["rowid"] - continue + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for db_wiki in connection.cursor('SELECT DISTINCT wiki, row_number() over (ORDER BY webhook) AS ROWID, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL order by webhook'): + domain = get_domain(db_wiki["wiki"]) + try: + if db_wiki["wiki"] not in all_wikis: + raise AssertionError + self.to_remove.remove(db_wiki["wiki"]) + except AssertionError: + all_wikis[db_wiki["wiki"]] = Wiki() + all_wikis[db_wiki["wiki"]].rc_active = db_wiki["rcid"] + except ValueError: + pass + if domain in full: + continue + try: + current_domain: dict = self[domain] + if current_domain["irc"]: + logger.debug("DOMAIN LIST FOR IRC: {}".format(current_domain["irc"].updated)) + logger.debug("CURRENT DOMAIN INFO: {}".format(domain)) + logger.debug("IS WIKI IN A LIST?: {}".format(db_wiki["wiki"] in current_domain["irc"].updated)) + logger.debug("LAST CHECK FOR THE WIKI {} IS {}".format(db_wiki["wiki"], all_wikis[db_wiki["wiki"]].last_check)) + if db_wiki["wiki"] in current_domain["irc"].updated: # Priority wikis are the ones with IRC, if they get updated forcefully add them to queue + current_domain["irc"].updated.remove(db_wiki["wiki"]) + current_domain["query"].append(QueuedWiki(db_wiki["wiki"], 20), forced=True) + logger.debug("Updated in IRC so adding to queue.") + continue + elif all_wikis[db_wiki["wiki"]].last_check+settings["irc_overtime"] < time.time(): # if time went by and wiki should be updated now use default mechanics + logger.debug("Overtime so adding to queue.") + pass + else: # Continue without adding + logger.debug("No condition fulfilled so skipping.") + continue + if not db_wiki["rowid"] < current_domain["last_rowid"]: + current_domain["query"].append(QueuedWiki(db_wiki["wiki"], 20)) + except KeyError: + await self.start_group(domain, [QueuedWiki(db_wiki["wiki"], 20)]) + logger.info("A new domain group ({}) has been added since last time, adding it to the domain_list and starting a task...".format(domain)) + except ListFull: + full.add(domain) + current_domain["last_rowid"] = db_wiki["rowid"] + continue for wiki in self.to_remove: await self.remove_wiki_from_group(wiki) for group, data in self.domain_list.items(): @@ -232,10 +236,11 @@ async def generate_targets(wiki_url: str, additional_requirements: str) -> defau request to the wiki just to duplicate the message. """ combinations = defaultdict(list) - async with connection.transaction(): - async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 {}'.format(additional_requirements), wiki_url): - combination = (webhook["lang"], webhook["display"]) - combinations[combination].append(webhook["webhook"]) + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 {}'.format(additional_requirements), wiki_url): + combination = (webhook["lang"], webhook["display"]) + combinations[combination].append(webhook["webhook"]) return combinations @@ -244,9 +249,10 @@ async def generate_domain_groups(): :returns tuple[str, list]""" domain_wikis = defaultdict(list) - async with connection.transaction(): - async for db_wiki in connection.cursor('SELECT DISTINCT wiki, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL'): - domain_wikis[get_domain(db_wiki["wiki"])].append(QueuedWiki(db_wiki["wiki"], 20)) + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for db_wiki in connection.cursor('SELECT DISTINCT wiki, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL'): + domain_wikis[get_domain(db_wiki["wiki"])].append(QueuedWiki(db_wiki["wiki"], 20)) for group, db_wikis in domain_wikis.items(): yield group, db_wikis @@ -396,107 +402,108 @@ async def message_sender(): async def discussion_handler(): try: while True: - async with connection.transaction(): - async for db_wiki in connection.cursor("SELECT DISTINCT wiki, rcid, postid FROM rcgcdw WHERE postid != '-1' OR postid IS NULL"): - try: - local_wiki = all_wikis[db_wiki["wiki"]] # set a reference to a wiki object from memory - except KeyError: - local_wiki = all_wikis[db_wiki["wiki"]] = Wiki() - local_wiki.rc_active = db_wiki["rcid"] - if db_wiki["wiki"] not in rcqueue.irc_mapping["fandom.com"].updated_discussions and local_wiki.last_discussion_check+settings["irc_overtime"] > time.time(): # I swear if another wiki farm ever starts using Fandom discussions I'm gonna use explosion magic - continue - else: + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for db_wiki in connection.cursor("SELECT DISTINCT wiki, rcid, postid FROM rcgcdw WHERE postid != '-1' OR postid IS NULL"): try: - rcqueue.irc_mapping["fandom.com"].updated_discussions.remove(db_wiki["wiki"]) + local_wiki = all_wikis[db_wiki["wiki"]] # set a reference to a wiki object from memory except KeyError: - pass # to be expected - header = settings["header"] - header["Accept"] = "application/hal+json" - async with aiohttp.ClientSession(headers=header, - timeout=aiohttp.ClientTimeout(6.0)) as session: - try: - feeds_response = await local_wiki.fetch_feeds(db_wiki["wiki"], session) - except (WikiServerError, WikiError): - continue # ignore this wiki if it throws errors - try: - discussion_feed_resp = await feeds_response.json(encoding="UTF-8") - if "error" in discussion_feed_resp: - error = discussion_feed_resp["error"] - if error == "NotFoundException": # Discussions disabled - if db_wiki["rcid"] != -1: # RC feed is disabled - await connection.execute("UPDATE rcgcdw SET postid = ? WHERE wiki = ?", - ("-1", db_wiki["wiki"],)) - else: - await local_wiki.remove(db_wiki["wiki"], 1000) - await DBHandler.update_db() - continue - raise WikiError - discussion_feed = discussion_feed_resp["_embedded"]["doc:posts"] - discussion_feed.reverse() - except aiohttp.ContentTypeError: - logger.exception("Wiki seems to be resulting in non-json content.") + local_wiki = all_wikis[db_wiki["wiki"]] = Wiki() + local_wiki.rc_active = db_wiki["rcid"] + if db_wiki["wiki"] not in rcqueue.irc_mapping["fandom.com"].updated_discussions and local_wiki.last_discussion_check+settings["irc_overtime"] > time.time(): # I swear if another wiki farm ever starts using Fandom discussions I'm gonna use explosion magic continue - except asyncio.TimeoutError: - logger.debug("Timeout on reading JSON of discussion post feeed.") - continue - except: - logger.exception("On loading json of response.") - continue - if db_wiki["postid"] is None: # new wiki, just get the last post to not spam the channel - if len(discussion_feed) > 0: - DBHandler.add(db_wiki["wiki"], discussion_feed[-1]["id"], True) else: - DBHandler.add(db_wiki["wiki"], "0", True) - await DBHandler.update_db() - continue - comment_events = [] - targets = await generate_targets(db_wiki["wiki"], "AND NOT postid = '-1'") - for post in discussion_feed: - if post["_embedded"]["thread"][0]["containerType"] == "ARTICLE_COMMENT" and post["id"] > db_wiki["postid"]: - comment_events.append(post["forumId"]) - comment_pages: dict = {} - if comment_events: - try: - comment_pages = await local_wiki.safe_request( - "{wiki}wikia.php?controller=FeedsAndPosts&method=getArticleNamesAndUsernames&stablePageIds={pages}&format=json".format( - wiki=db_wiki["wiki"], pages=",".join(comment_events) - ), RateLimiter(), "articleNames") - except aiohttp.ClientResponseError: # Fandom can be funny sometimes... See #30 - comment_pages = None - except: - if command_line_args.debug: - logger.exception("Exception on Feeds article comment request") - shutdown(loop=asyncio.get_event_loop()) + try: + rcqueue.irc_mapping["fandom.com"].updated_discussions.remove(db_wiki["wiki"]) + except KeyError: + pass # to be expected + header = settings["header"] + header["Accept"] = "application/hal+json" + async with aiohttp.ClientSession(headers=header, + timeout=aiohttp.ClientTimeout(6.0)) as session: + try: + feeds_response = await local_wiki.fetch_feeds(db_wiki["wiki"], session) + except (WikiServerError, WikiError): + continue # ignore this wiki if it throws errors + try: + discussion_feed_resp = await feeds_response.json(encoding="UTF-8") + if "error" in discussion_feed_resp: + error = discussion_feed_resp["error"] + if error == "NotFoundException": # Discussions disabled + if db_wiki["rcid"] != -1: # RC feed is disabled + await connection.execute("UPDATE rcgcdw SET postid = $1 WHERE wiki = $2", + ("-1", db_wiki["wiki"],)) + else: + await local_wiki.remove(db_wiki["wiki"], 1000) + await DBHandler.update_db() + continue + raise WikiError + discussion_feed = discussion_feed_resp["_embedded"]["doc:posts"] + discussion_feed.reverse() + except aiohttp.ContentTypeError: + logger.exception("Wiki seems to be resulting in non-json content.") + continue + except asyncio.TimeoutError: + logger.debug("Timeout on reading JSON of discussion post feeed.") + continue + except: + logger.exception("On loading json of response.") + continue + if db_wiki["postid"] is None: # new wiki, just get the last post to not spam the channel + if len(discussion_feed) > 0: + DBHandler.add(db_wiki["wiki"], discussion_feed[-1]["id"], True) else: - logger.exception("Exception on Feeds article comment request") - await generic_msg_sender_exception_logger(traceback.format_exc(), - "Exception on Feeds article comment request", - Post=str(post)[0:1000], Wiki=db_wiki["wiki"]) - message_list = defaultdict(list) - for post in discussion_feed: # Yeah, second loop since the comments require an extra request - if post["id"] > db_wiki["postid"]: - for target in targets.items(): - try: - message = await essential_feeds(post, comment_pages, db_wiki, target) - if message is not None: - message_list[target[0]].append(message) - except asyncio.CancelledError: - raise - except: - if command_line_args.debug: - logger.exception("Exception on Feeds formatter") - shutdown(loop=asyncio.get_event_loop()) - else: - logger.exception("Exception on Feeds formatter") - await generic_msg_sender_exception_logger(traceback.format_exc(), "Exception in feed formatter", Post=str(post)[0:1000], Wiki=db_wiki["wiki"]) - # Lets stack the messages - for messages in message_list.values(): - messages = stack_message_list(messages) - for message in messages: - await send_to_discord(message) - if discussion_feed: - DBHandler.add(db_wiki["wiki"], post["id"], True) - await asyncio.sleep(delay=2.0) # hardcoded really doesn't need much more + DBHandler.add(db_wiki["wiki"], "0", True) + await DBHandler.update_db() + continue + comment_events = [] + targets = await generate_targets(db_wiki["wiki"], "AND NOT postid = '-1'") + for post in discussion_feed: + if post["_embedded"]["thread"][0]["containerType"] == "ARTICLE_COMMENT" and post["id"] > db_wiki["postid"]: + comment_events.append(post["forumId"]) + comment_pages: dict = {} + if comment_events: + try: + comment_pages = await local_wiki.safe_request( + "{wiki}wikia.php?controller=FeedsAndPosts&method=getArticleNamesAndUsernames&stablePageIds={pages}&format=json".format( + wiki=db_wiki["wiki"], pages=",".join(comment_events) + ), RateLimiter(), "articleNames") + except aiohttp.ClientResponseError: # Fandom can be funny sometimes... See #30 + comment_pages = None + except: + if command_line_args.debug: + logger.exception("Exception on Feeds article comment request") + shutdown(loop=asyncio.get_event_loop()) + else: + logger.exception("Exception on Feeds article comment request") + await generic_msg_sender_exception_logger(traceback.format_exc(), + "Exception on Feeds article comment request", + Post=str(post)[0:1000], Wiki=db_wiki["wiki"]) + message_list = defaultdict(list) + for post in discussion_feed: # Yeah, second loop since the comments require an extra request + if post["id"] > db_wiki["postid"]: + for target in targets.items(): + try: + message = await essential_feeds(post, comment_pages, db_wiki, target) + if message is not None: + message_list[target[0]].append(message) + except asyncio.CancelledError: + raise + except: + if command_line_args.debug: + logger.exception("Exception on Feeds formatter") + shutdown(loop=asyncio.get_event_loop()) + else: + logger.exception("Exception on Feeds formatter") + await generic_msg_sender_exception_logger(traceback.format_exc(), "Exception in feed formatter", Post=str(post)[0:1000], Wiki=db_wiki["wiki"]) + # Lets stack the messages + for messages in message_list.values(): + messages = stack_message_list(messages) + for message in messages: + await send_to_discord(message) + if discussion_feed: + DBHandler.add(db_wiki["wiki"], post["id"], True) + await asyncio.sleep(delay=2.0) # hardcoded really doesn't need much more await asyncio.sleep(delay=1.0) # Avoid lock on no wikis await DBHandler.update_db() except asyncio.CancelledError: @@ -543,8 +550,8 @@ async def main_loop(): global main_tasks loop = asyncio.get_event_loop() nest_asyncio.apply(loop) - await setup_connection() - logger.debug("Connection type: {}".format(connection)) + await db.setup_connection() + logger.debug("Connection type: {}".format(db.connection)) await populate_allwikis() try: signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) @@ -562,7 +569,7 @@ async def main_loop(): await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"]) except KeyboardInterrupt: await DBHandler.update_db() - await shutdown_connection() + await db.shutdown_connection() shutdown(loop) except asyncio.CancelledError: return diff --git a/src/database.py b/src/database.py index 4c50562..86cb293 100644 --- a/src/database.py +++ b/src/database.py @@ -4,19 +4,38 @@ from typing import Optional from src.config import settings logger = logging.getLogger("rcgcdb.database") -connection: Optional[asyncpg.Connection] = None +# connection: Optional[asyncpg.Connection] = None -async def setup_connection(): - global connection - # Establish a connection to an existing database named "test" - # as a "postgres" user. - logger.debug("Setting up the Database connection...") - connection = await asyncpg.connect(user=settings["pg_user"], host=settings.get("pg_host", "localhost"), - database=settings.get("pg_db", "rcgcdb"), password=settings.get("pg_pass")) - logger.debug("Database connection established! Connection: {}".format(connection)) +class db_connection: + connection: Optional[asyncpg.Pool] = None + + async def setup_connection(self): + # Establish a connection to an existing database named "test" + # as a "postgres" user. + logger.debug("Setting up the Database connection...") + self.connection = await asyncpg.create_pool(user=settings["pg_user"], host=settings.get("pg_host", "localhost"), + database=settings.get("pg_db", "rcgcdb"), password=settings.get("pg_pass")) + logger.debug("Database connection established! Connection: {}".format(self.connection)) + + async def shutdown_connection(self): + await self.connection.close() + + def pool(self) -> asyncpg.Pool: + return self.connection + + # Tried to make it a decorator but tbh won't probably work + # async def in_transaction(self, func): + # async def single_transaction(): + # async with self.connection.acquire() as connection: + # async with connection.transaction(): + # await func() + # return single_transaction + + # async def query(self, string, *arg): + # async with self.connection.acquire() as connection: + # async with connection.transaction(): + # return connection.cursor(string, *arg) -async def shutdown_connection(): - global connection - await connection.close() +db = db_connection() diff --git a/src/discord.py b/src/discord.py index 89ad9bf..fedfd5a 100644 --- a/src/discord.py +++ b/src/discord.py @@ -3,7 +3,7 @@ from collections import defaultdict from src.misc import logger from src.config import settings -from src.database import connection +from src.database import db from src.i18n import langs from src.exceptions import EmbedListFull from asyncio import TimeoutError @@ -22,18 +22,19 @@ default_header["X-RateLimit-Precision"] = "millisecond" # User facing webhook functions async def wiki_removal(wiki_url, status): - async with connection.transaction(): - async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = ?', wiki_url): - _ = langs[observer["lang"]]["discord"].gettext - reasons = {410: _("wiki deleted"), 404: _("wiki deleted"), 401: _("wiki inaccessible"), - 402: _("wiki inaccessible"), 403: _("wiki inaccessible"), 1000: _("discussions disabled")} - reason = reasons.get(status, _("unknown error")) - await send_to_discord_webhook(DiscordMessage("compact", "webhook/remove", webhook_url=[], content=_("This recent changes webhook has been removed for `{reason}`!").format(reason=reason), wiki=None), webhook_url=observer["webhook"]) - header = settings["header"] - header['Content-Type'] = 'application/json' - header['X-Audit-Log-Reason'] = "Wiki becoming unavailable" - async with aiohttp.ClientSession(headers=header, timeout=aiohttp.ClientTimeout(5.0)) as session: - await session.delete("https://discord.com/api/webhooks/"+observer["webhook"]) + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = $1', wiki_url): + _ = langs[observer["lang"]]["discord"].gettext + reasons = {410: _("wiki deleted"), 404: _("wiki deleted"), 401: _("wiki inaccessible"), + 402: _("wiki inaccessible"), 403: _("wiki inaccessible"), 1000: _("discussions disabled")} + reason = reasons.get(status, _("unknown error")) + await send_to_discord_webhook(DiscordMessage("compact", "webhook/remove", webhook_url=[], content=_("This recent changes webhook has been removed for `{reason}`!").format(reason=reason), wiki=None), webhook_url=observer["webhook"]) + header = settings["header"] + header['Content-Type'] = 'application/json' + header['X-Audit-Log-Reason'] = "Wiki becoming unavailable" + async with aiohttp.ClientSession(headers=header, timeout=aiohttp.ClientTimeout(5.0)) as session: + await session.delete("https://discord.com/api/webhooks/"+observer["webhook"]) async def webhook_removal_monitor(webhook_url: str, reason: int): @@ -239,7 +240,8 @@ async def handle_discord_http(code: int, formatted_embed: str, result: aiohttp.C return 1 elif code == 401 or code == 404: # HTTP UNAUTHORIZED AND NOT FOUND logger.error("Webhook URL is invalid or no longer in use, please replace it with proper one.") - await connection.execute("DELETE FROM rcgcdw WHERE webhook = ?", (webhook_url,)) + async with db.pool().acquire() as connection: + await connection.execute("DELETE FROM rcgcdw WHERE webhook = $1", (webhook_url,)) await webhook_removal_monitor(webhook_url, code) return 1 elif code == 429: diff --git a/src/queue_handler.py b/src/queue_handler.py index 49952c5..9e4f81e 100644 --- a/src/queue_handler.py +++ b/src/queue_handler.py @@ -1,5 +1,5 @@ import logging -from src.database import connection +from src.database import db logger = logging.getLogger("rcgcdb.queue_handler") @@ -15,14 +15,15 @@ class UpdateDB: self.updated.clear() async def update_db(self): - async with connection.transaction(): - for update in self.updated: - if update[2] is None: - sql = "UPDATE rcgcdw SET rcid = ? WHERE wiki = ? AND ( rcid != -1 OR rcid IS NULL )" - else: - sql = "UPDATE rcgcdw SET postid = ? WHERE wiki = ? AND ( postid != '-1' OR postid IS NULL )" - await connection.execute(sql) - self.clear_list() + async with db.pool().acquire() as connection: + async with connection.transaction(): + for update in self.updated: + if update[2] is None: + sql = "UPDATE rcgcdw SET rcid = $2 WHERE wiki = $1 AND ( rcid != -1 OR rcid IS NULL )" + else: + sql = "UPDATE rcgcdw SET postid = $2 WHERE wiki = $1 AND ( postid != '-1' OR postid IS NULL )" + await connection.execute(sql, update[0], update[1]) + self.clear_list() DBHandler = UpdateDB() diff --git a/src/wiki.py b/src/wiki.py index 1c0fde8..02c9d19 100644 --- a/src/wiki.py +++ b/src/wiki.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import re import logging, aiohttp from src.exceptions import * -from src.database import connection +from src.database import db from src.formatters.rc import embed_formatter, compact_formatter from src.formatters.discussions import feeds_embed_formatter, feeds_compact_formatter from src.misc import parse_link @@ -109,7 +109,8 @@ class Wiki: logger.info("Removing a wiki {}".format(wiki_url)) await src.discord.wiki_removal(wiki_url, reason) await src.discord.wiki_removal_monitor(wiki_url, reason) - result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = ?', wiki_url) + async with db.pool().acquire() as connection: + result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = $1', wiki_url) logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, wiki_url)) async def pull_comment(self, comment_id, WIKI_API_PATH, rate_limiter):