From e0872441e810d8cba3298ec84678032b439f9c32 Mon Sep 17 00:00:00 2001
From: Tulir Asokan <tulir@maunium.net>
Date: Sat, 31 Aug 2019 15:04:14 +0300
Subject: [PATCH] Add support for shell execution and python exceptions

---
 base-config.yaml      |  23 ++++++--
 exec/bot.py           |  47 +++++++++++-----
 exec/runner/base.py   |  43 ++++++++++++++-
 exec/runner/python.py | 125 ++++++++++++++++++++++++------------------
 exec/runner/shell.py  |  83 ++++++++++++++++++++++++++--
 5 files changed, 242 insertions(+), 79 deletions(-)

diff --git a/base-config.yaml b/base-config.yaml
index a664942..5f52c85 100644
--- a/base-config.yaml
+++ b/base-config.yaml
@@ -1,4 +1,4 @@
-# The message prefix to treat as exec commands
+# The message prefix to treat as exec commands.
 prefix: '!exec'
 # Whether or not to enable "userbot" mode, where commands that the bot's user
 # sends are handled and responded to with edits instead of replies.
@@ -25,12 +25,17 @@ output:
         Output:
         {{ output }}
         {% endif %}
-        {% if return_value %}
+        {% if return_value != None %}
 
         Return:
         {{ return_value }}
         {% endif %}
-        {% if duration %}
+        {% if traceback != None %}
+
+        {% if traceback_header %}{{ traceback_header }}:{% endif %}
+        {{ traceback }}
+        {% endif %}
+        {% if duration != None %}
 
         Took {{ duration | round(3) }} seconds
         {% else %}
@@ -45,11 +50,19 @@ output:
         <h4>Output</h4>
         <pre>{{ output }}</pre>
         {% endif %}
-        {% if return_value %}
+        {% if return_value != None %}
+            {% if language in ("bash", "sh", "shell") %}
+            <h4>Return: <code>{{ return_value }}</code></h4>
+            {% else %}
             <h4>Return</h4>
             <pre>{{ return_value }}</pre>
+            {% endif %}
+        {% endif %}
+        {% if traceback != None %}
+            {% if traceback_header %}<h4>{{ traceback_header }}</h4>{% endif %}
+            <pre><code class="language-pytb">{{ traceback }}</code></pre>
         {% endif %}
-        {% if duration %}
+        {% if duration != None %}
             <h4>Took {{ duration | round(3) }} seconds</h4>
         {% else %}
             <h4>Running...</h4>
diff --git a/exec/bot.py b/exec/bot.py
index 1cce8d7..0a6219c 100644
--- a/exec/bot.py
+++ b/exec/bot.py
@@ -15,8 +15,8 @@
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
 from typing import Type, Set, Optional, Any
 from io import StringIO
+from html import escape as escape_orig
 from time import time
-from html import escape
 
 from jinja2 import Template
 
@@ -26,7 +26,11 @@ from mautrix.util.formatter import MatrixParser, EntityString, SimpleEntity, Ent
 from maubot import Plugin, MessageEvent
 from maubot.handlers import event
 
-from .runner import PythonRunner, OutputType
+from .runner import PythonRunner, ShellRunner, OutputType
+
+
+def escape(val: Optional[str]) -> Optional[str]:
+    return escape(val) if val else None
 
 
 class EntityParser(MatrixParser[EntityString]):
@@ -70,26 +74,27 @@ class ExecBot(Plugin):
         self.html_template = Template(self.config["output.html"], **template_args)
 
     def format_status(self, code: str, language: str, output: str = "", output_html: str = "",
-                      return_value: Any = None, duration: Optional[float] = None,
+                      return_value: Any = None, traceback: Optional[str] = None,
+                      traceback_header: Optional[str] = None, duration: Optional[float] = None,
                       msgtype: MessageType = MessageType.NOTICE) -> TextMessageEventContent:
-        return_value = repr(return_value) if return_value else ''
+        return_value = repr(return_value) if return_value is not None else None
         content = TextMessageEventContent(
             msgtype=msgtype, format=Format.HTML,
             body=self.plaintext_template.render(
                 code=code, language=language, output=output, return_value=return_value,
-                duration=duration),
+                duration=duration, traceback=traceback, traceback_header=traceback_header),
             formatted_body=self.html_template.render(
                 code=escape(code), language=language, output=output_html,
-                return_value=escape(return_value), duration=duration))
+                return_value=escape(return_value), duration=duration, traceback=escape(traceback),
+                traceback_header=escape(traceback_header)))
         return content
 
     @event.on(EventType.ROOM_MESSAGE)
     async def exec(self, evt: MessageEvent) -> None:
