diff --git a/src/bot.py b/src/bot.py index 6d1941b..b9b0890 100644 --- a/src/bot.py +++ b/src/bot.py @@ -3,6 +3,7 @@ import asyncio import logging.config import signal import traceback +import nest_asyncio from collections import defaultdict, namedtuple from typing import Generator @@ -29,6 +30,7 @@ if command_line_args.debug: # Log Fail states with structure wiki_url: number of fail states all_wikis: dict = {} mw_msgs: dict = {} # will have the type of id: tuple +main_tasks: dict = {} # First populate the all_wikis list with every wiki # Reasons for this: 1. we require amount of wikis to calculate the cooldown between requests @@ -289,6 +291,8 @@ async def wiki_scanner(): await asyncio.sleep(20.0) await rcqueue.update_queues() except asyncio.CancelledError: + for item in rcqueue.domain_list.values(): # cancel running tasks + item["task"].cancel() raise @@ -297,7 +301,12 @@ async def message_sender(): try: while True: await messagequeue.resend_msgs() + if main_tasks["msg_queue_shield"].cancelled(): + raise asyncio.CancelledError except asyncio.CancelledError: + while len(messagequeue): + logger.info("Shutting down after sending {} more Discord messages...".format(len(messagequeue))) + await messagequeue.resend_msgs() pass except: if command_line_args.debug: @@ -382,16 +391,23 @@ async def discussion_handler(): def shutdown(loop, signal=None): + global main_tasks DBHandler.update_db() db_connection.close() + loop.remove_signal_handler(signal) if len(messagequeue) > 0: logger.warning("Some messages are still queued!") + for task in (main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["msg_queue_shield"]): + task.cancel() + loop.run_until_complete(main_tasks["message_sender"]) for task in asyncio.all_tasks(loop): logger.debug("Killing task") task.cancel() - loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(loop))) - loop.stop() - logger.info("Script has shut down due to signal {}.".format(signal)) + try: + loop.run_until_complete(asyncio.gather(*asyncio.all_tasks(loop))) + except asyncio.CancelledError: + loop.stop() + logger.info("Script has shut down due to signal {}.".format(signal)) # sys.exit(0) @@ -406,7 +422,9 @@ def shutdown(loop, signal=None): async def main_loop(): + global main_tasks loop = asyncio.get_event_loop() + nest_asyncio.apply(loop) try: signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) for s in signals: @@ -417,13 +435,13 @@ async def main_loop(): signals = (signal.SIGBREAK, signal.SIGTERM, signal.SIGINT) # loop.set_exception_handler(global_exception_handler) try: - task1 = asyncio.create_task(wiki_scanner()) - task2 = asyncio.create_task(message_sender()) - task3 = asyncio.create_task(discussion_handler()) - await task1 - await task2 - await task3 + main_tasks = {"wiki_scanner": asyncio.create_task(wiki_scanner()), "message_sender": asyncio.create_task(message_sender()), + "discussion_handler": asyncio.create_task(discussion_handler())} + main_tasks["msg_queue_shield"] = asyncio.shield(main_tasks["message_sender"]) + await asyncio.gather(main_tasks["wiki_scanner"], main_tasks["discussion_handler"], main_tasks["message_sender"]) except KeyboardInterrupt: shutdown(loop) + except asyncio.exceptions.CancelledError: + return asyncio.run(main_loop(), debug=False)