From e0872441e810d8cba3298ec84678032b439f9c32 Mon Sep 17 00:00:00 2001 From: Tulir Asokan <tulir@maunium.net> Date: Sat, 31 Aug 2019 15:04:14 +0300 Subject: [PATCH] Add support for shell execution and python exceptions --- base-config.yaml | 23 ++++++-- exec/bot.py | 47 +++++++++++----- exec/runner/base.py | 43 ++++++++++++++- exec/runner/python.py | 125 ++++++++++++++++++++++++------------------ exec/runner/shell.py | 83 ++++++++++++++++++++++++++-- 5 files changed, 242 insertions(+), 79 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index a664942..5f52c85 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -1,4 +1,4 @@ -# The message prefix to treat as exec commands +# The message prefix to treat as exec commands. prefix: '!exec' # Whether or not to enable "userbot" mode, where commands that the bot's user # sends are handled and responded to with edits instead of replies. @@ -25,12 +25,17 @@ output: Output: {{ output }} {% endif %} - {% if return_value %} + {% if return_value != None %} Return: {{ return_value }} {% endif %} - {% if duration %} + {% if traceback != None %} + + {% if traceback_header %}{{ traceback_header }}:{% endif %} + {{ traceback }} + {% endif %} + {% if duration != None %} Took {{ duration | round(3) }} seconds {% else %} @@ -45,11 +50,19 @@ output: <h4>Output</h4> <pre>{{ output }}</pre> {% endif %} - {% if return_value %} + {% if return_value != None %} + {% if language in ("bash", "sh", "shell") %} + <h4>Return: <code>{{ return_value }}</code></h4> + {% else %} <h4>Return</h4> <pre>{{ return_value }}</pre> + {% endif %} + {% endif %} + {% if traceback != None %} + {% if traceback_header %}<h4>{{ traceback_header }}</h4>{% endif %} + <pre><code class="language-pytb">{{ traceback }}</code></pre> {% endif %} - {% if duration %} + {% if duration != None %} <h4>Took {{ duration | round(3) }} seconds</h4> {% else %} <h4>Running...</h4> diff --git a/exec/bot.py b/exec/bot.py index 1cce8d7..0a6219c 100644 --- a/exec/bot.py +++ b/exec/bot.py @@ -15,8 +15,8 @@ # along with this program. If not, see <https://www.gnu.org/licenses/>. from typing import Type, Set, Optional, Any from io import StringIO +from html import escape as escape_orig from time import time -from html import escape from jinja2 import Template @@ -26,7 +26,11 @@ from mautrix.util.formatter import MatrixParser, EntityString, SimpleEntity, Ent from maubot import Plugin, MessageEvent from maubot.handlers import event -from .runner import PythonRunner, OutputType +from .runner import PythonRunner, ShellRunner, OutputType + + +def escape(val: Optional[str]) -> Optional[str]: + return escape(val) if val else None class EntityParser(MatrixParser[EntityString]): @@ -70,26 +74,27 @@ class ExecBot(Plugin): self.html_template = Template(self.config["output.html"], **template_args) def format_status(self, code: str, language: str, output: str = "", output_html: str = "", - return_value: Any = None, duration: Optional[float] = None, + return_value: Any = None, traceback: Optional[str] = None, + traceback_header: Optional[str] = None, duration: Optional[float] = None, msgtype: MessageType = MessageType.NOTICE) -> TextMessageEventContent: - return_value = repr(return_value) if return_value else '' + return_value = repr(return_value) if return_value is not None else None content = TextMessageEventContent( msgtype=msgtype, format=Format.HTML, body=self.plaintext_template.render( code=code, language=language, output=output, return_value=return_value, - duration=duration), + duration=duration, traceback=traceback, traceback_header=traceback_header), formatted_body=self.html_template.render( code=escape(code), language=language, output=output_html, - return_value=escape(return_value), duration=duration)) + return_value=escape(return_value), duration=duration, traceback=escape(traceback), + traceback_header=escape(traceback_header))) return content @event.on(EventType.ROOM_MESSAGE) async def exec(self, evt: MessageEvent) -> None: - if evt.sender not in self.whitelist: - return - elif not evt.content.body.startswith(self.prefix): - return - elif not evt.content.formatted_body: + if ((evt.content.msgtype != MessageType.TEXT + or evt.sender not in self.whitelist + or not evt.content.body.startswith(self.prefix) + or not evt.content.formatted_body)): return command = EntityParser.parse(evt.content.formatted_body) @@ -110,8 +115,15 @@ class ExecBot(Plugin): if not code or not lang: return - if lang != "python": - await evt.respond("Only python is currently supported") + if lang == "python": + runner = PythonRunner(namespace={ + "client": self.client, + "event": evt, + }) + elif lang in ("shell", "bash", "sh"): + runner = ShellRunner() + else: + await evt.respond(f'Unsupported language "{lang}"') return if self.userbot: @@ -124,10 +136,10 @@ class ExecBot(Plugin): content = self.format_status(code, lang, msgtype=msgtype) output_event_id = await evt.respond(content) - runner = PythonRunner() output = StringIO() output_html = StringIO() return_value: Any = None + traceback, traceback_header = None, None start_time = time() prev_output = start_time async for out_type, data in runner.run(code, stdin): @@ -140,6 +152,9 @@ class ExecBot(Plugin): elif out_type == OutputType.RETURN: return_value = data continue + elif out_type == OutputType.EXCEPTION: + traceback, traceback_header = runner.format_exception(data) + continue cur_time = time() if prev_output + self.output_interval < cur_time: @@ -149,7 +164,9 @@ class ExecBot(Plugin): await self.client.send_message(evt.room_id, content) prev_output = cur_time duration = time() - start_time + print(return_value) content = self.format_status(code, lang, output.getvalue(), output_html.getvalue(), - return_value, duration, msgtype=msgtype) + return_value, traceback, traceback_header, duration, + msgtype=msgtype) content.set_edit(output_event_id) await self.client.send_message(evt.room_id, content) diff --git a/exec/runner/base.py b/exec/runner/base.py index 2fafe95..cc73320 100644 --- a/exec/runner/base.py +++ b/exec/runner/base.py @@ -13,7 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. -from typing import AsyncGenerator +from typing import AsyncGenerator, Tuple, Optional, Any +from asyncio import AbstractEventLoop, Queue, Future, get_event_loop, ensure_future, CancelledError from abc import ABC, abstractmethod from enum import Enum, auto @@ -22,9 +23,47 @@ class OutputType(Enum): STDOUT = auto() STDERR = auto() RETURN = auto() + EXCEPTION = auto() + + +class AsyncTextOutput: + loop: AbstractEventLoop + queue: Queue + read_task: Optional[Future] + closed: bool + + def __init__(self, loop: Optional[AbstractEventLoop] = None) -> None: + self.loop = loop or get_event_loop() + self.read_task = None + self.queue = Queue(loop=self.loop) + self.closed = False + + def __aiter__(self) -> 'AsyncTextOutput': + return self + + async def __anext__(self) -> str: + if self.closed and self.queue.empty(): + raise StopAsyncIteration + self.read_task = ensure_future(self.queue.get(), loop=self.loop) + try: + data = await self.read_task + except CancelledError: + raise StopAsyncIteration + self.queue.task_done() + return data + + def close(self) -> None: + self.closed = True + if self.read_task and self.queue.empty(): + self.read_task.cancel() class Runner(ABC): @abstractmethod - async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]: + async def run(self, code: str, stdin: str = "", loop: Optional[AbstractEventLoop] = None + ) -> AsyncGenerator[Tuple[OutputType, Any], None]: + pass + + @abstractmethod + def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]: pass diff --git a/exec/runner/python.py b/exec/runner/python.py index 7c4e238..090da0d 100644 --- a/exec/runner/python.py +++ b/exec/runner/python.py @@ -13,95 +13,83 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. -from typing import Dict, Any, Optional, AsyncGenerator +from typing import Dict, Any, Optional, Tuple, AsyncGenerator, Type, NamedTuple +from types import TracebackType from io import IOBase, StringIO import contextlib +import traceback import asyncio import ast import sys from mautrix.util.manhole import asyncify -from .base import Runner, OutputType +from .base import Runner, OutputType, AsyncTextOutput -class AsyncTextOutput: - loop: asyncio.AbstractEventLoop - queue: asyncio.Queue - writers: Dict[OutputType, 'ProxyOutput'] - read_task: Optional[asyncio.Future] - closed: bool +class SyncTextProxy(AsyncTextOutput): + writers: Dict[OutputType, 'ProxyWriter'] def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: - self.loop = loop or asyncio.get_event_loop() - self.read_task = None - self.queue = asyncio.Queue(loop=self.loop) - self.closed = False + super().__init__(loop) self.writers = {} - def __aiter__(self) -> 'AsyncTextOutput': - return self - - async def __anext__(self) -> str: - if self.closed: - raise StopAsyncIteration - self.read_task = asyncio.ensure_future(self.queue.get(), loop=self.loop) - try: - data = await self.read_task - except asyncio.CancelledError: - raise StopAsyncIteration - self.queue.task_done() - return data - def close(self) -> None: - self.closed = True for proxy in self.writers.values(): - proxy.close(_ato=True) - if self.read_task: - self.read_task.cancel() + proxy.close(_stp=True) + super().close() - def get_writer(self, output_type: OutputType) -> 'ProxyOutput': + def get_writer(self, output_type: OutputType) -> 'ProxyWriter': try: return self.writers[output_type] except KeyError: - self.writers[output_type] = proxy = ProxyOutput(output_type, self) + self.writers[output_type] = proxy = ProxyWriter(output_type, self) return proxy -class ProxyOutput(IOBase): +class ProxyWriter(IOBase): type: OutputType - ato: AsyncTextOutput + stp: SyncTextProxy - def __init__(self, output_type: OutputType, ato: AsyncTextOutput) -> None: + def __init__(self, output_type: OutputType, stp: SyncTextProxy) -> None: self.type = output_type - self.ato = ato + self.stp = stp def write(self, data: str) -> None: """Write to the stdout queue""" - self.ato.queue.put_nowait((self.type, data)) + self.stp.queue.put_nowait((self.type, data)) def writable(self) -> bool: return True - def close(self, _ato: bool = False) -> None: + def close(self, _stp: bool = False) -> None: super().close() - if not _ato: - self.ato.close() + if not _stp: + self.stp.close() + + +ExcInfo = NamedTuple('ExcInfo', type=Type[BaseException], exc=Exception, tb=TracebackType) class PythonRunner(Runner): namespace: Dict[str, Any] + per_run_namespace: bool - def __init__(self, namespace: Optional[Dict[str, Any]] = None) -> None: + def __init__(self, namespace: Optional[Dict[str, Any]] = None, per_run_namespace: bool = True + ) -> None: self.namespace = namespace or {} + self.per_run_namespace = per_run_namespace - async def _run_task(self, stdio: AsyncTextOutput) -> str: - value = await eval("__eval_async_expr()", self.namespace) - stdio.close() + @staticmethod + async def _wait_task(namespace: Dict[str, Any], stdio: SyncTextProxy) -> str: + try: + value = await eval("__eval_async_expr()", namespace) + finally: + stdio.close() return value @contextlib.contextmanager - def _redirect_io(self, output: AsyncTextOutput, stdin: StringIO) -> AsyncTextOutput: + def _redirect_io(self, output: SyncTextProxy, stdin: StringIO) -> SyncTextProxy: old_stdout, old_stderr, old_stdin = sys.stdout, sys.stderr, sys.stdin sys.stdout = output.get_writer(OutputType.STDOUT) sys.stderr = output.get_writer(OutputType.STDERR) @@ -109,12 +97,45 @@ class PythonRunner(Runner): yield output sys.stdout, sys.stderr, sys.stdin = old_stdout, old_stderr, old_stdin - async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]: - codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST)) - exec(codeobj, self.namespace) - with self._redirect_io(AsyncTextOutput(), StringIO(stdin)) as output: - task = asyncio.ensure_future(self._run_task(output)) + @staticmethod + def _format_exc(exception: Exception) -> str: + if len(exception.args) == 0: + return type(exception).__name__ + elif len(exception.args) == 1: + return f"{type(exception).__name__}: {exception.args[0]}" + else: + return f"{type(exception).__name__}: {exception.args}" + + def format_exception(self, exc_info: ExcInfo) -> Tuple[Optional[str], Optional[str]]: + if not exc_info: + return None, None + tb = traceback.extract_tb(exc_info.tb) + + line: traceback.FrameSummary + for i, line in enumerate(tb): + if line.filename == "<input>": + line.name = "<module>" + tb = tb[i:] + break + + return ("Traceback (most recent call last):", + f"{''.join(traceback.format_list(tb))}" + f"{self._format_exc(exc_info.exc)}") + + async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None + ) -> AsyncGenerator[Tuple[OutputType, Any], None]: + loop = loop or asyncio.get_event_loop() + codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST), + module="<input>") + namespace = {**self.namespace} if self.per_run_namespace else self.namespace + exec(codeobj, namespace) + with self._redirect_io(SyncTextProxy(loop), StringIO(stdin)) as output: + task = asyncio.ensure_future(self._wait_task(namespace, output), loop=loop) async for part in output: yield part - return_value = await task - yield (OutputType.RETURN, return_value) + try: + return_value = await task + except Exception: + yield (OutputType.EXCEPTION, sys.exc_info()) + else: + yield (OutputType.RETURN, return_value) diff --git a/exec/runner/shell.py b/exec/runner/shell.py index cd684cd..5694dc7 100644 --- a/exec/runner/shell.py +++ b/exec/runner/shell.py @@ -13,13 +13,86 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see <https://www.gnu.org/licenses/>. -from typing import AsyncGenerator +from typing import AsyncGenerator, Tuple, Optional, Dict, Union, Any import asyncio -from .base import Runner +from .base import Runner, OutputType, AsyncTextOutput + + +class AsyncTextProxy(AsyncTextOutput): + proxies: Dict[OutputType, 'StreamProxy'] + + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(loop) + self.proxies = {} + + def get_proxy(self, type: OutputType, stream: asyncio.StreamReader) -> 'StreamProxy': + try: + return self.proxies[type] + except KeyError: + self.proxies[type] = proxy = StreamProxy(type, self, stream, self.loop) + return proxy + + def close(self) -> None: + for proxy in self.proxies.values(): + proxy.stop() + super().close() + + +class StreamProxy: + type: OutputType + atp: AsyncTextProxy + input: asyncio.StreamReader + loop: asyncio.AbstractEventLoop + proxy_task: Optional[asyncio.Future] + + def __init__(self, output_type: OutputType, atp: AsyncTextProxy, input: asyncio.StreamReader, + loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self.type = output_type + self.atp = atp + self.input = input + self.loop = loop or asyncio.get_event_loop() + self.proxy_task = None + + def start(self) -> None: + if self.proxy_task and not self.proxy_task.done(): + raise RuntimeError("Can't re-start running proxy") + self.proxy_task = asyncio.ensure_future(self._proxy(), loop=self.loop) + + def stop(self) -> None: + self.proxy_task.cancel() + + async def _proxy(self) -> None: + while not self.input.at_eof(): + data = await self.input.readline() + if data: + await self.atp.queue.put((self.type, data.decode("utf-8"))) class ShellRunner(Runner): - async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]: - proc = await asyncio.create_subprocess_shell(code) - await proc.wait() + @staticmethod + async def _wait_proc(proc: asyncio.subprocess.Process, output: AsyncTextProxy) -> int: + resp = await proc.wait() + output.close() + return resp + + async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None + ) -> AsyncGenerator[Tuple[OutputType, Union[str, int]], None]: + loop = loop or asyncio.get_event_loop() + output = AsyncTextProxy() + proc = await asyncio.create_subprocess_shell(code, loop=loop, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE) + output.get_proxy(OutputType.STDOUT, proc.stdout).start() + output.get_proxy(OutputType.STDERR, proc.stderr).start() + proc.stdin.write(stdin.encode("utf-8")) + proc.stdin.write_eof() + waiter = asyncio.ensure_future(self._wait_proc(proc, output), loop=loop) + async for part in output: + yield part + yield (OutputType.RETURN, await waiter) + + def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]: + # The user input never returns exceptions in run() + return None, None -- GitLab