diff --git a/api/chat.py b/api/chat.py index c1fcbe1..9c756e4 100644 --- a/api/chat.py +++ b/api/chat.py @@ -17,6 +17,7 @@ import config import services.avatar import services.chat import services.translate +import utils.async_io import utils.request logger = logging.getLogger(__name__) @@ -224,7 +225,7 @@ class ChatHandler(tornado.websocket.WebSocketHandler): pass services.chat.client_room_manager.add_client(self.room_key, self) - asyncio.create_task(self._on_joined_room()) + utils.async_io.create_task_with_ref(self._on_joined_room()) self._refresh_receive_timeout_timer() diff --git a/services/avatar.py b/services/avatar.py index e727f14..b67c6be 100644 --- a/services/avatar.py +++ b/services/avatar.py @@ -14,6 +14,7 @@ import sqlalchemy.exc import config import models.bilibili as bl_models import models.database +import utils.async_io import utils.request logger = logging.getLogger(__name__) @@ -41,7 +42,7 @@ def init(): global _avatar_url_cache, _task_queue _avatar_url_cache = cachetools.TTLCache(cfg.avatar_cache_size, 10 * 60) _task_queue = asyncio.Queue(cfg.fetch_avatar_max_queue_size) - asyncio.get_running_loop().create_task(_do_init()) + utils.async_io.create_task_with_ref(_do_init()) async def _do_init(): @@ -89,7 +90,7 @@ async def get_avatar_url_or_none(user_id) -> Optional[str]: _update_avatar_cache_in_memory(user_id, avatar_url) # 如果距离数据库上次更新太久,则在后台从接口获取,并更新所有缓存 if (datetime.datetime.now() - user.update_time).days >= 1: - asyncio.create_task(_refresh_avatar_cache_from_web(user_id)) + utils.async_io.create_task_with_ref(_refresh_avatar_cache_from_web(user_id)) return avatar_url # 从接口获取 @@ -249,7 +250,7 @@ class AvatarFetcher: self._cool_down_timer_handle = None async def init(self): - asyncio.create_task(self._fetch_consumer()) + utils.async_io.create_task_with_ref(self._fetch_consumer()) return True @property diff --git a/services/chat.py b/services/chat.py index d16fd32..449b61f 100644 --- a/services/chat.py +++ b/services/chat.py @@ -14,6 +14,7 @@ import blivedm.blivedm.models.web as dm_web_models import config import services.avatar import services.translate +import utils.async_io import utils.request logger = logging.getLogger(__name__) @@ -239,7 +240,7 @@ class OpenLiveClient(blivedm.OpenLiveClient): self._game_heartbeat_timer_handle = asyncio.get_running_loop().call_later( sleep_time, self._on_send_game_heartbeat ) - asyncio.create_task(self._send_game_heartbeat()) + utils.async_io.create_task_with_ref(self._send_game_heartbeat()) async def _send_game_heartbeat(self): if self._game_id in (None, ''): @@ -412,7 +413,7 @@ class LiveMsgHandler(blivedm.BaseHandler): _live_client_manager.del_live_client(client.room_key) def _on_danmaku(self, client: WebLiveClient, message: dm_web_models.DanmakuMessage): - asyncio.create_task(self.__on_danmaku(client, message)) + utils.async_io.create_task_with_ref(self.__on_danmaku(client, message)) async def __on_danmaku(self, client: WebLiveClient, message: dm_web_models.DanmakuMessage): # 先异步调用再获取房间,因为返回时房间可能已经不存在了 @@ -500,7 +501,7 @@ class LiveMsgHandler(blivedm.BaseHandler): }) def _on_buy_guard(self, client: WebLiveClient, message: dm_web_models.GuardBuyMessage): - asyncio.create_task(self.__on_buy_guard(client, message)) + utils.async_io.create_task_with_ref(self.__on_buy_guard(client, message)) @staticmethod async def __on_buy_guard(client: WebLiveClient, message: dm_web_models.GuardBuyMessage): @@ -552,7 +553,7 @@ class LiveMsgHandler(blivedm.BaseHandler): }) if need_translate: - asyncio.create_task(self._translate_and_response( + utils.async_io.create_task_with_ref(self._translate_and_response( message.message, room.room_key, msg_id, services.translate.Priority.HIGH )) @@ -649,7 +650,9 @@ class LiveMsgHandler(blivedm.BaseHandler): )) if need_translate: - asyncio.create_task(self._translate_and_response(message.msg, room.room_key, message.msg_id)) + utils.async_io.create_task_with_ref(self._translate_and_response( + message.msg, room.room_key, message.msg_id + )) def _on_open_live_gift(self, client: OpenLiveClient, message: dm_open_models.GiftMessage): avatar_url = services.avatar.process_avatar_url(message.uface) @@ -723,7 +726,7 @@ class LiveMsgHandler(blivedm.BaseHandler): }) if need_translate: - asyncio.create_task(self._translate_and_response( + utils.async_io.create_task_with_ref(self._translate_and_response( message.message, room.room_key, msg_id, services.translate.Priority.HIGH )) diff --git a/services/open_live.py b/services/open_live.py index 314daa1..4cb4311 100644 --- a/services/open_live.py +++ b/services/open_live.py @@ -7,6 +7,7 @@ from typing import * import api.open_live import config +import utils.async_io logger = logging.getLogger(__name__) @@ -24,7 +25,7 @@ def init(): cfg = config.get_config() # 批量心跳只支持配置了开放平台的公共服务器,私有服务器用的人少,意义不大 if cfg.is_open_live_configured: - asyncio.create_task(_game_heartbeat_consumer()) + utils.async_io.create_task_with_ref(_game_heartbeat_consumer()) async def send_game_heartbeat(game_id) -> dict: diff --git a/services/translate.py b/services/translate.py index 7ccf3e2..4838ea4 100644 --- a/services/translate.py +++ b/services/translate.py @@ -19,6 +19,7 @@ import aiohttp import cachetools import config +import utils.async_io import utils.request logger = logging.getLogger(__name__) @@ -56,7 +57,7 @@ def init(): _translate_cache = cachetools.LRUCache(cfg.translation_cache_size) # 总队列长度会超过translate_max_queue_size,不用这么严格 _task_queues = [asyncio.Queue(cfg.translate_max_queue_size) for _ in range(len(Priority))] - asyncio.get_running_loop().create_task(_do_init()) + utils.async_io.create_task_with_ref(_do_init()) async def _do_init(): @@ -229,7 +230,7 @@ class TranslateProvider: self._be_available_event.set() async def init(self): - asyncio.create_task(self._translate_consumer()) + utils.async_io.create_task_with_ref(self._translate_consumer()) return True @property diff --git a/update.py b/update.py index 1ec3ad7..80cb41f 100644 --- a/update.py +++ b/update.py @@ -3,13 +3,14 @@ import asyncio import aiohttp +import utils.async_io import utils.request VERSION = 'v1.8.2' def check_update(): - asyncio.get_running_loop().create_task(_do_check_update()) + utils.async_io.create_task_with_ref(_do_check_update()) async def _do_check_update(): diff --git a/utils/async_io.py b/utils/async_io.py new file mode 100644 index 0000000..f0bb9f4 --- /dev/null +++ b/utils/async_io.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +import asyncio + +# 只用于持有Task的引用 +_task_refs = set() + + +def create_task_with_ref(*args, **kwargs): + """创建Task并保持引用,防止协程执行完之前就被GC""" + task = asyncio.create_task(*args, **kwargs) + _task_refs.add(task) + task.add_done_callback(_task_refs.discard) + return task