From 597a907791b8c62fbee5bdad82c4649d19d8c7df Mon Sep 17 00:00:00 2001 From: Frisk Date: Sat, 3 Jul 2021 14:07:47 +0200 Subject: [PATCH] Little progress, future-proofed SQL statements from SQL injections --- src/domain.py | 2 +- src/queue_handler.py | 17 ++++++++++++++--- src/wiki.py | 18 +++++++++++------- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/domain.py b/src/domain.py index c88fcf4..9077ffa 100644 --- a/src/domain.py +++ b/src/domain.py @@ -61,7 +61,7 @@ class Domain: async def run_wiki_scan(self, wiki: src.wiki.Wiki): await self.rate_limiter.timeout_wait() - await wiki.scan(self.rate_limiter) + await wiki.scan() self.wikis.move_to_end(wiki.script_url) self.rate_limiter.timeout_add(1.0) diff --git a/src/queue_handler.py b/src/queue_handler.py index 7e2f781..e76c882 100644 --- a/src/queue_handler.py +++ b/src/queue_handler.py @@ -1,5 +1,10 @@ import asyncio +import collections import logging +from typing import Union + +import asyncpg + from src.database import db logger = logging.getLogger("rcgcdb.queue_handler") @@ -7,7 +12,7 @@ logger = logging.getLogger("rcgcdb.queue_handler") class UpdateDB: def __init__(self): - self.updated = [] + self.updated: list[tuple[str, tuple[Union[str, int]]]] = [] def add(self, sql_expression): self.updated.append(sql_expression) @@ -15,6 +20,12 @@ class UpdateDB: def clear_list(self): self.updated.clear() + async def fetch_rows(self, SQLstatement: str, args: Union[str, int]) -> collections.AsyncIterable: + async with db.pool().acquire() as connection: + async with connection.transaction(): + async for row in connection.cursor(SQLstatement, *args): + yield row + async def update_db(self): try: while True: @@ -22,7 +33,7 @@ class UpdateDB: async with db.pool().acquire() as connection: async with connection.transaction(): for update in self.updated: - await connection.execute(update) + await connection.execute(update[0], *update[1]) self.clear_list() await asyncio.sleep(10.0) except asyncio.CancelledError: @@ -30,7 +41,7 @@ class UpdateDB: async with db.pool().acquire() as connection: async with connection.transaction(): for update in self.updated: - await connection.execute(update) + await connection.execute(update[0], *update[1]) self.clear_list() await db.shutdown_connection() diff --git a/src/wiki.py b/src/wiki.py index 69acd66..e762c10 100644 --- a/src/wiki.py +++ b/src/wiki.py @@ -62,16 +62,14 @@ class Wiki: else: self.fail_times -= 1 - def generate_targets(self) -> defaultdict[namedtuple, list[str]]: + async def generate_targets(self) -> defaultdict[namedtuple, list[str]]: """This function generates all possible varations of outputs that we need to generate messages for. :returns defaultdict[namedtuple, list[str]] - where namedtuple is a named tuple with settings for given webhooks in list""" Settings = namedtuple("Settings", ["lang", "display"]) target_settings: defaultdict[Settings, list[str]] = defaultdict(list) - async with db.pool().acquire() as connection: - async with connection.transaction(): - async for webhook in connection.cursor('SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1', self.script_url): - target_settings[Settings(webhook["lang"], webhook["display"])].append(webhook["webhook"]) + async for webhook in DBHandler.fetch_rows("SELECT webhook, lang, display FROM rcgcdw WHERE wiki = $1 AND (rcid != -1 OR rcid IS NULL)", self.script_url): + target_settings[Settings(webhook["lang"], webhook["display"])].append(webhook["webhook"]) return target_settings def parse_mw_request_info(self, request_data: dict, url: str): @@ -178,7 +176,7 @@ class Wiki: raise WikiServerError return response - def scan(self): + async def scan(self): try: request = await self.fetch_wiki() except WikiServerError: @@ -202,9 +200,15 @@ class Wiki: if self.rc_id in (0, None, -1): if len(recent_changes) > 0: self.statistics.last_action = recent_changes[-1]["rcid"] + DBHandler.add(("UPDATE rcgcdw SET rcid = $1 WHERE wiki = $2 AND ( rcid != -1 OR rcid IS NULL )", + (recent_changes[-1]["rcid"], self.script_url))) else: self.statistics.last_action = 0 - DBHandler.add("UPDATE rcgcdw SET rcid = 0 WHERE wiki = {} AND ( rcid != -1 OR rcid IS NULL )".format(self.script_url)) + DBHandler.add(("UPDATE rcgcdw SET rcid = 0 WHERE wiki = $1 AND ( rcid != -1 OR rcid IS NULL )", (self.script_url))) + return # TODO Add a log entry? + categorize_events = {} + targets = await self.generate_targets() + @dataclass class Wiki_old: