Redo storing messages to store them in db instead of memory (not finished)

This commit is contained in:
Frisk 2024-08-08 14:40:45 +02:00
parent 4735e04924
commit bb7ca25a35
4 changed files with 143 additions and 67 deletions

View file

@ -59,13 +59,18 @@ class DiscordMessageMetadata:
return dict_obj return dict_obj
def matches(self, other: dict): def matches(self, other: dict):
"""Checks if all keys and values match given dictionary"""
for key, value in other.items(): for key, value in other.items():
if self.__dict__[key] != value: if isinstance(value, list) or isinstance(value, set):
return False if self.__dict__[key] not in value:
return True return False
else:
if self.__dict__[key] != value:
return False
return True
def dump_ids(self) -> (int, int, int, int): def dump_ids(self) -> (int, int, int, int):
return self.page_id, self.rev_id, self.log_id, self.message_display return self.log_id, self.page_id, self.rev_id, self.message_display
class DiscordMessage: class DiscordMessage:
@ -191,7 +196,7 @@ class MessageTooBig(BaseException):
pass pass
class StackedDiscordMessage(): class StackedDiscordMessage:
def __init__(self, m_type: int, wiki: Wiki): def __init__(self, m_type: int, wiki: Wiki):
self.message_list: list[DiscordMessage] = [] self.message_list: list[DiscordMessage] = []
self.length = 0 self.length = 0
@ -213,6 +218,12 @@ class StackedDiscordMessage():
message_structure["components"] = self.message_list[0].webhook_object["components"] message_structure["components"] = self.message_list[0].webhook_object["components"]
return json.dumps(message_structure) return json.dumps(message_structure)
def __iter__(self):
return self.message_list.__iter__()
def is_empty(self):
return len(self.message_list) == 0
def json(self) -> dict: def json(self) -> dict:
dict_obj = { dict_obj = {
"length": self.length, "length": self.length,

View file

@ -3,6 +3,7 @@ from __future__ import annotations
import inspect import inspect
import json import json
import time import time
from collections import OrderedDict
from typing import TYPE_CHECKING, Callable, Optional from typing import TYPE_CHECKING, Callable, Optional
from urllib.parse import urlparse, urlunparse from urllib.parse import urlparse, urlunparse
import logging import logging
@ -20,6 +21,8 @@ from contextlib import redirect_stdout
from src.wiki import Wiki from src.wiki import Wiki
import tldextract import tldextract
from src.statistics import Log, LogType
logger = logging.getLogger("rcgcdb.domain_manager") logger = logging.getLogger("rcgcdb.domain_manager")
@ -117,10 +120,20 @@ class DomainManager:
wiki = domain.get_wiki(split_payload[3]) wiki = domain.get_wiki(split_payload[3])
if wiki is not None: if wiki is not None:
logger.debug("Wiki specified in pub/sub message has been found. Preparing and sending dump.") logger.debug("Wiki specified in pub/sub message has been found. Preparing and sending dump.")
wiki_json = wiki.json()
try:
wiki.statistics.update(Log(type=LogType.SCAN_REASON, title="Debug request for the wiki"))
params = OrderedDict({"action": "query", "format": "json", "uselang": "content", "list": "tags|recentchanges",
"meta": "siteinfo", "utf8": 1, "rcshow": "!bot",
"rcprop": "title|redirect|timestamp|ids|loginfo|parsedcomment|sizes|flags|tags|user|userid",
"rclimit": 500, "rctype": "edit|new|log|categorize", "siprop": "namespaces|general"})
wiki_json["wiki_rc"] = await wiki.api_request(params=params, timeout=5)
except:
wiki_json["wiki_rc"] = None
json_string: str = json.dumps(wiki.json()) json_string: str = json.dumps(wiki.json())
for json_part in self.chunkstring(json_string, 7950): for json_part in self.chunkstring(json_string, 7950):
await connection.execute("select pg_notify('debugresponse', 'SITE CHUNK ' || $1 || ' ' || $2);", await connection.execute("select pg_notify('debugresponse', 'SITE CHUNK ' || $1 || ' ' || $2);",
req_id, json_part) req_id, json_part)
await connection.execute("select pg_notify('debugresponse', 'SITE END ' || $1);", await connection.execute("select pg_notify('debugresponse', 'SITE END ' || $1);",
req_id) req_id)
else: else:

View file

@ -13,7 +13,7 @@ class UpdateDB:
def __init__(self): def __init__(self):
self.updated: list[tuple[str, tuple[Union[str, int], ...]]] = [] self.updated: list[tuple[str, tuple[Union[str, int], ...]]] = []
def add(self, sql_expression: tuple[str, tuple[Union[str, int], ...]]): def add(self, sql_expression: tuple[str, tuple[Union[str, int, bytes], ...]]):
self.updated.append(sql_expression) self.updated.append(sql_expression)
def clear_list(self): def clear_list(self):
@ -32,10 +32,11 @@ class UpdateDB:
if self.updated: if self.updated:
async with db.pool().acquire() as connection: async with db.pool().acquire() as connection:
async with connection.transaction(): async with connection.transaction():
for update in self.updated: while len(self.updated) > 0:
update = self.updated[0]
logger.debug("Executing: {} {}".format(update[0], update[1])) logger.debug("Executing: {} {}".format(update[0], update[1]))
await connection.execute(update[0], *update[1]) await connection.execute(update[0], *update[1])
self.clear_list() self.updated.pop(0)
await asyncio.sleep(10.0) await asyncio.sleep(10.0)
except asyncio.CancelledError: except asyncio.CancelledError:
logger.info("Shutting down after updating DB with {} more entries...".format(len(self.updated))) logger.info("Shutting down after updating DB with {} more entries...".format(len(self.updated)))

View file

@ -2,10 +2,13 @@ from __future__ import annotations
import datetime import datetime
import functools import functools
import pickle
import time import time
import re import re
import logging, aiohttp import logging, aiohttp
import asyncio import asyncio
from contextlib import asynccontextmanager
import requests import requests
from src.api.util import default_message from src.api.util import default_message
@ -21,6 +24,7 @@ from src.discord.message import DiscordMessage, DiscordMessageMetadata, StackedD
from src.i18n import langs from src.i18n import langs
from src.statistics import Statistics, Log, LogType from src.statistics import Statistics, Log, LogType
from src.config import settings from src.config import settings
from src.database import db
# noinspection PyPackageRequirements # noinspection PyPackageRequirements
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from collections import OrderedDict, defaultdict, namedtuple from collections import OrderedDict, defaultdict, namedtuple
@ -38,6 +42,65 @@ if TYPE_CHECKING:
MESSAGE_LIMIT = settings.get("message_limit", 30) MESSAGE_LIMIT = settings.get("message_limit", 30)
class MessageHistoryRetriever:
def __init__(self, wiki: Wiki):
self.wiki = wiki
def __len__(self):
return NotImplementedError
async def find_all_revids(self, page_id: int) -> list[int]:
result = []
async for item in dbmanager.fetch_rows(f"SELECT DISTINCT rev_id FROM rcgcdb_msg_metadata INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = $1 AND page_id = $2", (self.wiki.script_url, page_id)):
result.append(item["rev_id"])
return result
@asynccontextmanager
async def fetch_stacked_from_db(self, params: dict[str, Union[str, list]]) -> tuple[StackedDiscordMessage, str]:
# All relevant fields:
# message_display 0-3,
# log_id
# page_id
# rev_id
# stacked_index
# async for message in dbmanager.fetch_rows(f"SELECT {', '.join(['{key} = ${val}'.format(key=key, val=num) for num, key in enumerate(params.keys())])}", (*params.values(),)): # What is this monster
message_cache = [] # Temporary message storage required to write the messages back to the DB after potential changes
async with db.pool().acquire() as connection:
async with connection.transaction():
query_template, query_parameters = [], []
for query_key, query_parameter in params.items():
if isinstance(query_parameter, str):
query_template.append(f"{query_key} = ${len(query_parameters)+1}")
query_parameters.append(query_parameter)
else: # an iterator/list
query_template.append(f"{query_key} IN ({', '.join(['${}'.format(x+len(query_parameters)+1) for x in range(len(query_parameter))])})")
query_parameters.extend(query_parameter)
async for stacked_message in connection.cursor(f"SELECT message_id, webhook, message_object FROM rcgcdb_msg_history INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = {self.wiki.script_url} AND {" AND ".join(query_template)}", query_parameters):
unpickled_message = pickle.loads(stacked_message["message_object"])
yield unpickled_message, stacked_message["webhook"]
await self.update_message(unpickled_message, connection, stacked_message["message_id"])
@staticmethod
async def update_message(stacked_message: StackedDiscordMessage, connection, message_id: str):
if stacked_message.is_empty():
await connection.execute("DELETE FROM rcgcdb_msg_history WHERE message_id = $1;", message_id)
else:
await connection.execute("UPDATE rcgcdb_msg_history SET message_object = $1 WHERE message_id = $2;", pickle.dumps(stacked_message), message_id)
await connection.execute("DELETE FROM rcgcdb_msg_metadata WHERE message_id = $1;", message_id)
await connection.executemany("INSERT INTO rcgcdb_msg_metadata(message_id, log_id, page_id, rev_id, message_display, stacked_index) VALUES ($1, $2, $3, $4, $5, $6);",
list(((message_id, *x.metadata.dump_ids(), num) for num, x in enumerate(stacked_message.message_list))))
@staticmethod
def register_message(stacked_message: StackedDiscordMessage):
"""Registers a message in the database"""
dbmanager.add(("INSERT INTO rcgcdb_msg_history(message_id, webhook, message_object) VALUES ($1, $2, $3);",
(stacked_message.discord_callback_message_id, stacked_message.webhook, pickle.dumps(stacked_message))))
for stack_id, message in enumerate(stacked_message):
dbmanager.add(("INSERT INTO rcgcdb_msg_metadata(message_id, log_id, page_id, rev_id, message_display, stacked_index) VALUES ($1, $2, $3, $4, $5, $6);",
(stacked_message.discord_callback_message_id, *message.metadata.dump_ids(), stack_id)))
class Wiki: class Wiki:
def __init__(self, script_url: str, rc_id: Optional[int], discussion_id: Optional[str]): def __init__(self, script_url: str, rc_id: Optional[int], discussion_id: Optional[str]):
self.script_url: str = script_url self.script_url: str = script_url
@ -49,7 +112,7 @@ class Wiki:
self.rc_targets: Optional[defaultdict[Settings, list[str]]] = None self.rc_targets: Optional[defaultdict[Settings, list[str]]] = None
self.discussion_targets: Optional[defaultdict[Settings, list[str]]] = None self.discussion_targets: Optional[defaultdict[Settings, list[str]]] = None
self.client: Client = Client(formatter_hooks, self) self.client: Client = Client(formatter_hooks, self)
self.message_history: list[StackedDiscordMessage] = list() self.message_history: MessageHistoryRetriever = MessageHistoryRetriever(self)
self.namespaces: Optional[dict] = None self.namespaces: Optional[dict] = None
self.recache_requested: bool = False self.recache_requested: bool = False
self.session_requests = requests.Session() self.session_requests = requests.Session()
@ -109,78 +172,67 @@ class Wiki:
# logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, self.script_url)) # logger.warning('{} rows affected by DELETE FROM rcgcdw WHERE wiki = "{}"'.format(result, self.script_url))
def add_message(self, message: StackedDiscordMessage): def add_message(self, message: StackedDiscordMessage):
self.message_history.append(message) self.message_history.register_message(message)
if len(self.message_history) > MESSAGE_LIMIT*len(self.rc_targets):
self.message_history = self.message_history[len(self.message_history)-MESSAGE_LIMIT*len(self.rc_targets):]
def set_domain(self, domain: Domain): def set_domain(self, domain: Domain):
self.domain = domain self.domain = domain
def find_middle_next(self, ids: List[str], pageid: int) -> list: async def find_middle_next(self, ids: List[str], pageid: int) -> set:
"""To address #235 RcGcDw should now remove diffs in next revs relative to redacted revs to protect information in revs that revert revdeleted information. """To address #235 RcGcDw should now remove diffs in next revs relative to redacted revs to protect information in revs that revert revdeleted information.
What this function does, is it fetches all messages for given page and finds revids of the messages that come next after ids What this function does, is it fetches all messages for given page and finds revids of the messages that come next after ids
:arg ids - list :arg ids - list
:arg pageid - int :arg pageid - int
:return list""" :return set"""
def extract_revid(item: tuple[StackedDiscordMessage, list[int]]):
rev_ids = set()
for message_id in sorted(item[1], reverse=True):
rev_ids.add(item[0].message_list[message_id].metadata.rev_id)
return rev_ids
ids = [int(x) for x in ids] ids = [int(x) for x in ids]
result = set() result = set()
ids.sort() # Just to be sure, sort the list to make sure it's always sorted ids.sort() # Just to be sure, sort the list to make sure it's always sorted
search = self.search_message_history({"message_display": 3, "page_id": pageid}) all_revids = sorted(await self.message_history.find_all_revids(pageid))
# messages = db_cursor.execute("SELECT revid FROM event WHERE pageid = ? AND revid >= ? ORDER BY revid",
# (pageid, ids[0],))
all_in_page = sorted(set([x for row in map(extract_revid, search) for x in row])) # Flatten the result
for ID in ids: for ID in ids:
try: try:
result.add(all_in_page[all_in_page.index(ID) + 1]) result.add(all_revids[all_revids.index(ID) + 1])
except (KeyError, ValueError): except (KeyError, ValueError):
logger.debug(f"Value {ID} not in {all_in_page} or no value after that.") logger.debug(f"Value {ID} not in {all_revids} or no value after that.")
return list(result - set(ids)) return result
def search_message_history(self, params: dict) -> list[tuple[StackedDiscordMessage, list[int]]]: def search_message_history(self, stacked_message: StackedDiscordMessage, params: dict) -> list[int]:
"""Search self.message_history for messages which match all properties in params and return them in a list """Search self.message_history for messages which match all properties in params and return ids of those in a list
:param params is a dictionary of which messages are compared against. All name and values must be equal for match to return true :param params is a dictionary of which messages are compared against. All name and values must be equal for match to return true
Matches metadata from discord.message.DiscordMessageMetadata Matches metadata from discord.message.DiscordMessageMetadata
:returns [(StackedDiscordMessage, [index ids of matching messages in that StackedMessage])]""" :returns [index ids of matching messages in that StackedMessage]"""
output = [] output = []
for message in self.message_history: for num, message in enumerate(stacked_message.message_list):
returned_matches_for_stacked = message.filter(params) if message.metadata.matches(params):
if returned_matches_for_stacked: output.append(num)
output.append((message, [x[0] for x in returned_matches_for_stacked]))
return output return output
def delete_messages(self, params: dict): async def delete_messages(self, params: dict):
"""Delete certain messages from message_history which DiscordMessageMetadata matches all properties in params""" """Delete certain messages from message_history which DiscordMessageMetadata matches all properties in params"""
# Delete all messages with given IDs # Delete all messages with given IDs
for stacked_message, ids in self.search_message_history(params): async for stacked_message in self.message_history.fetch_stacked_from_db(params):
stacked_message.delete_message_by_id(ids) for message_ids in self.search_message_history(stacked_message, params):
# If all messages were removed, send a DELETE to Discord stacked_message.delete_message_by_id(message_ids)
if len(stacked_message.message_list) == 0: # If all messages were removed, send a DELETE to Discord
messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="DELETE")) if len(stacked_message.message_list) == 0:
else: messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="DELETE"))
messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH")) else:
messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH"))
def redact_messages(self, context: Context, ids: list[int], mode: str, censored_properties: dict): async def redact_messages(self, context: Context, ids: list[int] | set[int], mode: str, censored_properties: dict):
# ids can refer to multiple events, and search does not support additive mode, so we have to loop it for all ids # ids can refer to multiple events, and search does not support additive mode, so we have to loop it for all ids
for revlogid in ids: async for stacked_message, webhook in self.message_history.fetch_stacked_from_db({mode: ids, "message_display": [1, 2, 3]}):
for stacked_message, ids in self.search_message_history({mode: revlogid}): # This might not work depending on how Python handles it, but hey, learning experience for message in [message for message in stacked_message.message_list if message.metadata.matches({mode: ids})]:
for message in [message for num, message in enumerate(stacked_message.message_list) if num in ids]: if "user" in censored_properties and "url" in message["author"]:
if "user" in censored_properties and "url" in message["author"]: message["author"]["name"] = context._("hidden")
message["author"]["name"] = context._("hidden") message["author"].pop("url")
message["author"].pop("url") if "action" in censored_properties and "url" in message:
if "action" in censored_properties and "url" in message: message["title"] = context._("~~hidden~~")
message["title"] = context._("~~hidden~~") message["embed"].pop("url")
message["embed"].pop("url") if "content" in censored_properties and "fields" in message:
if "content" in censored_properties and "fields" in message: message["embed"].pop("fields")
message["embed"].pop("fields") if "comment" in censored_properties:
if "comment" in censored_properties: message["description"] = context._("~~hidden~~")
message["description"] = context._("~~hidden~~") messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH"))
messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH"))
# async def downtime_controller(self, down, reason=None): # async def downtime_controller(self, down, reason=None):
# if down: # if down:
@ -198,7 +250,7 @@ class Wiki:
discussion_targets: defaultdict[Settings, list[str]] = defaultdict(list) discussion_targets: defaultdict[Settings, list[str]] = defaultdict(list)
async for webhook in dbmanager.fetch_rows("SELECT webhook, lang, display, rcid, postid, buttons FROM rcgcdb WHERE wiki = $1", self.script_url): async for webhook in dbmanager.fetch_rows("SELECT webhook, lang, display, rcid, postid, buttons FROM rcgcdb WHERE wiki = $1", self.script_url):
if webhook['rcid'] == -1 and webhook['postid'] == '-1': if webhook['rcid'] == -1 and webhook['postid'] == '-1':
await self.remove_wiki_from_db(4) await self.remove_webhook_from_db(webhook['webhook'], "Webhook has invalid settings", True)
if webhook['rcid'] != -1: if webhook['rcid'] != -1:
target_settings[Settings(webhook["lang"], webhook["display"], webhook["buttons"])].append(webhook["webhook"]) target_settings[Settings(webhook["lang"], webhook["display"], webhook["buttons"])].append(webhook["webhook"])
if webhook['postid'] != '-1': if webhook['postid'] != '-1':
@ -418,6 +470,7 @@ class Wiki:
async def remove_webhook_from_db(self, webhook_url: str, reason: str, send_reason=False): async def remove_webhook_from_db(self, webhook_url: str, reason: str, send_reason=False):
logger.info(f"Removing a webhook with ID of {webhook_url.split("/")[0]} from the database due to {reason}.") logger.info(f"Removing a webhook with ID of {webhook_url.split("/")[0]} from the database due to {reason}.")
# TODO Write a reason for removal to a webhook if send_reason
dbmanager.add(("DELETE FROM rcgcdb WHERE webhook = $1", (webhook_url,))) dbmanager.add(("DELETE FROM rcgcdb WHERE webhook = $1", (webhook_url,)))
async def remove_wiki_from_db(self, reason: str): async def remove_wiki_from_db(self, reason: str):
@ -535,24 +588,22 @@ async def rc_processor(wiki: Wiki, change: dict, changed_categories: dict, displ
else: else:
raise raise
if identification_string in ("delete/delete", "delete/delete_redir"): # TODO Move it into a hook? if identification_string in ("delete/delete", "delete/delete_redir"): # TODO Move it into a hook?
wiki.delete_messages(dict(page_id=change.get("pageid"))) await wiki.delete_messages(dict(page_id=change.get("pageid")))
elif identification_string == "delete/event": elif identification_string == "delete/event":
logparams = change.get('logparams', {"ids": []}) logparams = change.get('logparams', {"ids": []})
if context.message_type == "embed": await wiki.redact_messages(context, logparams.get("ids", []), "log_id", logparams.get("new", {}))
wiki.redact_messages(context, logparams.get("ids", []), "log_id", logparams.get("new", {})) await wiki.delete_messages(dict(log_id=logparams.get("ids", []), message_display=0))
else:
for logid in logparams.get("ids", []):
wiki.delete_messages(dict(logid=logid))
elif identification_string == "delete/revision": elif identification_string == "delete/revision":
logparams = change.get('logparams', {"ids": []}) logparams = change.get('logparams', {"ids": []})
if context.message_type == "embed": if context.message_type == "embed":
wiki.redact_messages(context, logparams.get("ids", []), "rev_id", logparams.get("new", {}))
if display_options.display == 3: if display_options.display == 3:
wiki.redact_messages(context, wiki.find_middle_next(logparams.get("ids", []), change.get("pageid", -1)), "rev_id", await wiki.redact_messages(context, (await wiki.find_middle_next(logparams.get("ids", []), change.get("pageid", -1))).union(logparams.get("ids", [])), "rev_id",
{"content": ""}) {"content": ""})
else:
await wiki.redact_messages(context, logparams.get("ids", []), "rev_id", logparams.get("new", {}))
else: else:
for revid in logparams.get("ids", []): for revid in logparams.get("ids", []):
wiki.delete_messages(dict(revid=revid)) await wiki.delete_messages(dict(rev_id=revid))
run_hooks(post_hooks, discord_message, metadata, context, change) run_hooks(post_hooks, discord_message, metadata, context, change)
if discord_message: # TODO How to react when none? (crash in formatter), probably bad handling atm if discord_message: # TODO How to react when none? (crash in formatter), probably bad handling atm
discord_message.finish_embed() discord_message.finish_embed()