Skip to content
Snippets Groups Projects
Commit e0872441 authored by Tulir Asokan's avatar Tulir Asokan
Browse files

Add support for shell execution and python exceptions

parent 328fd57a
No related branches found
No related tags found
No related merge requests found
# The message prefix to treat as exec commands
# The message prefix to treat as exec commands.
prefix: '!exec'
# Whether or not to enable "userbot" mode, where commands that the bot's user
# sends are handled and responded to with edits instead of replies.
......@@ -25,12 +25,17 @@ output:
Output:
{{ output }}
{% endif %}
{% if return_value %}
{% if return_value != None %}
Return:
{{ return_value }}
{% endif %}
{% if duration %}
{% if traceback != None %}
{% if traceback_header %}{{ traceback_header }}:{% endif %}
{{ traceback }}
{% endif %}
{% if duration != None %}
Took {{ duration | round(3) }} seconds
{% else %}
......@@ -45,11 +50,19 @@ output:
<h4>Output</h4>
<pre>{{ output }}</pre>
{% endif %}
{% if return_value %}
{% if return_value != None %}
{% if language in ("bash", "sh", "shell") %}
<h4>Return: <code>{{ return_value }}</code></h4>
{% else %}
<h4>Return</h4>
<pre>{{ return_value }}</pre>
{% endif %}
{% endif %}
{% if traceback != None %}
{% if traceback_header %}<h4>{{ traceback_header }}</h4>{% endif %}
<pre><code class="language-pytb">{{ traceback }}</code></pre>
{% endif %}
{% if duration %}
{% if duration != None %}
<h4>Took {{ duration | round(3) }} seconds</h4>
{% else %}
<h4>Running...</h4>
......
......@@ -15,8 +15,8 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Type, Set, Optional, Any
from io import StringIO
from html import escape as escape_orig
from time import time
from html import escape
from jinja2 import Template
......@@ -26,7 +26,11 @@ from mautrix.util.formatter import MatrixParser, EntityString, SimpleEntity, Ent
from maubot import Plugin, MessageEvent
from maubot.handlers import event
from .runner import PythonRunner, OutputType
from .runner import PythonRunner, ShellRunner, OutputType
def escape(val: Optional[str]) -> Optional[str]:
return escape(val) if val else None
class EntityParser(MatrixParser[EntityString]):
......@@ -70,26 +74,27 @@ class ExecBot(Plugin):
self.html_template = Template(self.config["output.html"], **template_args)
def format_status(self, code: str, language: str, output: str = "", output_html: str = "",
return_value: Any = None, duration: Optional[float] = None,
return_value: Any = None, traceback: Optional[str] = None,
traceback_header: Optional[str] = None, duration: Optional[float] = None,
msgtype: MessageType = MessageType.NOTICE) -> TextMessageEventContent:
return_value = repr(return_value) if return_value else ''
return_value = repr(return_value) if return_value is not None else None
content = TextMessageEventContent(
msgtype=msgtype, format=Format.HTML,
body=self.plaintext_template.render(
code=code, language=language, output=output, return_value=return_value,
duration=duration),
duration=duration, traceback=traceback, traceback_header=traceback_header),
formatted_body=self.html_template.render(
code=escape(code), language=language, output=output_html,
return_value=escape(return_value), duration=duration))
return_value=escape(return_value), duration=duration, traceback=escape(traceback),
traceback_header=escape(traceback_header)))
return content
@event.on(EventType.ROOM_MESSAGE)
async def exec(self, evt: MessageEvent) -> None:
if evt.sender not in self.whitelist:
return
elif not evt.content.body.startswith(self.prefix):
return
elif not evt.content.formatted_body:
if ((evt.content.msgtype != MessageType.TEXT
or evt.sender not in self.whitelist
or not evt.content.body.startswith(self.prefix)
or not evt.content.formatted_body)):
return
command = EntityParser.parse(evt.content.formatted_body)
......@@ -110,8 +115,15 @@ class ExecBot(Plugin):
if not code or not lang:
return
if lang != "python":
await evt.respond("Only python is currently supported")
if lang == "python":
runner = PythonRunner(namespace={
"client": self.client,
"event": evt,
})
elif lang in ("shell", "bash", "sh"):
runner = ShellRunner()
else:
await evt.respond(f'Unsupported language "{lang}"')
return
if self.userbot:
......@@ -124,10 +136,10 @@ class ExecBot(Plugin):
content = self.format_status(code, lang, msgtype=msgtype)
output_event_id = await evt.respond(content)
runner = PythonRunner()
output = StringIO()
output_html = StringIO()
return_value: Any = None
traceback, traceback_header = None, None
start_time = time()
prev_output = start_time
async for out_type, data in runner.run(code, stdin):
......@@ -140,6 +152,9 @@ class ExecBot(Plugin):
elif out_type == OutputType.RETURN:
return_value = data
continue
elif out_type == OutputType.EXCEPTION:
traceback, traceback_header = runner.format_exception(data)
continue
cur_time = time()
if prev_output + self.output_interval < cur_time:
......@@ -149,7 +164,9 @@ class ExecBot(Plugin):
await self.client.send_message(evt.room_id, content)
prev_output = cur_time
duration = time() - start_time
print(return_value)
content = self.format_status(code, lang, output.getvalue(), output_html.getvalue(),
return_value, duration, msgtype=msgtype)
return_value, traceback, traceback_header, duration,
msgtype=msgtype)
content.set_edit(output_event_id)
await self.client.send_message(evt.room_id, content)
......@@ -13,7 +13,8 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import AsyncGenerator
from typing import AsyncGenerator, Tuple, Optional, Any
from asyncio import AbstractEventLoop, Queue, Future, get_event_loop, ensure_future, CancelledError
from abc import ABC, abstractmethod
from enum import Enum, auto
......@@ -22,9 +23,47 @@ class OutputType(Enum):
STDOUT = auto()
STDERR = auto()
RETURN = auto()
EXCEPTION = auto()
class AsyncTextOutput:
loop: AbstractEventLoop
queue: Queue
read_task: Optional[Future]
closed: bool
def __init__(self, loop: Optional[AbstractEventLoop] = None) -> None:
self.loop = loop or get_event_loop()
self.read_task = None
self.queue = Queue(loop=self.loop)
self.closed = False
def __aiter__(self) -> 'AsyncTextOutput':
return self
async def __anext__(self) -> str:
if self.closed and self.queue.empty():
raise StopAsyncIteration
self.read_task = ensure_future(self.queue.get(), loop=self.loop)
try:
data = await self.read_task
except CancelledError:
raise StopAsyncIteration
self.queue.task_done()
return data
def close(self) -> None:
self.closed = True
if self.read_task and self.queue.empty():
self.read_task.cancel()
class Runner(ABC):
@abstractmethod
async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
async def run(self, code: str, stdin: str = "", loop: Optional[AbstractEventLoop] = None
) -> AsyncGenerator[Tuple[OutputType, Any], None]:
pass
@abstractmethod
def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]:
pass
......@@ -13,95 +13,83 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import Dict, Any, Optional, AsyncGenerator
from typing import Dict, Any, Optional, Tuple, AsyncGenerator, Type, NamedTuple
from types import TracebackType
from io import IOBase, StringIO
import contextlib
import traceback
import asyncio
import ast
import sys
from mautrix.util.manhole import asyncify
from .base import Runner, OutputType
from .base import Runner, OutputType, AsyncTextOutput
class AsyncTextOutput:
loop: asyncio.AbstractEventLoop
queue: asyncio.Queue
writers: Dict[OutputType, 'ProxyOutput']
read_task: Optional[asyncio.Future]
closed: bool
class SyncTextProxy(AsyncTextOutput):
writers: Dict[OutputType, 'ProxyWriter']
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.loop = loop or asyncio.get_event_loop()
self.read_task = None
self.queue = asyncio.Queue(loop=self.loop)
self.closed = False
super().__init__(loop)
self.writers = {}
def __aiter__(self) -> 'AsyncTextOutput':
return self
async def __anext__(self) -> str:
if self.closed:
raise StopAsyncIteration
self.read_task = asyncio.ensure_future(self.queue.get(), loop=self.loop)
try:
data = await self.read_task
except asyncio.CancelledError:
raise StopAsyncIteration
self.queue.task_done()
return data
def close(self) -> None:
self.closed = True
for proxy in self.writers.values():
proxy.close(_ato=True)
if self.read_task:
self.read_task.cancel()
proxy.close(_stp=True)
super().close()
def get_writer(self, output_type: OutputType) -> 'ProxyOutput':
def get_writer(self, output_type: OutputType) -> 'ProxyWriter':
try:
return self.writers[output_type]
except KeyError:
self.writers[output_type] = proxy = ProxyOutput(output_type, self)
self.writers[output_type] = proxy = ProxyWriter(output_type, self)
return proxy
class ProxyOutput(IOBase):
class ProxyWriter(IOBase):
type: OutputType
ato: AsyncTextOutput
stp: SyncTextProxy
def __init__(self, output_type: OutputType, ato: AsyncTextOutput) -> None:
def __init__(self, output_type: OutputType, stp: SyncTextProxy) -> None:
self.type = output_type
self.ato = ato
self.stp = stp
def write(self, data: str) -> None:
"""Write to the stdout queue"""
self.ato.queue.put_nowait((self.type, data))
self.stp.queue.put_nowait((self.type, data))
def writable(self) -> bool:
return True
def close(self, _ato: bool = False) -> None:
def close(self, _stp: bool = False) -> None:
super().close()
if not _ato:
self.ato.close()
if not _stp:
self.stp.close()
ExcInfo = NamedTuple('ExcInfo', type=Type[BaseException], exc=Exception, tb=TracebackType)
class PythonRunner(Runner):
namespace: Dict[str, Any]
per_run_namespace: bool
def __init__(self, namespace: Optional[Dict[str, Any]] = None) -> None:
def __init__(self, namespace: Optional[Dict[str, Any]] = None, per_run_namespace: bool = True
) -> None:
self.namespace = namespace or {}
self.per_run_namespace = per_run_namespace
async def _run_task(self, stdio: AsyncTextOutput) -> str:
value = await eval("__eval_async_expr()", self.namespace)
stdio.close()
@staticmethod
async def _wait_task(namespace: Dict[str, Any], stdio: SyncTextProxy) -> str:
try:
value = await eval("__eval_async_expr()", namespace)
finally:
stdio.close()
return value
@contextlib.contextmanager
def _redirect_io(self, output: AsyncTextOutput, stdin: StringIO) -> AsyncTextOutput:
def _redirect_io(self, output: SyncTextProxy, stdin: StringIO) -> SyncTextProxy:
old_stdout, old_stderr, old_stdin = sys.stdout, sys.stderr, sys.stdin
sys.stdout = output.get_writer(OutputType.STDOUT)
sys.stderr = output.get_writer(OutputType.STDERR)
......@@ -109,12 +97,45 @@ class PythonRunner(Runner):
yield output
sys.stdout, sys.stderr, sys.stdin = old_stdout, old_stderr, old_stdin
async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST))
exec(codeobj, self.namespace)
with self._redirect_io(AsyncTextOutput(), StringIO(stdin)) as output:
task = asyncio.ensure_future(self._run_task(output))
@staticmethod
def _format_exc(exception: Exception) -> str:
if len(exception.args) == 0:
return type(exception).__name__
elif len(exception.args) == 1:
return f"{type(exception).__name__}: {exception.args[0]}"
else:
return f"{type(exception).__name__}: {exception.args}"
def format_exception(self, exc_info: ExcInfo) -> Tuple[Optional[str], Optional[str]]:
if not exc_info:
return None, None
tb = traceback.extract_tb(exc_info.tb)
line: traceback.FrameSummary
for i, line in enumerate(tb):
if line.filename == "<input>":
line.name = "<module>"
tb = tb[i:]
break
return ("Traceback (most recent call last):",
f"{''.join(traceback.format_list(tb))}"
f"{self._format_exc(exc_info.exc)}")
async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None
) -> AsyncGenerator[Tuple[OutputType, Any], None]:
loop = loop or asyncio.get_event_loop()
codeobj = asyncify(compile(code, "<input>", "exec", optimize=1, flags=ast.PyCF_ONLY_AST),
module="<input>")
namespace = {**self.namespace} if self.per_run_namespace else self.namespace
exec(codeobj, namespace)
with self._redirect_io(SyncTextProxy(loop), StringIO(stdin)) as output:
task = asyncio.ensure_future(self._wait_task(namespace, output), loop=loop)
async for part in output:
yield part
return_value = await task
yield (OutputType.RETURN, return_value)
try:
return_value = await task
except Exception:
yield (OutputType.EXCEPTION, sys.exc_info())
else:
yield (OutputType.RETURN, return_value)
......@@ -13,13 +13,86 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from typing import AsyncGenerator
from typing import AsyncGenerator, Tuple, Optional, Dict, Union, Any
import asyncio
from .base import Runner
from .base import Runner, OutputType, AsyncTextOutput
class AsyncTextProxy(AsyncTextOutput):
proxies: Dict[OutputType, 'StreamProxy']
def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(loop)
self.proxies = {}
def get_proxy(self, type: OutputType, stream: asyncio.StreamReader) -> 'StreamProxy':
try:
return self.proxies[type]
except KeyError:
self.proxies[type] = proxy = StreamProxy(type, self, stream, self.loop)
return proxy
def close(self) -> None:
for proxy in self.proxies.values():
proxy.stop()
super().close()
class StreamProxy:
type: OutputType
atp: AsyncTextProxy
input: asyncio.StreamReader
loop: asyncio.AbstractEventLoop
proxy_task: Optional[asyncio.Future]
def __init__(self, output_type: OutputType, atp: AsyncTextProxy, input: asyncio.StreamReader,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.type = output_type
self.atp = atp
self.input = input
self.loop = loop or asyncio.get_event_loop()
self.proxy_task = None
def start(self) -> None:
if self.proxy_task and not self.proxy_task.done():
raise RuntimeError("Can't re-start running proxy")
self.proxy_task = asyncio.ensure_future(self._proxy(), loop=self.loop)
def stop(self) -> None:
self.proxy_task.cancel()
async def _proxy(self) -> None:
while not self.input.at_eof():
data = await self.input.readline()
if data:
await self.atp.queue.put((self.type, data.decode("utf-8")))
class ShellRunner(Runner):
async def run(self, code: str, stdin: str = "") -> AsyncGenerator[str, None]:
proc = await asyncio.create_subprocess_shell(code)
await proc.wait()
@staticmethod
async def _wait_proc(proc: asyncio.subprocess.Process, output: AsyncTextProxy) -> int:
resp = await proc.wait()
output.close()
return resp
async def run(self, code: str, stdin: str = "", loop: Optional[asyncio.AbstractEventLoop] = None
) -> AsyncGenerator[Tuple[OutputType, Union[str, int]], None]:
loop = loop or asyncio.get_event_loop()
output = AsyncTextProxy()
proc = await asyncio.create_subprocess_shell(code, loop=loop,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE)
output.get_proxy(OutputType.STDOUT, proc.stdout).start()
output.get_proxy(OutputType.STDERR, proc.stderr).start()
proc.stdin.write(stdin.encode("utf-8"))
proc.stdin.write_eof()
waiter = asyncio.ensure_future(self._wait_proc(proc, output), loop=loop)
async for part in output:
yield part
yield (OutputType.RETURN, await waiter)
def format_exception(self, exc_info: Any) -> Tuple[Optional[str], Optional[str]]:
# The user input never returns exceptions in run()
return None, None
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment