mirror of
https://github.com/NohamR/trio_http_proxy.git
synced 2025-05-24 00:59:27 +00:00
Use h11!
This commit is contained in:
parent
4b4fc54ea3
commit
d0e7bcf55c
@ -1 +1,2 @@
|
||||
h11
|
||||
trio
|
||||
|
@ -6,6 +6,7 @@
|
||||
#
|
||||
async-generator==1.10 # via trio
|
||||
attrs==20.2.0 # via trio
|
||||
h11==0.10.0 # via -r requirements.in
|
||||
idna==2.10 # via trio
|
||||
outcome==1.0.1 # via trio
|
||||
sniffio==1.1.0 # via trio
|
||||
|
@ -6,47 +6,48 @@
|
||||
# License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
# file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
from contextvars import ContextVar
|
||||
from functools import partial
|
||||
from io import DEFAULT_BUFFER_SIZE
|
||||
from itertools import count
|
||||
from os import getenv
|
||||
from textwrap import indent
|
||||
from traceback import format_exc
|
||||
|
||||
from contextvars import ContextVar
|
||||
from trio import open_nursery, open_tcp_stream, run, serve_tcp
|
||||
import h11
|
||||
import trio
|
||||
|
||||
|
||||
DEFAULT_PORT = 8080
|
||||
PORT = int(getenv('PORT', DEFAULT_PORT)) # pylint: disable=invalid-envvar-default
|
||||
BUFMAXLEN = 16384
|
||||
PORT = int(getenv('PORT', DEFAULT_PORT))
|
||||
OK_CONNECT_PORTS = {443, 8443}
|
||||
|
||||
prn = partial(print, end='') # pylint: disable=C0103
|
||||
indented = partial(indent, prefix=' ') # pylint: disable=C0103
|
||||
decoded_and_indented = lambda some_bytes: indented(some_bytes.decode()) # pylint: disable=C0103
|
||||
prn = partial(print, end='')
|
||||
indented = partial(indent, prefix=' ')
|
||||
decoded_and_indented = lambda some_bytes: indented(some_bytes.decode())
|
||||
|
||||
CV_CLIENT_STREAM = ContextVar('client_stream', default=None)
|
||||
CV_DEST_STREAM = ContextVar('dest_stream', default=None)
|
||||
CV_PIPE_FROM = ContextVar('pipe_from', default=None)
|
||||
|
||||
|
||||
async def http_proxy(client_stream, _connidgen=count(1)):
|
||||
client_stream.id = next(_connidgen)
|
||||
async def http_proxy(client_stream, _nextid=count(1).__next__):
|
||||
client_stream.id = _nextid()
|
||||
CV_CLIENT_STREAM.set(client_stream)
|
||||
async with client_stream:
|
||||
try:
|
||||
dest_stream = await tunnel(client_stream)
|
||||
async with dest_stream, open_nursery() as nursery:
|
||||
async with dest_stream, trio.open_nursery() as nursery:
|
||||
nursery.start_soon(pipe, client_stream, dest_stream)
|
||||
nursery.start_soon(pipe, dest_stream, client_stream)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
except Exception:
|
||||
log(f'\n{indented(format_exc())}')
|
||||
|
||||
|
||||
async def start_server(server=http_proxy, port=PORT):
|
||||
print(f'* Starting {server.__name__} on port {port or "(OS-selected port)"}...')
|
||||
try:
|
||||
await serve_tcp(server, port)
|
||||
await trio.serve_tcp(server, port)
|
||||
except KeyboardInterrupt:
|
||||
print('\nGoodbye for now.')
|
||||
|
||||
@ -57,52 +58,43 @@ async def tunnel(client_stream):
|
||||
and notify the client when the end-to-end connection has been established.
|
||||
Return the destination stream and the corresponding host.
|
||||
"""
|
||||
desthost, destport = await process_as_http_connect_request(client_stream)
|
||||
desthost, destport = await process_as_connect_request(client_stream)
|
||||
log(f'Got CONNECT request for {desthost}:{destport}, connecting...')
|
||||
dest_stream = await open_tcp_stream(desthost, destport)
|
||||
dest_stream = await trio.open_tcp_stream(desthost, destport)
|
||||
dest_stream.host = desthost
|
||||
dest_stream.port = destport
|
||||
CV_DEST_STREAM.set(dest_stream)
|
||||
log(f'Connected to {desthost}, sending 200 response...')
|
||||
log(f'Connected to {desthost}, sending 200 to client...')
|
||||
await client_stream.send_all(b'HTTP/1.1 200 Connection established\r\n\r\n')
|
||||
log('Sent 200 to client, tunnel established.')
|
||||
return dest_stream
|
||||
|
||||
|
||||
async def process_as_http_connect_request(stream, bufmaxlen=BUFMAXLEN):
|
||||
async def process_as_connect_request(stream, bufmaxsz=DEFAULT_BUFFER_SIZE, maxreqsz=16384):
|
||||
"""Read a stream expected to contain a valid HTTP CONNECT request to desthost:destport.
|
||||
Parse and return the destination host. Validate (lightly) and raise if request invalid.
|
||||
See https://tools.ietf.org/html/rfc7231#section-4.3.6 for the CONNECT spec.
|
||||
"""
|
||||
# TODO: give client 'bad request' errors on assertion failure
|
||||
log(f'Reading...')
|
||||
bytes_read = await stream.receive_some(bufmaxlen)
|
||||
assert bytes_read.endswith(b'\r\n\r\n'), f'CONNECT request did not fit in {bufmaxlen} bytes?\n{decoded_and_indented(bytes_read)}'
|
||||
# Only examine the first two tokens (e.g. "CONNECT example.com:443 [ignored...]").
|
||||
# The Host header should duplicate the CONNECT request's authority and should therefore be safe
|
||||
# to ignore. Plus apparently some clients (iOS, Facebook) don't even send a Host header in
|
||||
# CONNECT requests according to https://go-review.googlesource.com/c/go/+/44004.
|
||||
split = bytes_read.split(maxsplit=2)
|
||||
assert len(split) == 3, f'Expected "<method> <authority> ..."\n{decoded_and_indented(bytes_read)}'
|
||||
method, authority, _ = split
|
||||
assert method == b'CONNECT', f'Expected "CONNECT", "{method}" unsupported\n{decoded_and_indented(bytes_read)}'
|
||||
desthost, colon, destport = authority.partition(b':')
|
||||
assert colon and destport, f'Expected ":<port>" in {authority}\n{decoded_and_indented(bytes_read)}'
|
||||
h11_conn = h11.Connection(our_role=h11.SERVER)
|
||||
total_bytes_read = 0
|
||||
while (h11_nextevt := h11_conn.next_event()) == h11.NEED_DATA:
|
||||
bytes_read = await stream.receive_some(bufmaxsz)
|
||||
total_bytes_read += len(bytes_read)
|
||||
assert total_bytes_read < maxreqsz, f'Request did not fit in {maxreqsz} bytes'
|
||||
h11_conn.receive_data(bytes_read)
|
||||
assert isinstance(h11_nextevt, h11.Request), f'{h11_nextevt=} is not a h11.Request'
|
||||
assert h11_nextevt.method == b'CONNECT', f'{h11_nextevt.method=} != CONNECT'
|
||||
desthost, _, destport = h11_nextevt.target.partition(b':')
|
||||
destport = int(destport.decode())
|
||||
assert destport in OK_CONNECT_PORTS, f'Forbidden destination port: {destport}'
|
||||
assert destport in OK_CONNECT_PORTS, f'{destport=} not in {OK_CONNECT_PORTS}'
|
||||
return desthost.decode(), destport
|
||||
|
||||
|
||||
async def read_all(stream, bufmaxlen=BUFMAXLEN):
|
||||
while True:
|
||||
chunk = await stream.receive_some(bufmaxlen)
|
||||
if not chunk:
|
||||
break
|
||||
yield chunk
|
||||
|
||||
|
||||
async def pipe(from_stream, to_stream, bufmaxlen=BUFMAXLEN):
|
||||
async def pipe(from_stream, to_stream, bufmaxsz=DEFAULT_BUFFER_SIZE):
|
||||
CV_PIPE_FROM.set(from_stream)
|
||||
async for chunk in read_all(from_stream, bufmaxlen=bufmaxlen): # pylint: disable=E1133; https://github.com/PyCQA/pylint/issues/2311
|
||||
async for chunk in from_stream:
|
||||
await to_stream.send_all(chunk)
|
||||
log(f'Forwarded {len(chunk)} bytes')
|
||||
log(f'Pipe finished')
|
||||
@ -124,4 +116,4 @@ def log(*args, **kw):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
run(start_server)
|
||||
trio.run(start_server)
|
||||
|
Loading…
x
Reference in New Issue
Block a user