添加插件框架
parent
5d98bb5d14
commit
f886506c39
@ -0,0 +1,98 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import *
|
||||
|
||||
import tornado.web
|
||||
import tornado.websocket
|
||||
|
||||
import api.base
|
||||
import blcsdk.models as models
|
||||
import services.plugin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _PluginHandlerBase(api.base.ApiHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.plugin: Optional[services.plugin.Plugin] = None
|
||||
|
||||
def prepare(self):
|
||||
try:
|
||||
auth = self.request.headers['Authorization']
|
||||
if not auth.startswith('Bearer '):
|
||||
raise ValueError(f'Bad authorization: {auth}')
|
||||
token = auth[7:]
|
||||
|
||||
self.plugin = services.plugin.get_plugin_by_token(token)
|
||||
if self.plugin is None:
|
||||
raise ValueError(f'Token error: {token}')
|
||||
except (KeyError, ValueError) as e:
|
||||
logger.warning('client=%s failed to find plugin: %r', self.request.remote_ip, e)
|
||||
raise tornado.web.HTTPError(403)
|
||||
|
||||
super().prepare()
|
||||
|
||||
|
||||
def make_message_body(cmd, data, extra: Optional[dict] = None):
|
||||
body = {'cmd': cmd, 'data': data}
|
||||
if extra:
|
||||
body['extra'] = extra
|
||||
return json.dumps(body).encode('utf-8')
|
||||
|
||||
|
||||
class PluginWsHandler(_PluginHandlerBase, tornado.websocket.WebSocketHandler):
|
||||
HEARTBEAT_INTERVAL = 10
|
||||
RECEIVE_TIMEOUT = HEARTBEAT_INTERVAL + 5
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._heartbeat_timer_handle = None
|
||||
self._receive_timeout_timer_handle = None
|
||||
|
||||
def open(self):
|
||||
logger.info('plugin=%s connected, client=%s', self.plugin.id, self.request.remote_ip)
|
||||
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
|
||||
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
|
||||
)
|
||||
self._refresh_receive_timeout_timer()
|
||||
|
||||
self.plugin.on_client_connect(self)
|
||||
|
||||
def _on_send_heartbeat(self):
|
||||
self.send_cmd_data(models.Command.HEARTBEAT, {})
|
||||
self._heartbeat_timer_handle = asyncio.get_running_loop().call_later(
|
||||
self.HEARTBEAT_INTERVAL, self._on_send_heartbeat
|
||||
)
|
||||
|
||||
def _refresh_receive_timeout_timer(self):
|
||||
if self._receive_timeout_timer_handle is not None:
|
||||
self._receive_timeout_timer_handle.cancel()
|
||||
self._receive_timeout_timer_handle = asyncio.get_running_loop().call_later(
|
||||
self.RECEIVE_TIMEOUT, self._on_receive_timeout
|
||||
)
|
||||
|
||||
def _on_receive_timeout(self):
|
||||
logger.info('plugin=%s timed out', self.plugin.id)
|
||||
self._receive_timeout_timer_handle = None
|
||||
self.close()
|
||||
|
||||
def on_close(self):
|
||||
logger.info('plugin=%s disconnected', self.plugin.id)
|
||||
self.plugin.on_client_close(self)
|
||||
|
||||
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
|
||||
self.send_body_no_raise(make_message_body(cmd, data, extra))
|
||||
|
||||
def send_body_no_raise(self, body: Union[bytes, str, Dict[str, Any]]):
|
||||
try:
|
||||
self.write_message(body)
|
||||
except tornado.websocket.WebSocketClosedError:
|
||||
self.close()
|
||||
|
||||
|
||||
ROUTES = [
|
||||
(r'/api/plugin/websocket', PluginWsHandler),
|
||||
]
|
@ -0,0 +1,2 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
__version__ = '0.0.1'
|
@ -0,0 +1 @@
|
||||
# -*- coding: utf-8 -*-
|
@ -0,0 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import enum
|
||||
|
||||
|
||||
class Command(enum.IntEnum):
|
||||
HEARTBEAT = 0
|
@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import blcsdk
|
||||
|
||||
|
||||
async def main():
|
||||
print('hello world!', blcsdk.__version__)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(asyncio.run(main()))
|
@ -0,0 +1,252 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import dataclasses
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import subprocess
|
||||
from typing import *
|
||||
|
||||
import api.plugin
|
||||
import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PLUGINS_PATH = os.path.join(config.DATA_PATH, 'plugins')
|
||||
|
||||
_plugins: Dict[str, 'Plugin'] = {}
|
||||
|
||||
|
||||
def init():
|
||||
plugin_ids = _discover_plugin_ids()
|
||||
if not plugin_ids:
|
||||
return
|
||||
logger.info('Found plugins: %s', plugin_ids)
|
||||
|
||||
for plugin_id in plugin_ids:
|
||||
plugin = _create_plugin(plugin_id)
|
||||
if plugin is not None:
|
||||
_plugins[plugin_id] = plugin
|
||||
|
||||
for plugin in _plugins.values():
|
||||
if plugin.enabled:
|
||||
try:
|
||||
plugin.start()
|
||||
except StartPluginError:
|
||||
pass
|
||||
|
||||
|
||||
def shut_down():
|
||||
for plugin in _plugins.values():
|
||||
plugin.stop()
|
||||
|
||||
|
||||
def _discover_plugin_ids():
|
||||
res = []
|
||||
try:
|
||||
with os.scandir(PLUGINS_PATH) as it:
|
||||
for entry in it:
|
||||
if entry.is_dir() and os.path.isfile(os.path.join(entry.path, 'plugin.json')):
|
||||
res.append(entry.name)
|
||||
except OSError:
|
||||
logger.exception('Failed to discover plugins:')
|
||||
return res
|
||||
|
||||
|
||||
def _create_plugin(plugin_id):
|
||||
config_path = os.path.join(PLUGINS_PATH, plugin_id, 'plugin.json')
|
||||
try:
|
||||
plugin_config = PluginConfig.from_file(config_path)
|
||||
except (OSError, json.JSONDecodeError, TypeError):
|
||||
logger.exception('plugin=%s failed to load config:', plugin_id)
|
||||
return None
|
||||
return Plugin(plugin_id, plugin_config)
|
||||
|
||||
|
||||
def iter_plugins() -> Iterable['Plugin']:
|
||||
return _plugins.values()
|
||||
|
||||
|
||||
def get_plugin_by_token(token):
|
||||
if token == '':
|
||||
return None
|
||||
for plugin in _plugins.values():
|
||||
if plugin.token == token:
|
||||
return plugin
|
||||
return None
|
||||
|
||||
|
||||
def broadcast_cmd_data(cmd, data, extra: Optional[dict] = None):
|
||||
body = api.plugin.make_message_body(cmd, data, extra)
|
||||
for plugin in _plugins.values():
|
||||
plugin.send_body_no_raise(body)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PluginConfig:
|
||||
name: str = ''
|
||||
version: str = ''
|
||||
author: str = ''
|
||||
description: str = ''
|
||||
run_cmd: str = ''
|
||||
enabled: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, path):
|
||||
with open(path, encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'Config type error, type={type(cfg)}')
|
||||
|
||||
return cls(
|
||||
name=str(cfg.get('name', '')),
|
||||
version=str(cfg.get('version', '')),
|
||||
author=str(cfg.get('author', '')),
|
||||
description=str(cfg.get('description', '')),
|
||||
run_cmd=str(cfg.get('run', '')),
|
||||
enabled=bool(cfg.get('enabled', False)),
|
||||
)
|
||||
|
||||
def save(self, path):
|
||||
try:
|
||||
with open(path, encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'Config type error, type={type(cfg)}')
|
||||
except (OSError, json.JSONDecodeError, TypeError):
|
||||
cfg = {}
|
||||
|
||||
cfg['name'] = self.name
|
||||
cfg['version'] = self.version
|
||||
cfg['author'] = self.author
|
||||
cfg['description'] = self.description
|
||||
cfg['run_cmd'] = self.run_cmd
|
||||
cfg['enabled'] = self.enabled
|
||||
|
||||
tmp_path = path + '.tmp'
|
||||
with open(tmp_path, encoding='utf-8') as f:
|
||||
json.dump(cfg, f, ensure_ascii=False, indent=2)
|
||||
os.replace(tmp_path, path)
|
||||
|
||||
|
||||
class StartPluginError(Exception):
|
||||
"""启动插件时错误"""
|
||||
|
||||
|
||||
class StartTooFrequently(StartPluginError):
|
||||
"""启动插件太频繁"""
|
||||
|
||||
|
||||
class Plugin:
|
||||
def __init__(self, plugin_id, cfg: PluginConfig):
|
||||
self._id = plugin_id
|
||||
self._config = cfg
|
||||
|
||||
self._last_start_time = datetime.datetime.fromtimestamp(0)
|
||||
self._token = ''
|
||||
self._client: Optional['api.plugin.PluginWsHandler'] = None
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def enabled(self):
|
||||
return self._config.enabled
|
||||
|
||||
@enabled.setter
|
||||
def enabled(self, value):
|
||||
if self._config.enabled == value:
|
||||
return
|
||||
self._config.enabled = value
|
||||
|
||||
config_path = os.path.join(self.base_path, 'plugin.json')
|
||||
try:
|
||||
self._config.save(config_path)
|
||||
except OSError:
|
||||
logger.exception('plugin=%s failed to save config', self._id)
|
||||
|
||||
if value:
|
||||
self.start()
|
||||
else:
|
||||
self.stop()
|
||||
|
||||
@property
|
||||
def base_path(self):
|
||||
return os.path.join(PLUGINS_PATH, self._id)
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._token
|
||||
|
||||
@property
|
||||
def is_started(self):
|
||||
return self._token != ''
|
||||
|
||||
@property
|
||||
def is_connected(self):
|
||||
return self._client is not None
|
||||
|
||||
def start(self):
|
||||
if self.is_started:
|
||||
return
|
||||
|
||||
cur_time = datetime.datetime.now()
|
||||
if cur_time - self._last_start_time < datetime.timedelta(seconds=3):
|
||||
raise StartTooFrequently(f'plugin={self._id} starts too frequently')
|
||||
self._last_start_time = cur_time
|
||||
|
||||
token = ''.join(random.choice(string.hexdigits) for _ in range(32))
|
||||
self._set_token(token)
|
||||
|
||||
env = {
|
||||
**os.environ,
|
||||
'BLC_PORT': str(12450), # TODO 读配置
|
||||
'BLC_TOKEN': self._token,
|
||||
}
|
||||
try:
|
||||
subprocess.Popen(
|
||||
self._config.run_cmd,
|
||||
shell=True,
|
||||
cwd=self.base_path,
|
||||
env=env,
|
||||
)
|
||||
except OSError as e:
|
||||
logger.exception('plugin=%s failed to start', self._id)
|
||||
raise StartPluginError(str(e))
|
||||
|
||||
def stop(self):
|
||||
if self.is_started:
|
||||
self._set_token('')
|
||||
|
||||
def _set_token(self, token):
|
||||
if self._token == token:
|
||||
return
|
||||
self._token = token
|
||||
|
||||
# 踢掉已经连接的客户端
|
||||
self._set_client(None)
|
||||
|
||||
def _set_client(self, client: Optional['api.plugin.PluginWsHandler']):
|
||||
if self._client is client:
|
||||
return
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
self._client = client
|
||||
|
||||
def on_client_connect(self, client: 'api.plugin.PluginWsHandler'):
|
||||
self._set_client(client)
|
||||
|
||||
def on_client_close(self, client: 'api.plugin.PluginWsHandler'):
|
||||
if self._client is client:
|
||||
self._set_client(None)
|
||||
|
||||
def send_cmd_data(self, cmd, data, extra: Optional[dict] = None):
|
||||
if self._client is not None:
|
||||
self._client.send_cmd_data(cmd, data, extra)
|
||||
|
||||
def send_body_no_raise(self, body):
|
||||
if self._client is not None:
|
||||
self._client.send_body_no_raise(body)
|
Loading…
Reference in New Issue