Added PostgreSQL compatibility, dropped SQLite compatibility

This commit is contained in:
Frisk 2021-03-20 13:42:54 +01:00
parent 2c8574445c
commit a2f1d54f39
No known key found for this signature in database
GPG key ID: 213F7C15068AF8AC
6 changed files with 227 additions and 193 deletions

View file

@ -8,6 +8,10 @@
"monitoring_webhook": "111111111111111111/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA", "monitoring_webhook": "111111111111111111/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"support": "https://discord.gg/v77RTk5", "support": "https://discord.gg/v77RTk5",
"irc_overtime": 3600, "irc_overtime": 3600,
"pg_user": "postgres",
"pg_host": "localhost",
"pg_db": "rcgcdb",
"pg_pass": "secret_password",
"irc_servers": { "irc_servers": {
"your custom name for the farm": { "your custom name for the farm": {
"domains": ["wikipedia.org", "otherwikipedia.org"], "domains": ["wikipedia.org", "otherwikipedia.org"],

View file

@ -11,7 +11,7 @@ from typing import Generator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from src.argparser import command_line_args from src.argparser import command_line_args
from src.config import settings from src.config import settings
from src.database import connection, setup_connection, shutdown_connection from src.database import db
from src.exceptions import * from src.exceptions import *
from src.misc import get_paths, get_domain from src.misc import get_paths, get_domain
from src.msgqueue import messagequeue, send_to_discord from src.msgqueue import messagequeue, send_to_discord
@ -39,7 +39,9 @@ main_tasks: dict = {}
# First populate the all_wikis list with every wiki # First populate the all_wikis list with every wiki
# Reasons for this: 1. we require amount of wikis to calculate the cooldown between requests # Reasons for this: 1. we require amount of wikis to calculate the cooldown between requests
# 2. Easier to code # 2. Easier to code
async def populate_allwikis(): async def populate_allwikis():
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, rcid FROM rcgcdw'): async for db_wiki in connection.cursor('SELECT DISTINCT wiki, rcid FROM rcgcdw'):
all_wikis[db_wiki["wiki"]] = Wiki() # populate all_wikis all_wikis[db_wiki["wiki"]] = Wiki() # populate all_wikis
@ -104,6 +106,7 @@ class RcQueue:
del self.domain_list[group] del self.domain_list[group]
async def check_if_domain_in_db(self, domain): async def check_if_domain_in_db(self, domain):
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for wiki in connection.cursor('SELECT DISTINCT wiki FROM rcgcdw WHERE rcid != -1;'): async for wiki in connection.cursor('SELECT DISTINCT wiki FROM rcgcdw WHERE rcid != -1;'):
if get_domain(wiki["wiki"]) == domain: if get_domain(wiki["wiki"]) == domain:
@ -146,6 +149,7 @@ class RcQueue:
try: try:
self.to_remove = [x[0] for x in filter(self.filter_rc_active, all_wikis.items())] # first populate this list and remove wikis that are still in the db, clean up the rest self.to_remove = [x[0] for x in filter(self.filter_rc_active, all_wikis.items())] # first populate this list and remove wikis that are still in the db, clean up the rest
full = set() full = set()
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, row_number() over (ORDER BY webhook) AS ROWID, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL order by webhook'): async for db_wiki in connection.cursor('SELECT DISTINCT wiki, row_number() over (ORDER BY webhook) AS ROWID, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL order by webhook'):
domain = get_domain(db_wiki["wiki"]) domain = get_domain(db_wiki["wiki"])
@ -232,6 +236,7 @@ async def generate_targets(wiki_url: str, additional_requirements: str) -> defau
request to the wiki just to duplicate the message. request to the wiki just to duplicate the message.
""" """
combinations = defaultdict(list) combinations = defaultdict(list)
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 {}'.format(additional_requirements), wiki_url): async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 {}'.format(additional_requirements), wiki_url):
combination = (webhook["lang"], webhook["display"]) combination = (webhook["lang"], webhook["display"])
@ -244,6 +249,7 @@ async def generate_domain_groups():
:returns tuple[str, list]""" :returns tuple[str, list]"""
domain_wikis = defaultdict(list) domain_wikis = defaultdict(list)
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL'): async for db_wiki in connection.cursor('SELECT DISTINCT wiki, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL'):
domain_wikis[get_domain(db_wiki["wiki"])].append(QueuedWiki(db_wiki["wiki"], 20)) domain_wikis[get_domain(db_wiki["wiki"])].append(QueuedWiki(db_wiki["wiki"], 20))
@ -396,6 +402,7 @@ async def message_sender():
async def discussion_handler(): async def discussion_handler():
try: try:
while True: while True:
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for db_wiki in connection.cursor("SELECT DISTINCT wiki, rcid, postid FROM rcgcdw WHERE postid != '-1' OR postid IS NULL"): async for db_wiki in connection.cursor("SELECT DISTINCT wiki, rcid, postid FROM rcgcdw WHERE postid != '-1' OR postid IS NULL"):
try: try:
@ -424,7 +431,7 @@ async def discussion_handler():
error = discussion_feed_resp["error"] error = discussion_feed_resp["error"]
if error == "NotFoundException": # Discussions disabled if error == "NotFoundException": # Discussions disabled
if db_wiki["rcid"] != -1: # RC feed is disabled if db_wiki["rcid"] != -1: # RC feed is disabled
await connection.execute("UPDATE rcgcdw SET postid = ? WHERE wiki = ?", await connection.execute("UPDATE rcgcdw SET postid = $1 WHERE wiki = $2",
("-1", db_wiki["wiki"],)) ("-1", db_wiki["wiki"],))
else: else:
await local_wiki.remove(db_wiki["wiki"], 1000) await local_wiki.remove(db_wiki["wiki"], 1000)
@ -543,8 +550,8 @@ async def main_loop():
global main_tasks global main_tasks
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
nest_asyncio.apply(loop) nest_asyncio.apply(loop)
await setup_connection() await db.setup_connection()
logger.debug("Connection type: {}".format(connection)) logger.debug("Connection type: {}".format(db.connection))
await populate_allwikis() await populate_allwikis()
try: try:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
@ -562,7 +569,7 @@ async def main_loop():
await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"]) await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"])
except KeyboardInterrupt: except KeyboardInterrupt:
await DBHandler.update_db() await DBHandler.update_db()
await shutdown_connection() await db.shutdown_connection()
shutdown(loop) shutdown(loop)
except asyncio.CancelledError: except asyncio.CancelledError:
return return

View file

@ -4,19 +4,38 @@ from typing import Optional
from src.config import settings from src.config import settings
logger = logging.getLogger("rcgcdb.database") logger = logging.getLogger("rcgcdb.database")
connection: Optional[asyncpg.Connection] = None # connection: Optional[asyncpg.Connection] = None
async def setup_connection(): class db_connection:
global connection connection: Optional[asyncpg.Pool] = None
async def setup_connection(self):
# Establish a connection to an existing database named "test" # Establish a connection to an existing database named "test"
# as a "postgres" user. # as a "postgres" user.
logger.debug("Setting up the Database connection...") logger.debug("Setting up the Database connection...")
connection = await asyncpg.connect(user=settings["pg_user"], host=settings.get("pg_host", "localhost"), self.connection = await asyncpg.create_pool(user=settings["pg_user"], host=settings.get("pg_host", "localhost"),
database=settings.get("pg_db", "rcgcdb"), password=settings.get("pg_pass")) database=settings.get("pg_db", "rcgcdb"), password=settings.get("pg_pass"))
logger.debug("Database connection established! Connection: {}".format(connection)) logger.debug("Database connection established! Connection: {}".format(self.connection))
async def shutdown_connection(self):
await self.connection.close()
def pool(self) -> asyncpg.Pool:
return self.connection
# Tried to make it a decorator but tbh won't probably work
# async def in_transaction(self, func):
# async def single_transaction():
# async with self.connection.acquire() as connection:
# async with connection.transaction():
# await func()
# return single_transaction
# async def query(self, string, *arg):
# async with self.connection.acquire() as connection:
# async with connection.transaction():
# return connection.cursor(string, *arg)
async def shutdown_connection(): db = db_connection()
global connection
await connection.close()

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from src.misc import logger from src.misc import logger
from src.config import settings from src.config import settings
from src.database import connection from src.database import db
from src.i18n import langs from src.i18n import langs
from src.exceptions import EmbedListFull from src.exceptions import EmbedListFull
from asyncio import TimeoutError from asyncio import TimeoutError
@ -22,8 +22,9 @@ default_header["X-RateLimit-Precision"] = "millisecond"
# User facing webhook functions # User facing webhook functions
async def wiki_removal(wiki_url, status): async def wiki_removal(wiki_url, status):
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = ?', wiki_url): async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = $1', wiki_url):
_ = langs[observer["lang"]]["discord"].gettext _ = langs[observer["lang"]]["discord"].gettext
reasons = {410: _("wiki deleted"), 404: _("wiki deleted"), 401: _("wiki inaccessible"), reasons = {410: _("wiki deleted"), 404: _("wiki deleted"), 401: _("wiki inaccessible"),
402: _("wiki inaccessible"), 403: _("wiki inaccessible"), 1000: _("discussions disabled")} 402: _("wiki inaccessible"), 403: _("wiki inaccessible"), 1000: _("discussions disabled")}
@ -239,7 +240,8 @@ async def handle_discord_http(code: int, formatted_embed: str, result: aiohttp.C
return 1 return 1
elif code == 401 or code == 404: # HTTP UNAUTHORIZED AND NOT FOUND elif code == 401 or code == 404: # HTTP UNAUTHORIZED AND NOT FOUND
logger.error("Webhook URL is invalid or no longer in use, please replace it with proper one.") logger.error("Webhook URL is invalid or no longer in use, please replace it with proper one.")
await connection.execute("DELETE FROM rcgcdw WHERE webhook = ?", (webhook_url,)) async with db.pool().acquire() as connection:
await connection.execute("DELETE FROM rcgcdw WHERE webhook = $1", (webhook_url,))
await webhook_removal_monitor(webhook_url, code) await webhook_removal_monitor(webhook_url, code)
return 1 return 1
elif code == 429: elif code == 429:

View file

@ -1,5 +1,5 @@
import logging import logging
from src.database import connection from src.database import db
logger = logging.getLogger("rcgcdb.queue_handler") logger = logging.getLogger("rcgcdb.queue_handler")
@ -15,13 +15,14 @@ class UpdateDB:
self.updated.clear() self.updated.clear()
async def update_db(self): async def update_db(self):
async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
for update in self.updated: for update in self.updated:
if update[2] is None: if update[2] is None:
sql = "UPDATE rcgcdw SET rcid = ? WHERE wiki = ? AND ( rcid != -1 OR rcid IS NULL )" sql = "UPDATE rcgcdw SET rcid = $2 WHERE wiki = $1 AND ( rcid != -1 OR rcid IS NULL )"
else: else:
sql = "UPDATE rcgcdw SET postid = ? WHERE wiki = ? AND ( postid != '-1' OR postid IS NULL )" sql = "UPDATE rcgcdw SET postid = $2 WHERE wiki = $1 AND ( postid != '-1' OR postid IS NULL )"
await connection.execute(sql) await connection.execute(sql, update[0], update[1])
self.clear_list() self.clear_list()

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
import re import re
import logging, aiohttp import logging, aiohttp
from src.exceptions import * from src.exceptions import *
from src.database import connection from src.database import db
from src.formatters.rc import embed_formatter, compact_formatter from src.formatters.rc import embed_formatter, compact_formatter
from src.formatters.discussions import feeds_embed_formatter, feeds_compact_formatter from src.formatters.discussions import feeds_embed_formatter, feeds_compact_formatter
from src.misc import parse_link from src.misc import parse_link
@ -109,7 +109,8 @@ class Wiki:
logger.info("Removing a wiki {}".format(wiki_url)) logger.info("Removing a wiki {}".format(wiki_url))
await src.discord.wiki_removal(wiki_url, reason) await src.discord.wiki_removal(wiki_url, reason)
await src.discord.wiki_removal_monitor(wiki_url, reason) await src.discord.wiki_removal_monitor(wiki_url, reason)
result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = ?', wiki_url) async with db.pool().acquire() as connection:
result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = $1', wiki_url)
logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, wiki_url)) logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, wiki_url))
async def pull_comment(self, comment_id, WIKI_API_PATH, rate_limiter): async def pull_comment(self, comment_id, WIKI_API_PATH, rate_limiter):