From d0e7bcf55c2f056e51d32abd85090c21cdabc756 Mon Sep 17 00:00:00 2001 From: Joshua Bronson Date: Thu, 1 Oct 2020 09:46:24 -0400 Subject: [PATCH] Use h11! --- requirements.in | 1 + requirements.txt | 1 + trio_http_proxy.py | 72 +++++++++++++++++++++------------------------- 3 files changed, 34 insertions(+), 40 deletions(-) diff --git a/requirements.in b/requirements.in index ae0d704..1363e7a 100644 --- a/requirements.in +++ b/requirements.in @@ -1 +1,2 @@ +h11 trio diff --git a/requirements.txt b/requirements.txt index 215be53..e65ef44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ # async-generator==1.10 # via trio attrs==20.2.0 # via trio +h11==0.10.0 # via -r requirements.in idna==2.10 # via trio outcome==1.0.1 # via trio sniffio==1.1.0 # via trio diff --git a/trio_http_proxy.py b/trio_http_proxy.py index 38017d6..6f69889 100755 --- a/trio_http_proxy.py +++ b/trio_http_proxy.py @@ -6,47 +6,48 @@ # License, v. 2.0. If a copy of the MPL was not distributed with this # file, You can obtain one at http://mozilla.org/MPL/2.0/. +from contextvars import ContextVar from functools import partial +from io import DEFAULT_BUFFER_SIZE from itertools import count from os import getenv from textwrap import indent from traceback import format_exc -from contextvars import ContextVar -from trio import open_nursery, open_tcp_stream, run, serve_tcp +import h11 +import trio DEFAULT_PORT = 8080 -PORT = int(getenv('PORT', DEFAULT_PORT)) # pylint: disable=invalid-envvar-default -BUFMAXLEN = 16384 +PORT = int(getenv('PORT', DEFAULT_PORT)) OK_CONNECT_PORTS = {443, 8443} -prn = partial(print, end='') # pylint: disable=C0103 -indented = partial(indent, prefix=' ') # pylint: disable=C0103 -decoded_and_indented = lambda some_bytes: indented(some_bytes.decode()) # pylint: disable=C0103 +prn = partial(print, end='') +indented = partial(indent, prefix=' ') +decoded_and_indented = lambda some_bytes: indented(some_bytes.decode()) CV_CLIENT_STREAM = ContextVar('client_stream', default=None) CV_DEST_STREAM = ContextVar('dest_stream', default=None) CV_PIPE_FROM = ContextVar('pipe_from', default=None) -async def http_proxy(client_stream, _connidgen=count(1)): - client_stream.id = next(_connidgen) +async def http_proxy(client_stream, _nextid=count(1).__next__): + client_stream.id = _nextid() CV_CLIENT_STREAM.set(client_stream) async with client_stream: try: dest_stream = await tunnel(client_stream) - async with dest_stream, open_nursery() as nursery: + async with dest_stream, trio.open_nursery() as nursery: nursery.start_soon(pipe, client_stream, dest_stream) nursery.start_soon(pipe, dest_stream, client_stream) - except Exception: # pylint: disable=broad-except + except Exception: log(f'\n{indented(format_exc())}') async def start_server(server=http_proxy, port=PORT): print(f'* Starting {server.__name__} on port {port or "(OS-selected port)"}...') try: - await serve_tcp(server, port) + await trio.serve_tcp(server, port) except KeyboardInterrupt: print('\nGoodbye for now.') @@ -57,52 +58,43 @@ async def tunnel(client_stream): and notify the client when the end-to-end connection has been established. Return the destination stream and the corresponding host. """ - desthost, destport = await process_as_http_connect_request(client_stream) + desthost, destport = await process_as_connect_request(client_stream) log(f'Got CONNECT request for {desthost}:{destport}, connecting...') - dest_stream = await open_tcp_stream(desthost, destport) + dest_stream = await trio.open_tcp_stream(desthost, destport) dest_stream.host = desthost dest_stream.port = destport CV_DEST_STREAM.set(dest_stream) - log(f'Connected to {desthost}, sending 200 response...') + log(f'Connected to {desthost}, sending 200 to client...') await client_stream.send_all(b'HTTP/1.1 200 Connection established\r\n\r\n') log('Sent 200 to client, tunnel established.') return dest_stream -async def process_as_http_connect_request(stream, bufmaxlen=BUFMAXLEN): +async def process_as_connect_request(stream, bufmaxsz=DEFAULT_BUFFER_SIZE, maxreqsz=16384): """Read a stream expected to contain a valid HTTP CONNECT request to desthost:destport. Parse and return the destination host. Validate (lightly) and raise if request invalid. See https://tools.ietf.org/html/rfc7231#section-4.3.6 for the CONNECT spec. """ + # TODO: give client 'bad request' errors on assertion failure log(f'Reading...') - bytes_read = await stream.receive_some(bufmaxlen) - assert bytes_read.endswith(b'\r\n\r\n'), f'CONNECT request did not fit in {bufmaxlen} bytes?\n{decoded_and_indented(bytes_read)}' - # Only examine the first two tokens (e.g. "CONNECT example.com:443 [ignored...]"). - # The Host header should duplicate the CONNECT request's authority and should therefore be safe - # to ignore. Plus apparently some clients (iOS, Facebook) don't even send a Host header in - # CONNECT requests according to https://go-review.googlesource.com/c/go/+/44004. - split = bytes_read.split(maxsplit=2) - assert len(split) == 3, f'Expected " ..."\n{decoded_and_indented(bytes_read)}' - method, authority, _ = split - assert method == b'CONNECT', f'Expected "CONNECT", "{method}" unsupported\n{decoded_and_indented(bytes_read)}' - desthost, colon, destport = authority.partition(b':') - assert colon and destport, f'Expected ":" in {authority}\n{decoded_and_indented(bytes_read)}' + h11_conn = h11.Connection(our_role=h11.SERVER) + total_bytes_read = 0 + while (h11_nextevt := h11_conn.next_event()) == h11.NEED_DATA: + bytes_read = await stream.receive_some(bufmaxsz) + total_bytes_read += len(bytes_read) + assert total_bytes_read < maxreqsz, f'Request did not fit in {maxreqsz} bytes' + h11_conn.receive_data(bytes_read) + assert isinstance(h11_nextevt, h11.Request), f'{h11_nextevt=} is not a h11.Request' + assert h11_nextevt.method == b'CONNECT', f'{h11_nextevt.method=} != CONNECT' + desthost, _, destport = h11_nextevt.target.partition(b':') destport = int(destport.decode()) - assert destport in OK_CONNECT_PORTS, f'Forbidden destination port: {destport}' + assert destport in OK_CONNECT_PORTS, f'{destport=} not in {OK_CONNECT_PORTS}' return desthost.decode(), destport -async def read_all(stream, bufmaxlen=BUFMAXLEN): - while True: - chunk = await stream.receive_some(bufmaxlen) - if not chunk: - break - yield chunk - - -async def pipe(from_stream, to_stream, bufmaxlen=BUFMAXLEN): +async def pipe(from_stream, to_stream, bufmaxsz=DEFAULT_BUFFER_SIZE): CV_PIPE_FROM.set(from_stream) - async for chunk in read_all(from_stream, bufmaxlen=bufmaxlen): # pylint: disable=E1133; https://github.com/PyCQA/pylint/issues/2311 + async for chunk in from_stream: await to_stream.send_all(chunk) log(f'Forwarded {len(chunk)} bytes') log(f'Pipe finished') @@ -124,4 +116,4 @@ def log(*args, **kw): if __name__ == '__main__': - run(start_server) + trio.run(start_server)