-        if evt.sender not in self.whitelist:
-            return
-        elif not evt.content.body.startswith(self.prefix):
-            return
-        elif not evt.content.formatted_body:
+        if ((evt.content.msgtype != MessageType.TEXT
+             or evt.sender not in self.whitelist
+             or not evt.content.body.startswith(self.prefix)
+             or not evt.content.formatted_body)):
             return
 
         command = EntityParser.parse(evt.content.formatted_body)
@@ -110,8 +115,15 @@ class ExecBot(Plugin):
         if not code or not lang:
             return
 
-        if lang != "python":
-            await evt.respond("Only python is currently supported")
+        if lang == "python":
+            runner = PythonRunner(namespace={
+                "client": self.client,
+                "event": evt,
+            })
+        elif lang in ("shell", "bash", "sh"):
+            runner = ShellRunner()
+        else:
+            await evt.respond(f'Unsupported language "{lang}"')
             return
 
         if self.userbot:
@@ -124,10 +136,10 @@ class ExecBot(Plugin):
             content = self.format_status(code, lang, msgtype=msgtype)
             output_event_id = await evt.respond(content)
 
-        runner = PythonRunner()
         output = StringIO()
         output_html = StringIO()
         return_value: Any = None
+        traceback, traceback_header = None, None
         start_time = time()
         prev_output = start_time
         async for out_type, data in runner.run(code, stdin):
@@ -140,6 +152,9 @@ class ExecBot(Plugin):
             elif out_type == OutputType.RETURN:
                 return_value = data
                 continue
+            elif out_type == OutputType.EXCEPTION:
+                traceback, traceback_header = runner.format_exception(data)
+                continue
 
             cur_time = time()
             if prev_output + self.output_interval < cur_time:
@@ -149,7 +164,9 @@ class ExecBot(Plugin):
                 await self.client.send_message(evt.room_id, content)
                 prev_output = cur_time
         duration = time() - start_time
+        print(return_value)
         content = self.format_status(code, lang, output.getvalue(), output_html.getvalue(),
-                                     return_value, duration, msgtype=msgtype)
+                                     return_value, traceback, traceback_header, duration,
+                                     msgtype=msgtype)
         content.set_edit(output_event_id)
         await self.client.send_message(evt.room_id, content)
diff --git a/exec/runner/base.py b/exec/runner/base.py
index 2fafe95..cc73320 100644
--- a/exec/runner/base.py
+++ b/exec/runner/base.py
@@ -13,7 +13,8 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Tuple, Optional, Any
+from asyncio import AbstractEventLoop, Queue, Future, get_event_loop, ensure_future, CancelledError
 from abc import ABC, abstractmethod
 from enum import Enum, auto
 
@@ -22,9 +23,47 @@ class OutputType(Enum):
     STDOUT = auto()
     STDERR = auto()
     RETURN = auto()
+    EXCEPTION = auto()
+
+
+class AsyncTextOutput:
+    loop: AbstractEventLoop
+    queue: Queue
+    read_task: Optional[Future]
+    closed: bool
+
+    def __init__(self, loop: Optional[AbstractEventLoop] = None) -> None:
+        self.loop = loop or get_event_loop()
+        self.read_task = None
+        self.queue = Queue(loop=self.loop)
+        self.closed = False
+
+    def __aiter__(self) -> 'AsyncTextOutput':
+        return self
+
+    async def __anext__(self) -> str:
+        if self.closed and self.queue.empty():
+            raise StopAsyncIteration
+        self.read_task = ensure_future(self.queue.get(), loop=self.loop)
+        try:
+            data = await self.read_task
+        except CancelledError:
+            raise StopAsyncIteration
+        self.queue.task_done()
+        return data
+
+    def close(self) -> None:
+        self.closed = True
+        if self.read_task and self.queue.empty():
+            self.read_task.cancel()
 
 
 class Runner(ABC):
     @abstractmethod
-    async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
+    async def run(self, code: str, stdin: str = "", loop: Optional[AbstractEventLoop] = None
+                  ) -> AsyncGenerator[Tuple[OutputType, Any], None]:
+        pass
+
+    @abstractmethod
+    def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]:
         pass
diff --git a/exec/runner/python.py b/exec/runner/python.py
index 7c4e238..090da0d 100644
--- a/exec/runner/python.py
+++ b/exec/runner/python.py
@@ -13,95 +13,83 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import Dict, Any, Optional, AsyncGenerator
+from typing import Dict, Any, Optional, Tuple, AsyncGenerator, Type, NamedTuple
+from types import TracebackType
 from io import IOBase, StringIO
 import contextlib
+import traceback
 import asyncio
 import ast
 import sys
 
 from mautrix.util.manhole import asyncify
 
