"""
Asgineer test utilities.
"""
import os
import sys
import time
import inspect
import asyncio
import tempfile
import subprocess
from collections import namedtuple
from wsgiref.handlers import format_date_time
from urllib.parse import unquote, urlparse
import requests
import asgineer
Response = namedtuple("Response", ["status", "headers", "body"])
testfilename = os.path.join(
tempfile.gettempdir(), f"asgineer_test_script_{os.getpid()}.py"
)
PORT = 49152 + os.getpid() % 16383 # hash pid to ephimeral port number
URL = f"http://127.0.0.1:{PORT}"
# todo: allow running multiple processes at the same time, by including a sequence number
[docs]class BaseTestServer:
"""Base class for test servers. Objects of this class represent an ASGI
server instance that can be used to test your server implementation.
The ``app`` object passed to the constructor can be an ASGI application
or an async (Asgineer-style) handler.
The server can be started/stopped by using it as a context manager.
The ``url`` attribute represents the url that can be used to make
requests to the server. When the server has stopped, The ``out``
attribute contains the server output (stdout and stderr).
Only one instance of this class (per process) should be used (as a
context manager) at any given time.
"""
def __init__(self, app, server_description, *, loop=None):
self._app = app
self._server = server_description
self._loop = asyncio.get_event_loop() if loop is None else loop
self._out = ""
# Get stdout funcs because the mock server hijacks them
self._stdout_write = sys.stdout.write
self._stdout_flush = sys.stdout.flush
@property
def app(self):
"""The application object that was given at instantiation."""
return self._app
@property
def url(self):
"""The url at which the server is listening."""
return URL
@property
def out(self):
"""The stdout / stderr of the server. This gets set when the
with-statement using this object exits.
"""
return self._out
def __enter__(self):
self.log(f" Create {self._server} server .. ", end="")
self._out = ""
t0 = time.time()
self._start_server()
self.log(f" {time.time()-t0:0.1f}s ", end="")
return self
def __exit__(self, exc_type, exc_value, traceback):
self.log("- Closing .. " if exc_value is None else "Error .. ", end="")
t0 = time.time()
out = self._stop_server()
self._out = "\n".join(self.filter_lines(out.splitlines()))
if exc_value is None:
self.log(f" {time.time()-t0:0.1f}s ")
else:
self.log("Process output:")
self.log(self.out)
[docs] def get(self, path, data=None, headers=None, **kwargs):
"""Send a GET request to the server. See request() for detais."""
return self.request("GET", path, data=data, headers=headers, **kwargs)
[docs] def put(self, path, data=None, headers=None, **kwargs):
"""Send a PUT request to the server. See request() for detais."""
return self.request("PUT", path, data=data, headers=headers, **kwargs)
[docs] def post(self, path, data=None, headers=None, **kwargs):
"""Send a POST request to the server. See request() for detais."""
return self.request("POST", path, data=data, headers=headers, **kwargs)
[docs] def delete(self, path, data=None, headers=None, **kwargs):
"""Send a DELETE request to the server. See request() for detais."""
return self.request("DELETE", path, data=data, headers=headers, **kwargs)
[docs] def request(self, method, path, data=None, headers=None, **kwargs):
"""Send a request to the server. Returns a named tuple ``(status, headers, body)``.
Arguments:
method (str): the HTTP method (e.g. "GET")
path (str): path or url (also see the ``url`` property).
data: the bytes to send (optional).
headers: headers to send (optional).
kwargs: additional arguments to pass to ``requests.request()``.
"""
assert isinstance(method, str)
assert isinstance(path, str)
if path.startswith("http"):
url = path
else:
url = self.url + "/" + path.lstrip("/")
co = self._co_request(method, url, data=data, headers=headers, **kwargs)
co_res = self._loop.run_until_complete(co)
status, headers, body = co_res
return Response(status, headers, body)
[docs] def ws_communicate(self, path, client_co_func, loop=None):
"""Do a websocket request and communicate over the connection.
The ``client_co_func`` object must be an async function, it receives
a ws object as an argument, which has methods ``send``, ``receive`` and
``close``, and it can be iterated over. Messages are either str or bytes.
"""
url = self.url.replace("http", "ws") + "/" + path.lstrip("/")
if loop is None:
loop = asyncio.get_event_loop()
co = self._co_ws_communicate(url, client_co_func, loop)
return loop.run_until_complete(co)
[docs] def log(self, *messages, sep=" ", end="\n"):
"""Log a message. Overloadable. Default write to stdout."""
msg = sep.join(str(m) for m in messages)
self._stdout_write(msg + end)
self._stdout_flush()
[docs] def filter_lines(self, lines):
"""Overloadable line filter."""
return lines
START_CODE = """
import os
import sys
import time
import threading
import _thread
import asgineer
def closer():
while os.path.isfile(__file__):
time.sleep(0.01)
_thread.interrupt_main()
app = APP
async def proxy_app(scope, receive, send):
if scope["path"].startswith("/specialtestpath/"):
await send({"type": "http.response.start", "status": 200, "headers": []})
await send({"type": "http.response.body", "body": b""})
else:
return await app(scope, receive, send)
if __name__ == "__main__":
threading.Thread(target=closer).start()
asgineer.run("__main__:proxy_app", "ASGISERVER", "localhost:PORT")
sys.stderr.flush()
sys.stdout.flush()
sys.exit(0)
"""
LOAD_MODULE_CODE = """
import importlib
def load_module(name, filename):
assert filename.endswith('.py')
if name in sys.modules:
return sys.modules[name]
if '.' in name:
load_module(name.rsplit('.', 1)[0], os.path.join(os.path.dirname(filename), '__init__.py'))
spec = importlib.util.spec_from_file_location(name, filename)
return spec.loader.load_module()
"""
[docs]class ProcessTestServer(BaseTestServer):
"""Subclass of BaseTestServer that runs an actual server in a
subprocess. The ``server`` argument must be a server supported by
Asgineer' ``run()`` function, like "uvicorn", "hypercorn" or "daphne".
This provides a very realistic approach to test server applicationes, though
the overhead of starting and stopping the server costs about a second,
and its hard to measure code coverage in this way. Therefore this approach
is most suited for higher level / integration tests.
Requests can be done via the methods of this object, or using any other
request library.
"""
def __init__(self, app, server, **kwargs):
super().__init__(app, server, **kwargs)
self._app_code = self._get_app_code(app)
def _get_app_code(self, app):
assert app.__code__.co_argcount in (1, 3)
mod = inspect.getmodule(app)
modname = "_main_" if mod.__name__ == "__main__" else mod.__name__
is_handler = app.__code__.co_argcount == 1
name1 = app.__name__
name2 = "handler" if is_handler else "app"
if getattr(mod, name1, None) is app:
# We can import the app - safest option since app may have deps
code = LOAD_MODULE_CODE
code += "sys.path.insert(0, '')\n" + code
if "." not in mod.__name__:
code += f"sys.path.insert(0, {os.path.dirname(mod.__file__)!r})\n"
code += f"{name2} = load_module({modname!r}, {mod.__file__!r}).{name1}"
else:
# Likely a app defined inside a function. Get app from sourece code.
# This will not work if the app has dependencies.
sourcelines = inspect.getsourcelines(app)[0]
indent = inspect.indentsize(sourcelines[0])
code = "\n".join(line[indent:] for line in sourcelines)
code = code.replace("def " + app.__name__, f"def {name2}")
if is_handler:
code += f"\napp = asgineer.to_asgi({name2})"
return code
def _start_server(self):
# Prepare code
code = START_CODE.replace("ASGISERVER", self._server).replace("PORT", str(PORT))
code = code.replace("app = APP", self._app_code)
with open(testfilename, "wb") as f:
f.write((code).encode())
# Start server, clean up the temp filename on failure since __exit__ wont be called.
try:
self._start_subprocess()
except Exception as err:
self._delfile()
raise err
def _start_subprocess(self):
# Start subprocess. Don't use stdin; it breaks multiprocessing somehow!
self._p = subprocess.Popen(
[sys.executable, testfilename],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
# Wait for process to start, and make sure it is not dead
while self._p.poll() is None:
time.sleep(0.02)
try:
requests.get(URL + "/specialtestpath/init", timeout=0.01)
break
except (requests.ConnectionError, requests.ReadTimeout):
pass
if self._p.poll() is not None:
raise RuntimeError(
"Process failed to start!\n" + self._p.stdout.read().decode()
)
def _stop_server(self):
# Ask process to stop
self._delfile()
# Force it to stop if needed
for i in range(5):
etime = time.time() + 5
while self._p.poll() is None and time.time() < etime:
time.sleep(0.01)
if self._p.poll() is not None:
break
self._p.terminate()
else:
raise RuntimeError("Runaway server process failed to terminate!")
if self._p.poll():
self.log(f"nonzero exit code {self._p.poll()}")
# Get output
return self._p.stdout.read().decode(errors="ignore")
def _delfile(self):
try:
os.remove(testfilename)
except Exception:
pass
async def _co_request(self, method, url, **kwargs):
r = requests.request(method, url, **kwargs)
return r.status_code, r.headers, r.content
async def _co_ws_communicate(self, url, client_co_func, loop):
import websockets
try:
ws = await websockets.connect(url)
except websockets.InvalidStatusCode:
return None
ws.receive = ws.recv
res = await client_co_func(ws)
await ws.close()
return res
[docs]class MockTestServer(BaseTestServer):
"""Subclass of BaseTestServer that mocks an ASGI server and
operates in-process. This is a less realistic approach, but faster
and allows tracking test coverage, so it's more suited for unit
tests.
Requests *must* be done via the methods of this object. The used url
can be anything.
"""
def __init__(self, app, **kwargs):
super().__init__(app, "mock", **kwargs)
if app.__code__.co_argcount == 3:
self._asgi_app = app
else:
self._asgi_app = asgineer.to_asgi(app)
self._out_writes = []
def _write(self, msg):
self._out_writes.append(msg)
def _start_server(self):
self._out_writes = []
self._ori_streams = sys.stdout.write, sys.stderr.write
sys.stdout.write = sys.stderr.write = self._write
try:
self._lifespan_messages = []
self._lifespan_completes = []
self._lifespan_task = self._make_lifespan_task()
self._wait_for_lifespan_complete("startup")
except Exception as err:
self._restore_streams()
raise err
def _restore_streams(self):
sys.stdout.write, sys.stderr.write = self._ori_streams
def _stop_server(self):
try:
self._wait_for_lifespan_complete("shutdown")
except Exception as err:
self._restore_streams()
raise err
else:
self._restore_streams()
return "".join(self._out_writes)
def _make_lifespan_task(self):
scope = {"type": "lifespan"}
async def receive():
while True:
if self._lifespan_messages:
return self._lifespan_messages.pop(0)
await asyncio.sleep(0.02)
async def send(m):
self._lifespan_completes.append(m["type"])
return self._loop.create_task(self._asgi_app(scope, receive, send))
def _wait_for_lifespan_complete(self, what, timeout=5):
what_complete = f"lifespan.{what}.complete"
async def waiter():
etime = time.time() + timeout
while what_complete not in self._lifespan_completes:
if self._lifespan_task.done():
raise RuntimeError(
f"Lifespan task finished without producing {what}"
)
if time.time() > etime:
raise RuntimeError(
f"Timeout for {what}, has {self._lifespan_completes}"
)
await asyncio.sleep(0.02)
self._lifespan_messages.append({"type": f"lifespan.{what}"})
self._loop.run_until_complete(waiter())
def _make_scope(self, request):
scheme, netloc, path, params, query, fragement = urlparse(request.url)
if ":" in netloc:
host, port = netloc.split(":", 1)
port = int(port)
else:
host = netloc
port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
# Include the 'host' header.
if "host" in request.headers:
headers = []
elif port == 80:
headers = [[b"host", host.encode()]]
else:
headers = [[b"host", ("%s:%d" % (host, port)).encode()]]
# Include other request headers.
headers += [
[key.lower().encode(), value.encode()]
for key, value in request.headers.items()
]
if scheme.startswith("http"):
return {
"type": "http",
"http_version": "1.1",
"method": request.method,
"scheme": scheme,
"path": unquote(path),
"root_path": "",
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
}
elif scheme.startswith("ws"):
return {
"type": "websocket",
"scheme": scheme,
"path": unquote(path),
"root_path": "",
"query_string": query.encode(),
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"subprotocols": [],
}
else:
raise RuntimeError(f"Unknown scheme: {scheme}")
async def _co_request(self, method, url, **kwargs):
req = requests.Request(method, url, **kwargs)
p = req.prepare() # Get the "resolved" request
p.headers.setdefault("user-agent", "asgi_mock_server")
scope = self._make_scope(p)
# ---
client_to_server = []
server_to_client = []
if p.body is not None:
client_to_server.append(p.body)
else:
client_to_server.append(b"")
async def receive():
if client_to_server:
chunk = client_to_server.pop(0)
return {
"type": "http.request",
"body": chunk,
"more_body": bool(client_to_server),
}
else:
if method == "GET":
# We wait ... this is us mimicking an open connection
await asyncio.sleep(9999)
elif method == "PUT":
return {"type": "http.disconnect"}
else:
# Let's be a bad server and return None instead
return None
async def send(m):
if m["type"] == "http.response.start":
headers = dict((h[0].decode(), h[1].decode()) for h in m["headers"])
headers.setdefault("date", format_date_time(time.time()))
headers.setdefault("server", "asgineer_mock_server")
response.extend([m["status"], headers])
elif m["type"] == "http.response.body":
server_to_client.append(m["body"])
else:
pass # ignore?
response = []
await self._asgi_app(scope, receive, send)
if not response:
response.extend([9999, {}])
response.append(b"".join(server_to_client))
return tuple(response)
async def _co_ws_communicate(self, url, client_co_func, loop):
req = requests.Request("GET", url)
p = req.prepare() # Get the "resolved" request
p.headers.setdefault("user-agent", "asgi_mock_server")
scope = self._make_scope(p)
# ---
client_to_server = []
server_to_client = []
async def receive():
while not client_to_server:
await asyncio.sleep(0.02)
return client_to_server.pop(0)
async def send(m):
server_to_client.append(m)
class WS:
def __init__(self):
self._closed_server = False
self._accepted = False
async def send(self, value):
if self._closed_server:
raise IOError("ConnectionClosed")
if isinstance(value, bytes):
m = {"type": "websocket.receive", "bytes": value}
elif isinstance(value, str):
m = {"type": "websocket.receive", "text": value}
else:
raise TypeError("Can only send bytes/str.")
client_to_server.append(m)
async def receive(self):
# Wait for message to become available
if self._closed_server:
raise IOError("WS is closed")
while not server_to_client:
await asyncio.sleep(0.02)
# Get message and handle special cases
m = server_to_client.pop(0)
if m["type"] in ("websocket.disconnect", "websocket.close"):
self._closed_server = True
raise IOError("WS closed")
if m["type"] == "websocket.accept":
self._accepted = True
return await self.receive()
# Return
return m.get("bytes", None) or m.get("text", None) or b""
async def close(self):
client_to_server.append({"type": "websocket.disconnect"})
async def __aiter__(self):
while True:
try:
yield await self.receive()
except IOError:
return
loop.create_task(self._asgi_app(scope, receive, send))
client_to_server.append({"type": "websocket.connect"})
ws = WS()
result = await client_co_func(ws)
client_to_server.append({"type": "websocket.disconnect"})
return result