Added PostgreSQL compatibility, dropped SQLite compatibility

This commit is contained in:
Frisk 2021-03-20 13:42:54 +01:00
parent 2c8574445c
commit a2f1d54f39
No known key found for this signature in database
GPG key ID: 213F7C15068AF8AC
6 changed files with 227 additions and 193 deletions

View file

@ -8,6 +8,10 @@
"monitoring_webhook": "111111111111111111/AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA",
"support": "https://discord.gg/v77RTk5",
"irc_overtime": 3600,
"pg_user": "postgres",
"pg_host": "localhost",
"pg_db": "rcgcdb",
"pg_pass": "secret_password",
"irc_servers": {
"your custom name for the farm": {
"domains": ["wikipedia.org", "otherwikipedia.org"],

View file

@ -11,7 +11,7 @@ from typing import Generator
from contextlib import asynccontextmanager
from src.argparser import command_line_args
from src.config import settings
from src.database import connection, setup_connection, shutdown_connection
from src.database import db
from src.exceptions import *
from src.misc import get_paths, get_domain
from src.msgqueue import messagequeue, send_to_discord
@ -39,7 +39,9 @@ main_tasks: dict = {}
# First populate the all_wikis list with every wiki
# Reasons for this: 1. we require amount of wikis to calculate the cooldown between requests
# 2. Easier to code
async def populate_allwikis():
async with db.pool().acquire() as connection:
async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, rcid FROM rcgcdw'):
all_wikis[db_wiki["wiki"]] = Wiki() # populate all_wikis
@ -104,6 +106,7 @@ class RcQueue:
del self.domain_list[group]
async def check_if_domain_in_db(self, domain):
async with db.pool().acquire() as connection:
async with connection.transaction():
async for wiki in connection.cursor('SELECT DISTINCT wiki FROM rcgcdw WHERE rcid != -1;'):
if get_domain(wiki["wiki"]) == domain:
@ -146,6 +149,7 @@ class RcQueue:
try:
self.to_remove = [x[0] for x in filter(self.filter_rc_active, all_wikis.items())] # first populate this list and remove wikis that are still in the db, clean up the rest
full = set()
async with db.pool().acquire() as connection:
async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, row_number() over (ORDER BY webhook) AS ROWID, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL order by webhook'):
domain = get_domain(db_wiki["wiki"])
@ -232,6 +236,7 @@ async def generate_targets(wiki_url: str, additional_requirements: str) -> defau
request to the wiki just to duplicate the message.
"""
combinations = defaultdict(list)
async with db.pool().acquire() as connection:
async with connection.transaction():
async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 {}'.format(additional_requirements), wiki_url):
combination = (webhook["lang"], webhook["display"])
@ -244,6 +249,7 @@ async def generate_domain_groups():
:returns tuple[str, list]"""
domain_wikis = defaultdict(list)
async with db.pool().acquire() as connection:
async with connection.transaction():
async for db_wiki in connection.cursor('SELECT DISTINCT wiki, webhook, lang, display, rcid FROM rcgcdw WHERE rcid != -1 OR rcid IS NULL'):
domain_wikis[get_domain(db_wiki["wiki"])].append(QueuedWiki(db_wiki["wiki"], 20))
@ -396,6 +402,7 @@ async def message_sender():
async def discussion_handler():
try:
while True:
async with db.pool().acquire() as connection:
async with connection.transaction():
async for db_wiki in connection.cursor("SELECT DISTINCT wiki, rcid, postid FROM rcgcdw WHERE postid != '-1' OR postid IS NULL"):
try:
@ -424,7 +431,7 @@ async def discussion_handler():
error = discussion_feed_resp["error"]
if error == "NotFoundException": # Discussions disabled
if db_wiki["rcid"] != -1: # RC feed is disabled
await connection.execute("UPDATE rcgcdw SET postid = ? WHERE wiki = ?",
await connection.execute("UPDATE rcgcdw SET postid = $1 WHERE wiki = $2",
("-1", db_wiki["wiki"],))
else:
await local_wiki.remove(db_wiki["wiki"], 1000)
@ -543,8 +550,8 @@ async def main_loop():
global main_tasks
loop = asyncio.get_event_loop()
nest_asyncio.apply(loop)
await setup_connection()
logger.debug("Connection type: {}".format(connection))
await db.setup_connection()
logger.debug("Connection type: {}".format(db.connection))
await populate_allwikis()
try:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
@ -562,7 +569,7 @@ async def main_loop():
await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"])
except KeyboardInterrupt:
await DBHandler.update_db()
await shutdown_connection()
await db.shutdown_connection()
shutdown(loop)
except asyncio.CancelledError:
return

View file

@ -4,19 +4,38 @@ from typing import Optional
from src.config import settings
logger = logging.getLogger("rcgcdb.database")
connection: Optional[asyncpg.Connection] = None
# connection: Optional[asyncpg.Connection] = None
async def setup_connection():
global connection
class db_connection:
connection: Optional[asyncpg.Pool] = None
async def setup_connection(self):
# Establish a connection to an existing database named "test"
# as a "postgres" user.
logger.debug("Setting up the Database connection...")
connection = await asyncpg.connect(user=settings["pg_user"], host=settings.get("pg_host", "localhost"),
self.connection = await asyncpg.create_pool(user=settings["pg_user"], host=settings.get("pg_host", "localhost"),
database=settings.get("pg_db", "rcgcdb"), password=settings.get("pg_pass"))
logger.debug("Database connection established! Connection: {}".format(connection))
logger.debug("Database connection established! Connection: {}".format(self.connection))
async def shutdown_connection(self):
await self.connection.close()
def pool(self) -> asyncpg.Pool:
return self.connection
# Tried to make it a decorator but tbh won't probably work
# async def in_transaction(self, func):
# async def single_transaction():
# async with self.connection.acquire() as connection:
# async with connection.transaction():
# await func()
# return single_transaction
# async def query(self, string, *arg):
# async with self.connection.acquire() as connection:
# async with connection.transaction():
# return connection.cursor(string, *arg)
async def shutdown_connection():
global connection
await connection.close()
db = db_connection()

View file

@ -3,7 +3,7 @@ from collections import defaultdict
from src.misc import logger
from src.config import settings
from src.database import connection
from src.database import db
from src.i18n import langs
from src.exceptions import EmbedListFull
from asyncio import TimeoutError
@ -22,8 +22,9 @@ default_header["X-RateLimit-Precision"] = "millisecond"
# User facing webhook functions
async def wiki_removal(wiki_url, status):
async with db.pool().acquire() as connection:
async with connection.transaction():
async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = ?', wiki_url):
async for observer in connection.cursor('SELECT webhook, lang FROM rcgcdw WHERE wiki = $1', wiki_url):
_ = langs[observer["lang"]]["discord"].gettext
reasons = {410: _("wiki deleted"), 404: _("wiki deleted"), 401: _("wiki inaccessible"),
402: _("wiki inaccessible"), 403: _("wiki inaccessible"), 1000: _("discussions disabled")}
@ -239,7 +240,8 @@ async def handle_discord_http(code: int, formatted_embed: str, result: aiohttp.C
return 1
elif code == 401 or code == 404: # HTTP UNAUTHORIZED AND NOT FOUND
logger.error("Webhook URL is invalid or no longer in use, please replace it with proper one.")
await connection.execute("DELETE FROM rcgcdw WHERE webhook = ?", (webhook_url,))
async with db.pool().acquire() as connection:
await connection.execute("DELETE FROM rcgcdw WHERE webhook = $1", (webhook_url,))
await webhook_removal_monitor(webhook_url, code)
return 1
elif code == 429:

View file

@ -1,5 +1,5 @@
import logging
from src.database import connection
from src.database import db
logger = logging.getLogger("rcgcdb.queue_handler")
@ -15,13 +15,14 @@ class UpdateDB:
self.updated.clear()
async def update_db(self):
async with db.pool().acquire() as connection:
async with connection.transaction():
for update in self.updated:
if update[2] is None:
sql = "UPDATE rcgcdw SET rcid = ? WHERE wiki = ? AND ( rcid != -1 OR rcid IS NULL )"
sql = "UPDATE rcgcdw SET rcid = $2 WHERE wiki = $1 AND ( rcid != -1 OR rcid IS NULL )"
else:
sql = "UPDATE rcgcdw SET postid = ? WHERE wiki = ? AND ( postid != '-1' OR postid IS NULL )"
await connection.execute(sql)
sql = "UPDATE rcgcdw SET postid = $2 WHERE wiki = $1 AND ( postid != '-1' OR postid IS NULL )"
await connection.execute(sql, update[0], update[1])
self.clear_list()

View file

@ -2,7 +2,7 @@ from dataclasses import dataclass
import re
import logging, aiohttp
from src.exceptions import *
from src.database import connection
from src.database import db
from src.formatters.rc import embed_formatter, compact_formatter
from src.formatters.discussions import feeds_embed_formatter, feeds_compact_formatter
from src.misc import parse_link
@ -109,7 +109,8 @@ class Wiki:
logger.info("Removing a wiki {}".format(wiki_url))
await src.discord.wiki_removal(wiki_url, reason)
await src.discord.wiki_removal_monitor(wiki_url, reason)
result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = ?', wiki_url)
async with db.pool().acquire() as connection:
result = await connection.execute('DELETE FROM rcgcdw WHERE wiki = $1', wiki_url)
logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, wiki_url))
async def pull_comment(self, comment_id, WIKI_API_PATH, rate_limiter):