-from .base import Runner, OutputType
+from .base import Runner, OutputType, AsyncTextOutput
 
 
-class AsyncTextOutput:
-    loop: asyncio.AbstractEventLoop
-    queue: asyncio.Queue
-    writers: Dict[OutputType, 'ProxyOutput']
-    read_task: Optional[asyncio.Future]
-    closed: bool
+class SyncTextProxy(AsyncTextOutput):
+    writers: Dict[OutputType, 'ProxyWriter']
 
     def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
-        self.loop = loop or asyncio.get_event_loop()
-        self.read_task = None
-        self.queue = asyncio.Queue(loop=self.loop)
-        self.closed = False
+        super().__init__(loop)
         self.writers = {}
 
-    def __aiter__(self) -> 'AsyncTextOutput':
-        return self
-
-    async def __anext__(self) -> str:
-        if self.closed:
-            raise StopAsyncIteration
-        self.read_task = asyncio.ensure_future(self.queue.get(), loop=self.loop)
-        try:
-            data = await self.read_task
-        except asyncio.CancelledError:
-            raise StopAsyncIteration
-        self.queue.task_done()
-        return data
-
     def close(self) -> None:
-        self.closed = True
         for proxy in self.writers.values():
-            proxy.close(_ato=True)
-        if self.read_task:
-            self.read_task.cancel()
+            proxy.close(_stp=True)
+        super().close()
 
-    def get_writer(self, output_type: OutputType) -> 'ProxyOutput':
+    def get_writer(self, output_type: OutputType) -> 'ProxyWriter':
         try:
             return self.writers[output_type]
         except KeyError:
-            self.writers[output_type] = proxy = ProxyOutput(output_type, self)
+            self.writers[output_type] = proxy = ProxyWriter(output_type, self)
             return proxy
 
 
-class ProxyOutput(IOBase):
+class ProxyWriter(IOBase):
     type: OutputType
-    ato: AsyncTextOutput
+    stp: SyncTextProxy
 
-    def __init__(self, output_type: OutputType, ato: AsyncTextOutput) -> None:
+    def __init__(self, output_type: OutputType, stp: SyncTextProxy) -> None:
         self.type = output_type
-        self.ato = ato
+        self.stp = stp
 
     def write(self, data: str) -> None:
         """Write to the stdout queue"""
-        self.ato.queue.put_nowait((self.type, data))
+        self.stp.queue.put_nowait((self.type, data))
 
     def writable(self) -> bool:
         return True
 
-    def close(self, _ato: bool = False) -> None:
+    def close(self, _stp: bool = False) -> None:
         super().close()
-        if not _ato:
-            self.ato.close()
+        if not _stp:
+            self.stp.close()
+
+
+ExcInfo = NamedTuple('ExcInfo', type=Type[BaseException], exc=Exception, tb=TracebackType)
 
 
 class PythonRunner(Runner):
     namespace: Dict[str, Any]
+    per_run_namespace: bool
 
-    def __init__(self, namespace: Optional[Dict[str, Any]] = None) -> None:
+    def __init__(self, namespace: Optional[Dict[str, Any]] = None, per_run_namespace: bool = True
+                 ) -> None:
         self.namespace = namespace or {}
+        self.per_run_namespace = per_run_namespace
 
-    async def _run_task(self, stdio: AsyncTextOutput) -> str:
-        value = await eval("__eval_async_expr()", self.namespace)
-        stdio.close()
+    @staticmethod
+    async def _wait_task(namespace: Dict[str, Any], stdio: SyncTextProxy) -> str:
+        try:
+            value = await eval("__eval_async_expr()", namespace)
+        finally:
+            stdio.close()
         return value
 
     @contextlib.contextmanager
-    def _redirect_io(self, output: AsyncTextOutput, stdin: StringIO) -> AsyncTextOutput:
+    def _redirect_io(self, output: SyncTextProxy, stdin: StringIO) -> SyncTextProxy:
         old_stdout, old_stderr, old_stdin = sys.stdout, sys.stderr, sys.stdin
         sys.stdout = output.get_writer(OutputType.STDOUT)
         sys.stderr = output.get_writer(OutputType.STDERR)
@@ -109,12 +97,45 @@ class PythonRunner(Runner):
         yield output
         sys.stdout, sys.stderr, sys.stdin = old_stdout, old_stderr, old_stdin
 
-    async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
-        codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST))
-        exec(codeobj, self.namespace)
-        with self._redirect_io(AsyncTextOutput(), StringIO(stdin)) as output:
-            task = asyncio.ensure_future(self._run_task(output))
+    @staticmethod
+    def _format_exc(exception: Exception) -> str:
+        if len(exception.args) == 0:
+            return type(exception).__name__
+        elif len(exception.args) == 1:
+            return f"{type(exception).__name__}: {exception.args[0]}"
+        else:
+            return f"{type(exception).__name__}: {exception.args}"
+
+    def format_exception(self, exc_info: ExcInfo) -> Tuple[Optional[str], Optional[str]]:
+        if not exc_info:
+            return None, None
+        tb = traceback.extract_tb(exc_info.tb)
+
+        line: traceback.FrameSummary
+        for i, line in enumerate(tb):
+            if line.filename == "<input>":
+                line.name = "<module>"
+                tb = tb[i:]
+                break
+
+        return ("Traceback (most recent call last):",
+                f"{''.join(traceback.format_list(tb))}"
+                f"{self._format_exc(exc_info.exc)}")
+
+    async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None
+                  ) -> AsyncGenerator[Tuple[OutputType, Any], None]:
+        loop = loop or asyncio.get_event_loop()
+        codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST),
+                           module="<input>")
+        namespace = {**self.namespace} if self.per_run_namespace else self.namespace
+        exec(codeobj, namespace)
+        with self._redirect_io(SyncTextProxy(loop), StringIO(stdin)) as output:
+            task = asyncio.ensure_future(self._wait_task(namespace, output), loop=loop)
             async for part in output:
                 yield part
-            return_value = await task
-            yield (OutputType.RETURN, return_value)
+            try:
+                return_value = await task
+            except Exception:
+                yield (OutputType.EXCEPTION, sys.exc_info())
+            else:
+                yield (OutputType.RETURN, return_value)
diff --git a/exec/runner/shell.py b/exec/runner/shell.py
index cd684cd..5694dc7 100644
--- a/exec/runner/shell.py
+++ b/exec/runner/shell.py
@@ -13,13 +13,86 @@
 #
 # You should have received a copy of the GNU Affero General Public License
 # along with this program.  If not, see <https://www.gnu.org/licenses/>.
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Tuple, Optional, Dict, Union, Any
 import asyncio
 
-from .base import Runner
+from .base import Runner, OutputType, AsyncTextOutput
+
+
+class AsyncTextProxy(AsyncTextOutput):
+    proxies: Dict[OutputType, 'StreamProxy']
+
+    def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+        super().__init__(loop)
+        self.proxies = {}
+
+    def get_proxy(self, type: OutputType, stream: asyncio.StreamReader) -> 'StreamProxy':
+        try:
+            return self.proxies[type]
+        except KeyError:
+            self.proxies[type] = proxy = StreamProxy(type, self, stream, self.loop)
+            return proxy
+
+    def close(self) -> None:
+        for proxy in self.proxies.values():
+            proxy.stop()
+        super().close()
+
+
+class StreamProxy:
+    type: OutputType
+    atp: AsyncTextProxy
+    input: asyncio.StreamReader
+    loop: asyncio.AbstractEventLoop
+    proxy_task: Optional[asyncio.Future]
+
+    def __init__(self, output_type: OutputType, atp: AsyncTextProxy, input: asyncio.StreamReader,
+                 loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+        self.type = output_type
+        self.atp = atp
+        self.input = input
+        self.loop = loop or asyncio.get_event_loop()
+        self.proxy_task = None
+
+    def start(self) -> None:
+        if self.proxy_task and not self.proxy_task.done():
+            raise RuntimeError("Can't re-start running proxy")
+        self.proxy_task = asyncio.ensure_future(self._proxy(), loop=self.loop)
+
+    def stop(self) -> None:
+        self.proxy_task.cancel()
+
+    async def _proxy(self) -> None:
+        while not self.input.at_eof():
+            data = await self.input.readline()
+            if data:
+                await self.atp.queue.put((self.type, data.decode("utf-8")))
 
 
 class ShellRunner(Runner):
-    async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
-        proc = await asyncio.create_subprocess_shell(code)
-        await proc.wait()
+    @staticmethod
+    async def _wait_proc(proc: asyncio.subprocess.Process, output: AsyncTextProxy) -> int:
+        resp = await proc.wait()
+        output.close()
+        return resp
+
+    async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None
+                  ) -> AsyncGenerator[Tuple[OutputType, Union[str, int]], None]:
+        loop = loop or asyncio.get_event_loop()
+        output = AsyncTextProxy()
+        proc = await asyncio.create_subprocess_shell(code, loop=loop,
+                                                     stdin=asyncio.subprocess.PIPE,
+                                                     stdout=asyncio.subprocess.PIPE,
+                                                     stderr=asyncio.subprocess.PIPE)
+        output.get_proxy(OutputType.STDOUT, proc.stdout).start()
+        output.get_proxy(OutputType.STDERR, proc.stderr).start()
+        proc.stdin.write(stdin.encode("utf-8"))
+        proc.stdin.write_eof()
+        waiter = asyncio.ensure_future(self._wait_proc(proc, output), loop=loop)
+        async for part in output:
+            yield part
+        yield (OutputType.RETURN, await waiter)
+
+    def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]:
+        # The user input never returns exceptions in run()
+        return None, None
-- 
GitLab