From bbb81469eba0727c42a831d09d547195b5b7857f Mon Sep 17 00:00:00 2001 From: Frisk Date: Fri, 9 Aug 2024 14:00:31 +0200 Subject: [PATCH] Fixes and progress --- src/config.py | 2 +- src/discord/message.py | 15 ++++++++ src/discord/queue.py | 24 +++++++------ src/domain.py | 4 ++- src/domain_manager.py | 2 +- src/wiki.py | 78 +++++++++++++++++++++++++----------------- 6 files changed, 80 insertions(+), 45 deletions(-) diff --git a/src/config.py b/src/config.py index cca038b..84a1c59 100644 --- a/src/config.py +++ b/src/config.py @@ -4,7 +4,7 @@ try: # load settings with open("settings.json", encoding="utf8") as sfile: settings = json.load(sfile) if "user-agent" in settings["header"]: - settings["header"]["user-agent"] = settings["header"]["user-agent"].format(version="1.9.1 Beta") # set the version in the useragent + settings["header"]["user-agent"] = settings["header"]["user-agent"].format(version="2.0 Beta") # set the version in the useragent except FileNotFoundError: logging.critical("No config file could be found. Please make sure settings.json is in the directory.") sys.exit(1) \ No newline at end of file diff --git a/src/discord/message.py b/src/discord/message.py index ca91b92..662b0ea 100644 --- a/src/discord/message.py +++ b/src/discord/message.py @@ -47,6 +47,11 @@ class DiscordMessageMetadata: def __str__(self): return f"" + def __getstate__(self): + obj_copy = self.__dict__.copy() + del obj_copy["domain"] + return obj_copy + def json(self) -> dict: dict_obj = { "method": self.method, @@ -111,6 +116,11 @@ class DiscordMessage: def __len__(self): return self.length + def __getstate__(self): + obj_copy = self.__dict__.copy() + del obj_copy['wiki'] + return obj_copy + def json(self): dict_obj = { "length": self.length, @@ -221,6 +231,11 @@ class StackedDiscordMessage: def __iter__(self): return self.message_list.__iter__() + def __getstate__(self): + obj_copy = self.__dict__.copy() + del obj_copy['wiki'] + return obj_copy + def is_empty(self): return len(self.message_list) == 0 diff --git a/src/discord/queue.py b/src/discord/queue.py index 586de5c..4eae6c2 100644 --- a/src/discord/queue.py +++ b/src/discord/queue.py @@ -161,7 +161,7 @@ class MessageQueue: StackedDiscordMessages using self.pack_messages or returns a message if it's not POST request. If stacked message succeeds in changing, its status in _queue is changed to sent for given webhook.""" - webhook_url, messages = msg_set # str("daosdkosakda/adkahfwegr34", list(DiscordMessage, DiscordMessage, DiscordMessage) + webhook_url, messages = msg_set # str("daosdkosakda/adkahfwegr34", list(QueueEntry, QueueEntry, QueueEntry) async for msg, index, method in self.pack_massages(messages): if msg is None: # Msg can be None if last message was not POST continue @@ -176,7 +176,7 @@ class MessageQueue: client_error = True except (aiohttp.ServerConnectionError, aiohttp.ServerTimeoutError, asyncio.TimeoutError): # Retry on next Discord message sent attempt - logger.debug(f"Received timeout or connection error when sending a Discord message for {msg.wiki.script_url}.") + logger.debug(f"Received timeout or connection error when sending a Discord message for {msg.wiki.script_url if hasattr(msg, "wiki") else "PATCH OR DELETE MESSAGE"}.") return except ExhaustedDiscordBucket as e: if e.is_global: @@ -184,7 +184,7 @@ class MessageQueue: await asyncio.sleep(e.remaining / 1000) return else: - if status == 0: + if status == 0 and method == "POST": message = None for message in msg.message_list: if message.metadata.domain is not None and message.metadata.time_of_change is not None: @@ -193,17 +193,19 @@ class MessageQueue: if message and message.metadata.domain is not None: message.metadata.domain.discord_message_registration() if client_error is False: - msg.webhook = webhook_url - msg.wiki.add_message(msg) + if method == "POST": + msg.webhook = webhook_url + msg.wiki.add_message(msg) for queue_message in messages[max(index-len(msg.message_list), 0):index+1]: queue_message.confirm_sent_status(webhook_url) else: - webhook_id = webhook_url.split("/")[0] - if webhook_id in self.webhook_suspensions: - await msg.wiki.remove_webhook_from_db(webhook_url, "Attempts to send a message to a webhook result in client error.", send=False) - self.webhook_suspensions[webhook_id].cancel() - else: - self.webhook_suspensions[webhook_id] = asyncio.create_task(self.suspension_check(webhook_url), name="DC Sus Check for {}".format(webhook_id)) + if hasattr(msg, "wiki"): # PATCH and DELETE can not have wiki attribute + webhook_id = webhook_url.split("/")[0] + if webhook_id in self.webhook_suspensions: + await msg.wiki.remove_webhook_from_db(webhook_url, "Attempts to send a message to a webhook result in client error.", send=False) + self.webhook_suspensions[webhook_id].cancel() + else: + self.webhook_suspensions[webhook_id] = asyncio.create_task(self.suspension_check(webhook_url), name="DC Sus Check for {}".format(webhook_id)) async def resend_msgs(self): """Main function for orchestrating Discord message sending. It's a task that runs every half a second.""" diff --git a/src/domain.py b/src/domain.py index 627627b..2d4b1f4 100644 --- a/src/domain.py +++ b/src/domain.py @@ -187,6 +187,7 @@ class Domain: except Exception as e: if command_line_args.debug: logger.exception("IRC scheduler task for domain {} failed!".format(self.name)) + raise asyncio.exceptions.CancelledError() else: # production if not (time.time()-172800 > self.last_failure_report): # If we haven't reported for more than 2 days or at all return @@ -199,11 +200,12 @@ class Domain: async def regular_scheduler(self): try: while True: - await asyncio.sleep(self.calculate_sleep_time(len(self))) # To make sure that we don't spam domains with one wiki every second we calculate a sane timeout for domains with few wikis await self.run_wiki_scan(next(iter(self.wikis.values())), "regular check") + await asyncio.sleep(self.calculate_sleep_time(len(self))) # To make sure that we don't spam domains with one wiki every second we calculate a sane timeout for domains with few wikis except Exception as e: if command_line_args.debug: logger.exception("Regular scheduler task for domain {} failed!".format(self.name)) + raise asyncio.exceptions.CancelledError() else: if not (time.time()-172800 > self.last_failure_report): # If we haven't reported for more than 2 days or at all return diff --git a/src/domain_manager.py b/src/domain_manager.py index 031cc82..fd697af 100644 --- a/src/domain_manager.py +++ b/src/domain_manager.py @@ -107,7 +107,7 @@ class DomainManager: for name, domain in self.domains.items(): json_object["domains"][name] = domain.json() for message in messagequeue._queue: - json_object["queued_messages"].append({"metadata": str(message.discord_message.metadata), "url": message.wiki.script_url}) + json_object["queued_messages"].append({"metadata": str(message.discord_message.metadata), "url": message.wiki.script_url if hasattr(message, "wiki") else "#######"}) req_id: str = split_payload[2] json_string: str = json.dumps(json_object) for json_part in self.chunkstring(json_string, 7950): diff --git a/src/wiki.py b/src/wiki.py index 5d1f106..f0cc227 100644 --- a/src/wiki.py +++ b/src/wiki.py @@ -43,6 +43,33 @@ MESSAGE_LIMIT = settings.get("message_limit", 30) class MessageHistoryRetriever: + def __init__(self, wiki: Wiki, message_history: MessageHistory, params: dict[str, Union[str, list]]): + self.wiki = wiki + self.params = params + self.message_history_obj = message_history + + async def __aiter__(self): + async with db.pool().acquire() as connection: + async with connection.transaction(): + query_template, query_parameters = [], [] + for query_key, query_parameter in self.params.items(): + if isinstance(query_parameter, str) or isinstance(query_parameter, int): + query_template.append(f"{query_key} = ${len(query_parameters) + 1}") + query_parameters.append(query_parameter) + else: # an iterator/list + query_template.append( + f"{query_key} IN ({', '.join(['${}'.format(x + len(query_parameters) + 1) for x in range(len(query_parameter))])})") + query_parameters.extend(query_parameter) + logger.debug(f"SELECT rcgcdb_msg_history.message_id, rcgcdb_msg_history.webhook, message_object FROM rcgcdb_msg_history INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = '{self.wiki.script_url}' AND {" AND ".join(query_template)}") + logger.debug(query_parameters) + async for stacked_message in connection.cursor( + f"SELECT rcgcdb_msg_history.message_id, rcgcdb_msg_history.webhook, message_object FROM rcgcdb_msg_history INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = '{self.wiki.script_url}' AND {" AND ".join(query_template)}", + *query_parameters): + unpickled_message = pickle.loads(stacked_message["message_object"]) + yield unpickled_message, stacked_message["webhook"] + await self.message_history_obj.update_message(unpickled_message, connection, stacked_message["message_id"]) + +class MessageHistory: def __init__(self, wiki: Wiki): self.wiki = wiki @@ -50,13 +77,13 @@ class MessageHistoryRetriever: return NotImplementedError async def find_all_revids(self, page_id: int) -> list[int]: + """Function to find all revisions for a page in message history""" result = [] async for item in dbmanager.fetch_rows(f"SELECT DISTINCT rev_id FROM rcgcdb_msg_metadata INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = $1 AND page_id = $2", (self.wiki.script_url, page_id)): result.append(item["rev_id"]) return result - @asynccontextmanager - async def fetch_stacked_from_db(self, params: dict[str, Union[str, list]]) -> tuple[StackedDiscordMessage, str]: + def fetch_stacked_from_db(self, params: dict[str, Union[str, list]]): # All relevant fields: # message_display 0-3, # log_id @@ -64,21 +91,7 @@ class MessageHistoryRetriever: # rev_id # stacked_index # async for message in dbmanager.fetch_rows(f"SELECT {', '.join(['{key} = ${val}'.format(key=key, val=num) for num, key in enumerate(params.keys())])}", (*params.values(),)): # What is this monster - message_cache = [] # Temporary message storage required to write the messages back to the DB after potential changes - async with db.pool().acquire() as connection: - async with connection.transaction(): - query_template, query_parameters = [], [] - for query_key, query_parameter in params.items(): - if isinstance(query_parameter, str): - query_template.append(f"{query_key} = ${len(query_parameters)+1}") - query_parameters.append(query_parameter) - else: # an iterator/list - query_template.append(f"{query_key} IN ({', '.join(['${}'.format(x+len(query_parameters)+1) for x in range(len(query_parameter))])})") - query_parameters.extend(query_parameter) - async for stacked_message in connection.cursor(f"SELECT message_id, webhook, message_object FROM rcgcdb_msg_history INNER JOIN rcgcdb_msg_metadata ON rcgcdb_msg_metadata.message_id = rcgcdb_msg_history.message_id INNER JOIN rcgcdb ON rcgcdb_msg_history.webhook = rcgcdb.webhook WHERE rcgcdb.wiki = {self.wiki.script_url} AND {" AND ".join(query_template)}", query_parameters): - unpickled_message = pickle.loads(stacked_message["message_object"]) - yield unpickled_message, stacked_message["webhook"] - await self.update_message(unpickled_message, connection, stacked_message["message_id"]) + return MessageHistoryRetriever(self.wiki, self, params) @staticmethod @@ -112,7 +125,7 @@ class Wiki: self.rc_targets: Optional[defaultdict[Settings, list[str]]] = None self.discussion_targets: Optional[defaultdict[Settings, list[str]]] = None self.client: Client = Client(formatter_hooks, self) - self.message_history: MessageHistoryRetriever = MessageHistoryRetriever(self) + self.message_history: MessageHistory = MessageHistory(self) self.namespaces: Optional[dict] = None self.recache_requested: bool = False self.session_requests = requests.Session() @@ -177,7 +190,7 @@ class Wiki: def set_domain(self, domain: Domain): self.domain = domain - async def find_middle_next(self, ids: List[str], pageid: int) -> set: + async def find_middle_next(self, ids: List[str | int], pageid: int) -> set: """To address #235 RcGcDw should now remove diffs in next revs relative to redacted revs to protect information in revs that revert revdeleted information. What this function does, is it fetches all messages for given page and finds revids of the messages that come next after ids :arg ids - list @@ -218,7 +231,7 @@ class Wiki: else: messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH")) - async def redact_messages(self, context: Context, ids: list[int] | set[int], mode: str, censored_properties: dict): + async def redact_messages(self, context: Context, ids: list[int] | set[int], mode: str, censored_properties: dict, page_id=None): # ids can refer to multiple events, and search does not support additive mode, so we have to loop it for all ids async for stacked_message, webhook in self.message_history.fetch_stacked_from_db({mode: ids, "message_display": [1, 2, 3]}): for message in [message for message in stacked_message.message_list if message.metadata.matches({mode: ids})]: @@ -232,7 +245,17 @@ class Wiki: message["embed"].pop("fields") if "comment" in censored_properties: message["description"] = context._("~~hidden~~") + logger.debug(f"Rev-deleting contents of message {stacked_message.discord_callback_message_id} due to being in list of ids {ids}.") messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH")) + if mode == "rev_id" and "content" in censored_properties: # Erase content field from messages coming AFTER revdel'd edit + middle_ids = (await self.find_middle_next(ids, page_id)).difference(ids) + async for stacked_message, webhook in self.message_history.fetch_stacked_from_db({"rev_id": middle_ids, "message_display": 3}): + for message in [message for message in stacked_message.message_list if + message.metadata.matches({"rev_id": middle_ids})]: + if "content" in censored_properties and "fields" in message: + message["embed"].pop("fields") + logger.debug(f"Rev-deleting content of message {stacked_message.discord_callback_message_id} due to being in list of middle ids {middle_ids}.") + messagequeue.add_message(QueueEntry(stacked_message, [stacked_message.webhook], self, method="PATCH")) # async def downtime_controller(self, down, reason=None): # if down: @@ -588,22 +611,15 @@ async def rc_processor(wiki: Wiki, change: dict, changed_categories: dict, displ else: raise if identification_string in ("delete/delete", "delete/delete_redir"): # TODO Move it into a hook? - await wiki.delete_messages(dict(page_id=change.get("pageid"))) + await wiki.delete_messages(dict(page_id=int(change.get("pageid")))) elif identification_string == "delete/event": logparams = change.get('logparams', {"ids": []}) await wiki.redact_messages(context, logparams.get("ids", []), "log_id", logparams.get("new", {})) - await wiki.delete_messages(dict(log_id=logparams.get("ids", []), message_display=0)) + await wiki.delete_messages(dict(log_id=[int(x) for x in logparams.get("ids", [])], message_display=0)) elif identification_string == "delete/revision": logparams = change.get('logparams', {"ids": []}) - if context.message_type == "embed": - if display_options.display == 3: - await wiki.redact_messages(context, (await wiki.find_middle_next(logparams.get("ids", []), change.get("pageid", -1))).union(logparams.get("ids", [])), "rev_id", - {"content": ""}) - else: - await wiki.redact_messages(context, logparams.get("ids", []), "rev_id", logparams.get("new", {})) - else: - for revid in logparams.get("ids", []): - await wiki.delete_messages(dict(rev_id=revid)) + await wiki.redact_messages(context, logparams.get("ids", []), "rev_id", logparams.get("new", {}), page_id=change.get("pageid", -1)) + await wiki.delete_messages(dict(rev_id=[int(x) for x in logparams.get("ids", [])], message_display=0)) run_hooks(post_hooks, discord_message, metadata, context, change) if discord_message: # TODO How to react when none? (crash in formatter), probably bad handling atm discord_message.finish_embed()