diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c31669a..7617e1f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,19 +10,24 @@ jobs: tox: runs-on: ubuntu-latest strategy: - max-parallel: 5 + max-parallel: 7 matrix: python-version: - - 3.6 - - 3.7 - 3.8 - - pypy3 + - 3.9 + - "3.10" + - "3.11" + - "3.12" + - "3.13" + - pypy-3.9 + - pypy-3.10 steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + allow-prereleases: true - name: Install tox run: | python -m pip install --upgrade pip setuptools @@ -33,6 +38,6 @@ jobs: - name: Test with tox run: | tox --parallel 0 - - uses: codecov/codecov-action@v1 + - uses: codecov/codecov-action@v4 with: file: ./coverage.xml diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 0000000..38d4fcc --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +version: 2 + +build: + os: ubuntu-22.04 + apt_packages: + - graphviz + tools: + python: "3.8" + +sphinx: + configuration: docs/source/conf.py + +python: + install: + - method: pip + path: . + - requirements: docs/requirements.txt diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 45a1ebf..27c4d15 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -48,7 +48,7 @@ other hand, the following are all very welcome: the project. With those projects installed the commands, black h11/ bench/ examples/ fuzz/ - isort --dont-skip __init__.py --apply --settings-path setup.cfg --recursive h11 bench examples fuzz + isort --profile black --dt h11 bench examples fuzz will format your code for you. diff --git a/MANIFEST.in b/MANIFEST.in index 1d4bdc5..f2f65de 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ -include LICENSE.txt README.rst notes.org tiny-client-demo.py +include LICENSE.txt README.rst notes.org tiny-client-demo.py h11/py.typed recursive-include docs * -recursive-include h11/tests/data * +recursive-include h11/tests * recursive-include fuzz * prune docs/build diff --git a/README.rst b/README.rst index dcfd53d..5f28616 100644 --- a/README.rst +++ b/README.rst @@ -112,7 +112,7 @@ library. It has a test suite with 100.0% coverage for both statements and branches. -Currently it supports Python 3 (testing on 3.6-3.9) and PyPy 3. +Currently it supports Python 3 (testing on 3.8-3.12) and PyPy 3. The last Python 2-compatible version was h11 0.11.x. (Originally it had a Cython wrapper for `http-parser `_ and a beautiful nested state diff --git a/bench/benchmarks/benchmarks.py b/bench/benchmarks/benchmarks.py index abc0079..73d078e 100644 --- a/bench/benchmarks/benchmarks.py +++ b/bench/benchmarks/benchmarks.py @@ -60,7 +60,7 @@ def _run_basic_get_repeatedly(): for _ in range(REPEAT): time_server_basic_get_with_realistic_headers() finish = default_timer() - print("{:.1f} requests/sec".format(REPEAT / (finish - start))) + print(f"{REPEAT / (finish - start):.1f} requests/sec") if __name__ == "__main__": diff --git a/docs/requirements.txt b/docs/requirements.txt index b33e9c4..1c6aca5 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,6 @@ mistune jsonschema ipython +sphinx<4 +jinja2<3 +markupsafe<2 diff --git a/docs/source/changes.rst b/docs/source/changes.rst index b6a6ccd..db234fd 100644 --- a/docs/source/changes.rst +++ b/docs/source/changes.rst @@ -5,6 +5,83 @@ History of changes .. towncrier release notes start +H11 0.16.0 (2025-04-23) +----------------------- + +Security fix +~~~~~~~~~~~~ + +Reject certain malformed `Transfer-Encoding: chunked` bodies that were previously accepted. These could have enabled request-smuggling attacks when an h11-based HTTP server was placed behind a load balancer with a matching bug in its `chunked` handling. + +Advisory with more details: https://github.com/python-hyper/h11/security/advisories/GHSA-vqfr-h8mv-ghfj + +Reported by: Jeppe Bonde Weikop + +H11 0.15.0 (2025-04-23) +----------------------- + +Bugfixes +~~~~~~~~ + +- Reject Content-Lengths >= 1 zettabyte (1 billion terabytes) early, `without attempting to parse the integer `__ (`#178 `__) + + +Miscellaneous internal changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Remove the `tests` folder from wheel files. This reduces the zipped file size by 20KB (about 30%). (`#158 `__) + + +H11 0.14.0 (2022-09-25) +----------------------- + +Features +~~~~~~~~ + +- Allow additional trailing whitespace in chunk headers for additional + compatibility with existing servers. (`#133 + `__) +- Improve the type hints for Sentinel types, which should make it + easier to type hint h11 usage. (`#151 + `__ & `#144 + `__)) + +Deprecations and Removals +~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Python 3.6 support is removed. h11 now requires Python>=3.7 + including PyPy 3. Users running `pip install h11` on Python 2 will + automatically get the last Python 2-compatible version. (`#138 + `__) + + +v0.13.0 (2022-01-19) +-------------------- + +Features +~~~~~~~~ + +- Clarify that the Headers class is a Sequence and inherit from the + collections Sequence abstract base class to also indicate this (and + gain the mixin methods). See also #104. (`#112 + `__) +- Switch event classes to dataclasses for easier typing and slightly + improved performance. (`#124 + `__) +- Shorten traceback of protocol errors for easier readability (`#132 + `__). +- Add typing including a PEP 561 marker for usage by type checkers + (`#135 `__). +- Expand the allowed status codes to [0, 999] from [0, 600] (`#134 + https://github.com/python-hyper/h11/issues/134`__). + +Backwards **in**\compatible changes +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- Ensure request method is a valid token (`#141 + https://github.com/python-hyper/h11/pull/141>`__). + + v0.12.0 (2021-01-01) -------------------- diff --git a/docs/source/conf.py b/docs/source/conf.py index 0d8b494..b3627f5 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,5 +1,4 @@ #!/usr/bin/env python3 -# -*- coding: utf-8 -*- # # h11 documentation build configuration file, created by # sphinx-quickstart on Tue May 3 00:20:14 2016. diff --git a/docs/source/index.rst b/docs/source/index.rst index 31ca36f..ee02847 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,7 +44,7 @@ whatever. But h11 makes it much easier to implement something like Vital statistics ---------------- -* Requirements: Python 3.6+ (PyPy works great) +* Requirements: Python 3.8+ (PyPy works great) The last Python 2-compatible version was h11 0.11.x. diff --git a/docs/source/make-state-diagrams.py b/docs/source/make-state-diagrams.py index 16f033e..617efa5 100644 --- a/docs/source/make-state-diagrams.py +++ b/docs/source/make-state-diagrams.py @@ -38,16 +38,16 @@ def __init__(self): def e(self, source, target, label, color, italicize=False, weight=1): if italicize: - quoted_label = "<{}>".format(label) + quoted_label = f"<{label}>" else: - quoted_label = '<{}>'.format(label) + quoted_label = f'<{label}>' self.edges.append( - '{source} -> {target} [\n' - ' label={quoted_label},\n' - ' color="{color}", fontcolor="{color}",\n' - ' weight={weight},\n' - ']\n' - .format(**locals())) + f'{source} -> {target} [\n' + f' label={quoted_label},\n' + f' color="{color}", fontcolor="{color}",\n' + f' weight={weight},\n' + f']\n' + ) def write(self, f): self.edges.sort() @@ -150,7 +150,7 @@ def make_dot(role, out_path): else: (their_state, our_state) = state_pair edges.e(our_state, updates[role], - "peer in
{}".format(their_state), + f"peer in
{their_state}", color=_STATE_COLOR) if role is CLIENT: diff --git a/examples/trio-server.py b/examples/trio-server.py index 214ab51..996afb6 100644 --- a/examples/trio-server.py +++ b/examples/trio-server.py @@ -1,6 +1,5 @@ # A simple HTTP server implemented using h11 and Trio: # http://trio.readthedocs.io/en/latest/index.html -# (so requires python 3.5+). # # All requests get echoed back a JSON document containing information about # the request. @@ -75,21 +74,39 @@ # - We should probably do something cleverer with buffering responses and # TCP_CORK and suchlike. +import datetime +import email.utils import json from itertools import count -from wsgiref.handlers import format_date_time import trio import h11 -MAX_RECV = 2 ** 16 +MAX_RECV = 2**16 TIMEOUT = 10 + +# We are using email.utils.format_datetime to generate the Date header. +# It may sound weird, but it actually follows the RFC. +# Please see: https://stackoverflow.com/a/59416334/14723771 +# +# See also: +# [1] https://www.rfc-editor.org/rfc/rfc9110#section-5.6.7 +# [2] https://www.rfc-editor.org/rfc/rfc7231#section-7.1.1.1 +# [3] https://www.rfc-editor.org/rfc/rfc5322#section-3.3 +def format_date_time(dt=None): + """Generate a RFC 7231 / RFC 9110 IMF-fixdate string""" + if dt is None: + dt = datetime.datetime.now(datetime.timezone.utc) + return email.utils.format_datetime(dt, usegmt=True) + + ################################################################ # I/O adapter: h11 <-> trio ################################################################ + # The core of this could be factored out to be usable for trio-based clients # too, as well as servers. But as a simplified pedagogical example we don't # attempt this here. @@ -101,7 +118,7 @@ def __init__(self, stream): self.conn = h11.Connection(h11.SERVER) # Our Server: header self.ident = " ".join( - ["h11-example-trio-server/{}".format(h11.__version__), h11.PRODUCT_ID] + [f"h11-example-trio-server/{h11.__version__}", h11.PRODUCT_ID] ).encode("ascii") # A unique id for this connection, to include in debugging output # (useful for understanding what's going on if there are multiple @@ -114,7 +131,13 @@ async def send(self, event): # appropriate when 'data' is None. assert type(event) is not h11.ConnectionClosed data = self.conn.send(event) - await self.stream.send_all(data) + try: + await self.stream.send_all(data) + except BaseException: + # If send_all raises an exception (especially trio.Cancelled), + # we have no choice but to give it up. + self.conn.send_failed() + raise async def _read_from_peer(self): if self.conn.they_are_waiting_for_100_continue: @@ -177,19 +200,20 @@ def basic_headers(self): # HTTP requires these headers in all responses (client would do # something different here) return [ - ("Date", format_date_time(None).encode("ascii")), + ("Date", format_date_time().encode("ascii")), ("Server", self.ident), ] def info(self, *args): # Little debugging method - print("{}:".format(self._obj_id), *args) + print(f"{self._obj_id}:", *args) ################################################################ # Server main loop ################################################################ + # General theory: # # If everything goes well: @@ -229,7 +253,7 @@ async def http_serve(stream): if type(event) is h11.Request: await send_echo_response(wrapper, event) except Exception as exc: - wrapper.info("Error during response handler: {!r}".format(exc)) + wrapper.info(f"Error during response handler: {exc!r}") await maybe_send_error_response(wrapper, exc) if wrapper.conn.our_state is h11.MUST_CLOSE: @@ -244,7 +268,7 @@ async def http_serve(stream): states = wrapper.conn.states wrapper.info("unexpected state", states, "-- bailing out") await maybe_send_error_response( - wrapper, RuntimeError("unexpected state {}".format(states)) + wrapper, RuntimeError(f"unexpected state {states}") ) await wrapper.shutdown_and_clean_up() return @@ -254,6 +278,7 @@ async def http_serve(stream): # Actual response handlers ################################################################ + # Helper function async def send_simple_response(wrapper, status_code, content_type, body): wrapper.info("Sending", status_code, "response with", len(body), "bytes") @@ -318,7 +343,7 @@ async def send_echo_response(wrapper, request): async def serve(port): - print("listening on http://localhost:{}".format(port)) + print(f"listening on http://localhost:{port}") try: await trio.serve_tcp(http_serve, port) except KeyboardInterrupt: diff --git a/format-requirements.txt b/format-requirements.txt new file mode 100644 index 0000000..a45e8c9 --- /dev/null +++ b/format-requirements.txt @@ -0,0 +1,2 @@ +black==23.3.0 +isort==5.12.0 \ No newline at end of file diff --git a/h11/__init__.py b/h11/__init__.py index ae39e01..989e92c 100644 --- a/h11/__init__.py +++ b/h11/__init__.py @@ -6,16 +6,57 @@ # semantics to check that what you're asking to write to the wire is sensible, # but at least it gets you out of dealing with the wire itself. -from ._connection import * -from ._events import * -from ._state import * -from ._util import LocalProtocolError, ProtocolError, RemoteProtocolError -from ._version import __version__ +from h11._connection import Connection, NEED_DATA, PAUSED +from h11._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from h11._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from h11._util import LocalProtocolError, ProtocolError, RemoteProtocolError +from h11._version import __version__ PRODUCT_ID = "python-h11/" + __version__ -__all__ = ["ProtocolError", "LocalProtocolError", "RemoteProtocolError"] -__all__ += _events.__all__ -__all__ += _connection.__all__ -__all__ += _state.__all__ +__all__ = ( + "Connection", + "NEED_DATA", + "PAUSED", + "ConnectionClosed", + "Data", + "EndOfMessage", + "Event", + "InformationalResponse", + "Request", + "Response", + "CLIENT", + "CLOSED", + "DONE", + "ERROR", + "IDLE", + "MUST_CLOSE", + "SEND_BODY", + "SEND_RESPONSE", + "SERVER", + "SWITCHED_PROTOCOL", + "ProtocolError", + "LocalProtocolError", + "RemoteProtocolError", +) diff --git a/h11/_abnf.py b/h11/_abnf.py index e6d49e1..933587f 100644 --- a/h11/_abnf.py +++ b/h11/_abnf.py @@ -125,5 +125,8 @@ chunk_header = ( r"(?P{chunk_size})" r"(?P{chunk_ext})?" - r"\r\n".format(**globals()) + r"{OWS}\r\n".format( + **globals() + ) # Even though the specification does not allow for extra whitespaces, + # we are lenient with trailing whitespaces because some servers on the wild use it. ) diff --git a/h11/_connection.py b/h11/_connection.py index 6f796ef..e37d82a 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -1,28 +1,64 @@ # This contains the main Connection class. Everything in h11 revolves around # this. +from typing import ( + Any, + Callable, + cast, + Dict, + List, + Optional, + overload, + Tuple, + Type, + Union, +) -from ._events import * # Import all event types +from ._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from ._headers import get_comma_header, has_expect_100_continue, set_comma_header -from ._readers import READERS +from ._readers import READERS, ReadersType from ._receivebuffer import ReceiveBuffer -from ._state import * # Import all state sentinels -from ._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from ._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + ConnectionState, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + SEND_BODY, + SERVER, + SWITCHED_PROTOCOL, +) from ._util import ( # Import the internal things we need LocalProtocolError, - make_sentinel, RemoteProtocolError, + Sentinel, ) -from ._writers import WRITERS +from ._writers import WRITERS, WritersType # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = ["Connection", "NEED_DATA", "PAUSED"] -NEED_DATA = make_sentinel("NEED_DATA") -PAUSED = make_sentinel("PAUSED") + +class NEED_DATA(Sentinel, metaclass=Sentinel): + pass + + +class PAUSED(Sentinel, metaclass=Sentinel): + pass + # If we ever have this much buffered without it making a complete parseable # event, we error out. The only time we really buffer is when reading the -# request/reponse line + headers together, so this is effectively the limit on +# request/response line + headers together, so this is effectively the limit on # the size of that. # # Some precedents for defaults: @@ -32,6 +68,7 @@ # - Apache: <8 KiB per line> DEFAULT_MAX_INCOMPLETE_EVENT_SIZE = 16 * 1024 + # RFC 7230's rules for connection lifecycles: # - If either side says they want to close the connection, then the connection # must close. @@ -44,7 +81,7 @@ # our rule is: # - If someone says Connection: close, we will close # - If someone uses HTTP/1.0, we will close. -def _keep_alive(event): +def _keep_alive(event: Union[Request, Response]) -> bool: connection = get_comma_header(event.headers, b"connection") if b"close" in connection: return False @@ -53,7 +90,9 @@ def _keep_alive(event): return True -def _body_framing(request_method, event): +def _body_framing( + request_method: bytes, event: Union[Request, Response] +) -> Tuple[str, Union[Tuple[()], Tuple[int]]]: # Called when we enter SEND_BODY to figure out framing information for # this body. # @@ -126,13 +165,16 @@ class Connection: """ def __init__( - self, our_role, max_incomplete_event_size=DEFAULT_MAX_INCOMPLETE_EVENT_SIZE - ): + self, + our_role: Type[Sentinel], + max_incomplete_event_size: int = DEFAULT_MAX_INCOMPLETE_EVENT_SIZE, + ) -> None: self._max_incomplete_event_size = max_incomplete_event_size # State and role tracking if our_role not in (CLIENT, SERVER): - raise ValueError("expected CLIENT or SERVER, not {!r}".format(our_role)) + raise ValueError(f"expected CLIENT or SERVER, not {our_role!r}") self.our_role = our_role + self.their_role: Type[Sentinel] if our_role is CLIENT: self.their_role = SERVER else: @@ -155,14 +197,14 @@ def __init__( # These two are only used to interpret framing headers for figuring # out how to read/write response bodies. their_http_version is also # made available as a convenient public API. - self.their_http_version = None - self._request_method = None + self.their_http_version: Optional[bytes] = None + self._request_method: Optional[bytes] = None # This is pure flow-control and doesn't at all affect the set of legal # transitions, so no need to bother ConnectionState with it: self.client_is_waiting_for_100_continue = False @property - def states(self): + def states(self) -> Dict[Type[Sentinel], Type[Sentinel]]: """A dictionary like:: {CLIENT: , SERVER: } @@ -173,24 +215,24 @@ def states(self): return dict(self._cstate.states) @property - def our_state(self): + def our_state(self) -> Type[Sentinel]: """The current state of whichever role we are playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.our_role] @property - def their_state(self): + def their_state(self) -> Type[Sentinel]: """The current state of whichever role we are NOT playing. See :ref:`state-machine` for details. """ return self._cstate.states[self.their_role] @property - def they_are_waiting_for_100_continue(self): + def they_are_waiting_for_100_continue(self) -> bool: return self.their_role is CLIENT and self.client_is_waiting_for_100_continue - def start_next_cycle(self): + def start_next_cycle(self) -> None: """Attempt to reset our connection state for a new request/response cycle. @@ -210,12 +252,12 @@ def start_next_cycle(self): assert not self.client_is_waiting_for_100_continue self._respond_to_state_changes(old_states) - def _process_error(self, role): + def _process_error(self, role: Type[Sentinel]) -> None: old_states = dict(self._cstate.states) self._cstate.process_error(role) self._respond_to_state_changes(old_states) - def _server_switch_event(self, event): + def _server_switch_event(self, event: Event) -> Optional[Type[Sentinel]]: if type(event) is InformationalResponse and event.status_code == 101: return _SWITCH_UPGRADE if type(event) is Response: @@ -227,7 +269,7 @@ def _server_switch_event(self, event): return None # All events go through here - def _process_event(self, role, event): + def _process_event(self, role: Type[Sentinel], event: Event) -> None: # First, pass the event through the state machine to make sure it # succeeds. old_states = dict(self._cstate.states) @@ -243,16 +285,15 @@ def _process_event(self, role, event): # Then perform the updates triggered by it. - # self._request_method if type(event) is Request: self._request_method = event.method - # self.their_http_version if role is self.their_role and type(event) in ( Request, Response, InformationalResponse, ): + event = cast(Union[Request, Response, InformationalResponse], event) self.their_http_version = event.http_version # Keep alive handling @@ -261,7 +302,9 @@ def _process_event(self, role, event): # shows up on a 1xx InformationalResponse. I think the idea is that # this is not supposed to happen. In any case, if it does happen, we # ignore it. - if type(event) in (Request, Response) and not _keep_alive(event): + if type(event) in (Request, Response) and not _keep_alive( + cast(Union[Request, Response], event) + ): self._cstate.process_keep_alive_disabled() # 100-continue @@ -274,22 +317,33 @@ def _process_event(self, role, event): self._respond_to_state_changes(old_states, event) - def _get_io_object(self, role, event, io_dict): + def _get_io_object( + self, + role: Type[Sentinel], + event: Optional[Event], + io_dict: Union[ReadersType, WritersType], + ) -> Optional[Callable[..., Any]]: # event may be None; it's only used when entering SEND_BODY state = self._cstate.states[role] if state is SEND_BODY: # Special case: the io_dict has a dict of reader/writer factories # that depend on the request/response framing. - framing_type, args = _body_framing(self._request_method, event) - return io_dict[SEND_BODY][framing_type](*args) + framing_type, args = _body_framing( + cast(bytes, self._request_method), cast(Union[Request, Response], event) + ) + return io_dict[SEND_BODY][framing_type](*args) # type: ignore[index] else: # General case: the io_dict just has the appropriate reader/writer # for this state - return io_dict.get((role, state)) + return io_dict.get((role, state)) # type: ignore[return-value] # This must be called after any action that might have caused # self._cstate.states to change. - def _respond_to_state_changes(self, old_states, event=None): + def _respond_to_state_changes( + self, + old_states: Dict[Type[Sentinel], Type[Sentinel]], + event: Optional[Event] = None, + ) -> None: # Update reader/writer if self.our_state != old_states[self.our_role]: self._writer = self._get_io_object(self.our_role, event, WRITERS) @@ -297,7 +351,7 @@ def _respond_to_state_changes(self, old_states, event=None): self._reader = self._get_io_object(self.their_role, event, READERS) @property - def trailing_data(self): + def trailing_data(self) -> Tuple[bytes, bool]: """Data that has been received, but not yet processed, represented as a tuple with two elements, where the first is a byte-string containing the unprocessed data itself, and the second is a bool that is True if @@ -307,7 +361,7 @@ def trailing_data(self): """ return (bytes(self._receive_buffer), self._receive_buffer_closed) - def receive_data(self, data): + def receive_data(self, data: bytes) -> None: """Add data to our internal receive buffer. This does not actually do any processing on the data, just stores @@ -353,7 +407,9 @@ def receive_data(self, data): else: self._receive_buffer_closed = True - def _extract_next_receive_event(self): + def _extract_next_receive_event( + self, + ) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: state = self.their_state # We don't pause immediately when they enter DONE, because even in # DONE state we can still process a ConnectionClosed() event. But @@ -377,9 +433,9 @@ def _extract_next_receive_event(self): event = ConnectionClosed() if event is None: event = NEED_DATA - return event + return event # type: ignore[no-any-return] - def next_event(self): + def next_event(self) -> Union[Event, Type[NEED_DATA], Type[PAUSED]]: """Parse the next event out of our receive buffer, update our internal state, and return it. @@ -424,7 +480,7 @@ def next_event(self): try: event = self._extract_next_receive_event() if event not in [NEED_DATA, PAUSED]: - self._process_event(self.their_role, event) + self._process_event(self.their_role, cast(Event, event)) if event is NEED_DATA: if len(self._receive_buffer) > self._max_incomplete_event_size: # 431 is "Request header fields too large" which is pretty @@ -444,7 +500,21 @@ def next_event(self): else: raise - def send(self, event): + @overload + def send(self, event: ConnectionClosed) -> None: + ... + + @overload + def send( + self, event: Union[Request, InformationalResponse, Response, Data, EndOfMessage] + ) -> bytes: + ... + + @overload + def send(self, event: Event) -> Optional[bytes]: + ... + + def send(self, event: Event) -> Optional[bytes]: """Convert a high-level event into bytes that can be sent to the peer, while updating our internal state machine. @@ -471,7 +541,7 @@ def send(self, event): else: return b"".join(data_list) - def send_with_data_passthrough(self, event): + def send_with_data_passthrough(self, event: Event) -> Optional[List[bytes]]: """Identical to :meth:`send`, except that in situations where :meth:`send` returns a single :term:`bytes-like object`, this instead returns a list of them -- and when sending a :class:`Data` event, this @@ -483,7 +553,7 @@ def send_with_data_passthrough(self, event): raise LocalProtocolError("Can't send data when our state is ERROR") try: if type(event) is Response: - self._clean_up_response_headers_for_sending(event) + event = self._clean_up_response_headers_for_sending(event) # We want to call _process_event before calling the writer, # because if someone tries to do something invalid then this will # give a sensible error message, while our writers all just assume @@ -497,14 +567,14 @@ def send_with_data_passthrough(self, event): # In any situation where writer is None, process_event should # have raised ProtocolError assert writer is not None - data_list = [] + data_list: List[bytes] = [] writer(event, data_list.append) return data_list except: self._process_error(self.our_role) raise - def send_failed(self): + def send_failed(self) -> None: """Notify the state machine that we failed to send the data it gave us. @@ -528,9 +598,8 @@ def send_failed(self): # # This function's *only* responsibility is making sure headers are set up # right -- everything downstream just looks at the headers. There are no - # side channels. It mutates the response event in-place (but not the - # response.headers list object). - def _clean_up_response_headers_for_sending(self, response): + # side channels. + def _clean_up_response_headers_for_sending(self, response: Response) -> Response: assert type(response) is Response headers = response.headers @@ -543,7 +612,7 @@ def _clean_up_response_headers_for_sending(self, response): # we're allowed to leave out the framing headers -- see # https://tools.ietf.org/html/rfc7231#section-4.3.2 . But it's just as # easy to get them right.) - method_for_choosing_headers = self._request_method + method_for_choosing_headers = cast(bytes, self._request_method) if method_for_choosing_headers == b"HEAD": method_for_choosing_headers = b"GET" framing_type, _ = _body_framing(method_for_choosing_headers, response) @@ -573,7 +642,7 @@ def _clean_up_response_headers_for_sending(self, response): if self._request_method != b"HEAD": need_close = True else: - headers = set_comma_header(headers, b"transfer-encoding", ["chunked"]) + headers = set_comma_header(headers, b"transfer-encoding", [b"chunked"]) if not self._cstate.keep_alive or need_close: # Make sure Connection: close is set @@ -582,4 +651,9 @@ def _clean_up_response_headers_for_sending(self, response): connection.add(b"close") headers = set_comma_header(headers, b"connection", sorted(connection)) - response.headers = headers + return Response( + headers=headers, + status_code=response.status_code, + http_version=response.http_version, + reason=response.reason, + ) diff --git a/h11/_events.py b/h11/_events.py index 1827930..ca1c3ad 100644 --- a/h11/_events.py +++ b/h11/_events.py @@ -6,13 +6,17 @@ # Don't subclass these. Stuff will break. import re +from abc import ABC +from dataclasses import dataclass +from typing import List, Tuple, Union -from . import _headers -from ._abnf import request_target +from ._abnf import method, request_target +from ._headers import Headers, normalize_and_validate from ._util import bytesify, LocalProtocolError, validate # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ + "Event", "Request", "InformationalResponse", "Response", @@ -21,75 +25,20 @@ "ConnectionClosed", ] +method_re = re.compile(method.encode("ascii")) request_target_re = re.compile(request_target.encode("ascii")) -class _EventBundle: - _fields = [] - _defaults = {} - - def __init__(self, **kwargs): - _parsed = kwargs.pop("_parsed", False) - allowed = set(self._fields) - for kwarg in kwargs: - if kwarg not in allowed: - raise TypeError( - "unrecognized kwarg {} for {}".format( - kwarg, self.__class__.__name__ - ) - ) - required = allowed.difference(self._defaults) - for field in required: - if field not in kwargs: - raise TypeError( - "missing required kwarg {} for {}".format( - field, self.__class__.__name__ - ) - ) - self.__dict__.update(self._defaults) - self.__dict__.update(kwargs) - - # Special handling for some fields - - if "headers" in self.__dict__: - self.headers = _headers.normalize_and_validate( - self.headers, _parsed=_parsed - ) - - if not _parsed: - for field in ["method", "target", "http_version", "reason"]: - if field in self.__dict__: - self.__dict__[field] = bytesify(self.__dict__[field]) - - if "status_code" in self.__dict__: - if not isinstance(self.status_code, int): - raise LocalProtocolError("status code must be integer") - # Because IntEnum objects are instances of int, but aren't - # duck-compatible (sigh), see gh-72. - self.status_code = int(self.status_code) - - self._validate() - - def _validate(self): - pass - - def __repr__(self): - name = self.__class__.__name__ - kwarg_strs = [ - "{}={}".format(field, self.__dict__[field]) for field in self._fields - ] - kwarg_str = ", ".join(kwarg_strs) - return "{}({})".format(name, kwarg_str) - - # Useful for tests - def __eq__(self, other): - return self.__class__ == other.__class__ and self.__dict__ == other.__dict__ +class Event(ABC): + """ + Base class for h11 events. + """ - # This is an unhashable type. - __hash__ = None + __slots__ = () -class Request(_EventBundle): +@dataclass(init=False, frozen=True) +class Request(Event): """The beginning of an HTTP request. Fields: @@ -123,10 +72,38 @@ class Request(_EventBundle): """ - _fields = ["method", "target", "headers", "http_version"] - _defaults = {"http_version": b"1.1"} + __slots__ = ("method", "headers", "target", "http_version") + + method: bytes + headers: Headers + target: bytes + http_version: bytes + + def __init__( + self, + *, + method: Union[bytes, str], + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + target: Union[bytes, str], + http_version: Union[bytes, str] = b"1.1", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "method", bytesify(method)) + object.__setattr__(self, "target", bytesify(target)) + object.__setattr__(self, "http_version", bytesify(http_version)) + else: + object.__setattr__(self, "method", method) + object.__setattr__(self, "target", target) + object.__setattr__(self, "http_version", http_version) - def _validate(self): # "A server MUST respond with a 400 (Bad Request) status code to any # HTTP/1.1 request message that lacks a Host header field and to any # request message that contains more than one Host header field or a @@ -141,14 +118,61 @@ def _validate(self): if host_count > 1: raise LocalProtocolError("Found multiple Host: headers") + validate(method_re, self.method, "Illegal method characters") validate(request_target_re, self.target, "Illegal target characters") + # This is an unhashable type. + __hash__ = None # type: ignore + + +@dataclass(init=False, frozen=True) +class _ResponseBase(Event): + __slots__ = ("headers", "http_version", "reason", "status_code") + + headers: Headers + http_version: bytes + reason: bytes + status_code: int + + def __init__( + self, + *, + headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]], + status_code: int, + http_version: Union[bytes, str] = b"1.1", + reason: Union[bytes, str] = b"", + _parsed: bool = False, + ) -> None: + super().__init__() + if isinstance(headers, Headers): + object.__setattr__(self, "headers", headers) + else: + object.__setattr__( + self, "headers", normalize_and_validate(headers, _parsed=_parsed) + ) + if not _parsed: + object.__setattr__(self, "reason", bytesify(reason)) + object.__setattr__(self, "http_version", bytesify(http_version)) + if not isinstance(status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + object.__setattr__(self, "status_code", int(status_code)) + else: + object.__setattr__(self, "reason", reason) + object.__setattr__(self, "http_version", http_version) + object.__setattr__(self, "status_code", status_code) + + self.__post_init__() + + def __post_init__(self) -> None: + pass -class _ResponseBase(_EventBundle): - _fields = ["status_code", "headers", "http_version", "reason"] - _defaults = {"http_version": b"1.1", "reason": b""} + # This is an unhashable type. + __hash__ = None # type: ignore +@dataclass(init=False, frozen=True) class InformationalResponse(_ResponseBase): """An HTTP informational response. @@ -179,14 +203,18 @@ class InformationalResponse(_ResponseBase): """ - def _validate(self): + def __post_init__(self) -> None: if not (100 <= self.status_code < 200): raise LocalProtocolError( "InformationalResponse status_code should be in range " "[100, 200), not {}".format(self.status_code) ) + # This is an unhashable type. + __hash__ = None # type: ignore + +@dataclass(init=False, frozen=True) class Response(_ResponseBase): """The beginning of an HTTP response. @@ -196,7 +224,7 @@ class Response(_ResponseBase): The status code of this response, as an integer. For an :class:`Response`, this is always in the range [200, - 600). + 1000). .. attribute:: headers @@ -216,16 +244,20 @@ class Response(_ResponseBase): """ - def _validate(self): - if not (200 <= self.status_code < 600): + def __post_init__(self) -> None: + if not (200 <= self.status_code < 1000): raise LocalProtocolError( - "Response status_code should be in range [200, 600), not {}".format( + "Response status_code should be in range [200, 1000), not {}".format( self.status_code ) ) + # This is an unhashable type. + __hash__ = None # type: ignore + -class Data(_EventBundle): +@dataclass(init=False, frozen=True) +class Data(Event): """Part of an HTTP message body. Fields: @@ -258,8 +290,21 @@ class Data(_EventBundle): """ - _fields = ["data", "chunk_start", "chunk_end"] - _defaults = {"chunk_start": False, "chunk_end": False} + __slots__ = ("data", "chunk_start", "chunk_end") + + data: bytes + chunk_start: bool + chunk_end: bool + + def __init__( + self, data: bytes, chunk_start: bool = False, chunk_end: bool = False + ) -> None: + object.__setattr__(self, "data", data) + object.__setattr__(self, "chunk_start", chunk_start) + object.__setattr__(self, "chunk_end", chunk_end) + + # This is an unhashable type. + __hash__ = None # type: ignore # XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that @@ -267,7 +312,8 @@ class Data(_EventBundle): # present in the header section might bypass external security filters." # https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part # Unfortunately, the list of forbidden fields is long and vague :-/ -class EndOfMessage(_EventBundle): +@dataclass(init=False, frozen=True) +class EndOfMessage(Event): """The end of an HTTP message. Fields: @@ -284,11 +330,32 @@ class EndOfMessage(_EventBundle): """ - _fields = ["headers"] - _defaults = {"headers": []} + __slots__ = ("headers",) + + headers: Headers + + def __init__( + self, + *, + headers: Union[ + Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None + ] = None, + _parsed: bool = False, + ) -> None: + super().__init__() + if headers is None: + headers = Headers([]) + elif not isinstance(headers, Headers): + headers = normalize_and_validate(headers, _parsed=_parsed) + + object.__setattr__(self, "headers", headers) + + # This is an unhashable type. + __hash__ = None # type: ignore -class ConnectionClosed(_EventBundle): +@dataclass(frozen=True) +class ConnectionClosed(Event): """This event indicates that the sender has closed their outgoing connection. diff --git a/h11/_headers.py b/h11/_headers.py index 7ed39bc..31da3e2 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,8 +1,20 @@ import re +from typing import AnyStr, cast, List, overload, Sequence, Tuple, TYPE_CHECKING, Union from ._abnf import field_name, field_value from ._util import bytesify, LocalProtocolError, validate +if TYPE_CHECKING: + from ._events import Request + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + +CONTENT_LENGTH_MAX_DIGITS = 20 # allow up to 1 billion TB - 1 + + # Facts # ----- # @@ -57,12 +69,12 @@ # # Maybe a dict-of-lists would be better? -_content_length_re = re.compile(br"[0-9]+") +_content_length_re = re.compile(rb"[0-9]+") _field_name_re = re.compile(field_name.encode("ascii")) _field_value_re = re.compile(field_value.encode("ascii")) -class Headers: +class Headers(Sequence[Tuple[bytes, bytes]]): """ A list-like interface that allows iterating over headers as byte-pairs of (lowercased-name, value). @@ -89,34 +101,57 @@ class Headers: __slots__ = "_full_items" - def __init__(self, full_items): + def __init__(self, full_items: List[Tuple[bytes, bytes, bytes]]) -> None: self._full_items = full_items - def __iter__(self): - for _, name, value in self._full_items: - yield name, value - - def __bool__(self): + def __bool__(self) -> bool: return bool(self._full_items) - def __eq__(self, other): - return list(self) == list(other) + def __eq__(self, other: object) -> bool: + return list(self) == list(other) # type: ignore - def __len__(self): + def __len__(self) -> int: return len(self._full_items) - def __repr__(self): + def __repr__(self) -> str: return "" % repr(list(self)) - def __getitem__(self, idx): + def __getitem__(self, idx: int) -> Tuple[bytes, bytes]: # type: ignore[override] _, name, value = self._full_items[idx] return (name, value) - def raw_items(self): + def raw_items(self) -> List[Tuple[bytes, bytes]]: return [(raw_name, value) for raw_name, _, value in self._full_items] -def normalize_and_validate(headers, _parsed=False): +HeaderTypes = Union[ + List[Tuple[bytes, bytes]], + List[Tuple[bytes, str]], + List[Tuple[str, bytes]], + List[Tuple[str, str]], +] + + +@overload +def normalize_and_validate(headers: Headers, _parsed: Literal[True]) -> Headers: + ... + + +@overload +def normalize_and_validate(headers: HeaderTypes, _parsed: Literal[False]) -> Headers: + ... + + +@overload +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: + ... + + +def normalize_and_validate( + headers: Union[Headers, HeaderTypes], _parsed: bool = False +) -> Headers: new_headers = [] seen_content_length = None saw_transfer_encoding = False @@ -129,6 +164,9 @@ def normalize_and_validate(headers, _parsed=False): value = bytesify(value) validate(_field_name_re, name, "Illegal header name {!r}", name) validate(_field_value_re, value, "Illegal header value {!r}", value) + assert isinstance(name, bytes) + assert isinstance(value, bytes) + raw_name = name name = name.lower() if name == b"content-length": @@ -137,6 +175,8 @@ def normalize_and_validate(headers, _parsed=False): raise LocalProtocolError("conflicting Content-Length headers") value = lengths.pop() validate(_content_length_re, value, "bad Content-Length") + if len(value) > CONTENT_LENGTH_MAX_DIGITS: + raise LocalProtocolError("bad Content-Length") if seen_content_length is None: seen_content_length = value new_headers.append((raw_name, name, value)) @@ -166,7 +206,7 @@ def normalize_and_validate(headers, _parsed=False): return Headers(new_headers) -def get_comma_header(headers, name): +def get_comma_header(headers: Headers, name: bytes) -> List[bytes]: # Should only be used for headers whose value is a list of # comma-separated, case-insensitive values. # @@ -202,7 +242,7 @@ def get_comma_header(headers, name): # Expect: the only legal value is the literal string # "100-continue". Splitting on commas is harmless. Case insensitive. # - out = [] + out: List[bytes] = [] for _, found_name, found_raw_value in headers._full_items: if found_name == name: found_raw_value = found_raw_value.lower() @@ -213,7 +253,7 @@ def get_comma_header(headers, name): return out -def set_comma_header(headers, name, new_values): +def set_comma_header(headers: Headers, name: bytes, new_values: List[bytes]) -> Headers: # The header name `name` is expected to be lower-case bytes. # # Note that when we store the header we use title casing for the header @@ -223,7 +263,7 @@ def set_comma_header(headers, name, new_values): # here given the cases where we're using `set_comma_header`... # # Connection, Content-Length, Transfer-Encoding. - new_headers = [] + new_headers: List[Tuple[bytes, bytes]] = [] for found_raw_name, found_name, found_raw_value in headers._full_items: if found_name != name: new_headers.append((found_raw_name, found_raw_value)) @@ -232,7 +272,7 @@ def set_comma_header(headers, name, new_values): return normalize_and_validate(new_headers) -def has_expect_100_continue(request): +def has_expect_100_continue(request: "Request") -> bool: # https://tools.ietf.org/html/rfc7231#section-5.1.1 # "A server that receives a 100-continue expectation in an HTTP/1.0 request # MUST ignore that expectation." diff --git a/h11/_readers.py b/h11/_readers.py index 0ead0be..576804c 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -17,30 +17,39 @@ # - or, for body readers, a dict of per-framing reader factories import re +from typing import Any, Callable, Dict, Iterable, NoReturn, Optional, Tuple, Type, Union from ._abnf import chunk_header, header_field, request_line, status_line -from ._events import * -from ._state import * -from ._util import LocalProtocolError, RemoteProtocolError, validate +from ._events import Data, EndOfMessage, InformationalResponse, Request, Response +from ._receivebuffer import ReceiveBuffer +from ._state import ( + CLIENT, + CLOSED, + DONE, + IDLE, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, +) +from ._util import LocalProtocolError, RemoteProtocolError, Sentinel, validate __all__ = ["READERS"] header_field_re = re.compile(header_field.encode("ascii")) +obs_fold_re = re.compile(rb"[ \t]+") -# Remember that this has to run in O(n) time -- so e.g. the bytearray cast is -# critical. -obs_fold_re = re.compile(br"[ \t]+") - -def _obsolete_line_fold(lines): +def _obsolete_line_fold(lines: Iterable[bytes]) -> Iterable[bytes]: it = iter(lines) - last = None + last: Optional[bytes] = None for line in it: match = obs_fold_re.match(line) if match: if last is None: raise LocalProtocolError("continuation line at start of headers") if not isinstance(last, bytearray): + # Cast to a mutable type, avoiding copy on append to ensure O(n) time last = bytearray(last) last += b" " last += line[match.end() :] @@ -52,7 +61,9 @@ def _obsolete_line_fold(lines): yield last -def _decode_header_lines(lines): +def _decode_header_lines( + lines: Iterable[bytes], +) -> Iterable[Tuple[bytes, bytes]]: for line in _obsolete_line_fold(lines): matches = validate(header_field_re, line, "illegal header line: {!r}", line) yield (matches["field_name"], matches["field_value"]) @@ -61,7 +72,7 @@ def _decode_header_lines(lines): request_line_re = re.compile(request_line.encode("ascii")) -def maybe_read_from_IDLE_client(buf): +def maybe_read_from_IDLE_client(buf: ReceiveBuffer) -> Optional[Request]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -80,7 +91,9 @@ def maybe_read_from_IDLE_client(buf): status_line_re = re.compile(status_line.encode("ascii")) -def maybe_read_from_SEND_RESPONSE_server(buf): +def maybe_read_from_SEND_RESPONSE_server( + buf: ReceiveBuffer, +) -> Union[InformationalResponse, Response, None]: lines = buf.maybe_extract_lines() if lines is None: if buf.is_next_line_obviously_invalid_request_line(): @@ -89,22 +102,29 @@ def maybe_read_from_SEND_RESPONSE_server(buf): if not lines: raise LocalProtocolError("no response line received") matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) - # Tolerate missing reason phrases - if matches["reason"] is None: - matches["reason"] = b"" - status_code = matches["status_code"] = int(matches["status_code"]) - class_ = InformationalResponse if status_code < 200 else Response + http_version = ( + b"1.1" if matches["http_version"] is None else matches["http_version"] + ) + reason = b"" if matches["reason"] is None else matches["reason"] + status_code = int(matches["status_code"]) + class_: Union[Type[InformationalResponse], Type[Response]] = ( + InformationalResponse if status_code < 200 else Response + ) return class_( - headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + headers=list(_decode_header_lines(lines[1:])), + _parsed=True, + status_code=status_code, + reason=reason, + http_version=http_version, ) class ContentLengthReader: - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length self._remaining = length - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._remaining == 0: return EndOfMessage() data = buf.maybe_extract_at_most(self._remaining) @@ -113,7 +133,7 @@ def __call__(self, buf): self._remaining -= len(data) return Data(data=data) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(received {} bytes, expected {})".format( @@ -126,29 +146,32 @@ def read_eof(self): class ChunkedReader: - def __init__(self): + def __init__(self) -> None: self._bytes_in_chunk = 0 - # After reading a chunk, we have to throw away the trailing \r\n; if - # this is >0 then we discard that many bytes before resuming regular - # de-chunkification. - self._bytes_to_discard = 0 + # After reading a chunk, we have to throw away the trailing \r\n. + # This tracks the bytes that we need to match and throw away. + self._bytes_to_discard = b"" self._reading_trailer = False - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Union[Data, EndOfMessage, None]: if self._reading_trailer: lines = buf.maybe_extract_lines() if lines is None: return None return EndOfMessage(headers=list(_decode_header_lines(lines))) - if self._bytes_to_discard > 0: - data = buf.maybe_extract_at_most(self._bytes_to_discard) + if self._bytes_to_discard: + data = buf.maybe_extract_at_most(len(self._bytes_to_discard)) if data is None: return None - self._bytes_to_discard -= len(data) - if self._bytes_to_discard > 0: + if data != self._bytes_to_discard[: len(data)]: + raise LocalProtocolError( + f"malformed chunk footer: {data!r} (expected {self._bytes_to_discard!r})" + ) + self._bytes_to_discard = self._bytes_to_discard[len(data) :] + if self._bytes_to_discard: return None # else, fall through and read some more - assert self._bytes_to_discard == 0 + assert self._bytes_to_discard == b"" if self._bytes_in_chunk == 0: # We need to refill our chunk count chunk_header = buf.maybe_extract_next_line() @@ -174,13 +197,13 @@ def __call__(self, buf): return None self._bytes_in_chunk -= len(data) if self._bytes_in_chunk == 0: - self._bytes_to_discard = 2 + self._bytes_to_discard = b"\r\n" chunk_end = True else: chunk_end = False return Data(data=data, chunk_start=chunk_start, chunk_end=chunk_end) - def read_eof(self): + def read_eof(self) -> NoReturn: raise RemoteProtocolError( "peer closed connection without sending complete message body " "(incomplete chunked read)" @@ -188,23 +211,28 @@ def read_eof(self): class Http10Reader: - def __call__(self, buf): + def __call__(self, buf: ReceiveBuffer) -> Optional[Data]: data = buf.maybe_extract_at_most(999999999) if data is None: return None return Data(data=data) - def read_eof(self): + def read_eof(self) -> EndOfMessage: return EndOfMessage() -def expect_nothing(buf): +def expect_nothing(buf: ReceiveBuffer) -> None: if buf: raise LocalProtocolError("Got data when expecting EOF") return None -READERS = { +ReadersType = Dict[ + Union[Type[Sentinel], Tuple[Type[Sentinel], Type[Sentinel]]], + Union[Callable[..., Any], Dict[str, Callable[..., Any]]], +] + +READERS: ReadersType = { (CLIENT, IDLE): maybe_read_from_IDLE_client, (SERVER, IDLE): maybe_read_from_SEND_RESPONSE_server, (SERVER, SEND_RESPONSE): maybe_read_from_SEND_RESPONSE_server, diff --git a/h11/_receivebuffer.py b/h11/_receivebuffer.py index a3737f3..e5c4e08 100644 --- a/h11/_receivebuffer.py +++ b/h11/_receivebuffer.py @@ -1,5 +1,6 @@ import re import sys +from typing import List, Optional, Union __all__ = ["ReceiveBuffer"] @@ -44,26 +45,26 @@ class ReceiveBuffer: - def __init__(self): + def __init__(self) -> None: self._data = bytearray() self._next_line_search = 0 self._multiple_lines_search = 0 - def __iadd__(self, byteslike): + def __iadd__(self, byteslike: Union[bytes, bytearray]) -> "ReceiveBuffer": self._data += byteslike return self - def __bool__(self): + def __bool__(self) -> bool: return bool(len(self)) - def __len__(self): + def __len__(self) -> int: return len(self._data) # for @property unprocessed_data - def __bytes__(self): + def __bytes__(self) -> bytes: return bytes(self._data) - def _extract(self, count): + def _extract(self, count: int) -> bytearray: # extracting an initial slice of the data buffer and return it out = self._data[:count] del self._data[:count] @@ -73,7 +74,7 @@ def _extract(self, count): return out - def maybe_extract_at_most(self, count): + def maybe_extract_at_most(self, count: int) -> Optional[bytearray]: """ Extract a fixed number of bytes from the buffer. """ @@ -83,7 +84,7 @@ def maybe_extract_at_most(self, count): return self._extract(count) - def maybe_extract_next_line(self): + def maybe_extract_next_line(self) -> Optional[bytearray]: """ Extract the first line, if it is completed in the buffer. """ @@ -100,7 +101,7 @@ def maybe_extract_next_line(self): return self._extract(idx) - def maybe_extract_lines(self): + def maybe_extract_lines(self) -> Optional[List[bytearray]]: """ Extract everything up to the first blank line, and return a list of lines. """ @@ -143,7 +144,7 @@ def maybe_extract_lines(self): # This is especially interesting when peer is messing up with HTTPS and # sent us a TLS stream where we were expecting plain HTTP given all # versions of TLS so far start handshake with a 0x16 message type code. - def is_next_line_obviously_invalid_request_line(self): + def is_next_line_obviously_invalid_request_line(self) -> bool: try: # HTTP header line must not contain non-printable characters # and should not start with a space diff --git a/h11/_state.py b/h11/_state.py index 0f08a09..3ad444b 100644 --- a/h11/_state.py +++ b/h11/_state.py @@ -110,9 +110,10 @@ # tables. But it can't automatically read the transitions that are written # directly in Python code. So if you touch those, you need to also update the # script to keep it in sync! +from typing import cast, Dict, Optional, Set, Tuple, Type, Union from ._events import * -from ._util import LocalProtocolError, make_sentinel +from ._util import LocalProtocolError, Sentinel # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ @@ -129,26 +130,70 @@ "ERROR", ] -CLIENT = make_sentinel("CLIENT") -SERVER = make_sentinel("SERVER") + +class CLIENT(Sentinel, metaclass=Sentinel): + pass + + +class SERVER(Sentinel, metaclass=Sentinel): + pass + # States -IDLE = make_sentinel("IDLE") -SEND_RESPONSE = make_sentinel("SEND_RESPONSE") -SEND_BODY = make_sentinel("SEND_BODY") -DONE = make_sentinel("DONE") -MUST_CLOSE = make_sentinel("MUST_CLOSE") -CLOSED = make_sentinel("CLOSED") -ERROR = make_sentinel("ERROR") +class IDLE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_RESPONSE(Sentinel, metaclass=Sentinel): + pass + + +class SEND_BODY(Sentinel, metaclass=Sentinel): + pass + + +class DONE(Sentinel, metaclass=Sentinel): + pass + + +class MUST_CLOSE(Sentinel, metaclass=Sentinel): + pass + + +class CLOSED(Sentinel, metaclass=Sentinel): + pass + + +class ERROR(Sentinel, metaclass=Sentinel): + pass + # Switch types -MIGHT_SWITCH_PROTOCOL = make_sentinel("MIGHT_SWITCH_PROTOCOL") -SWITCHED_PROTOCOL = make_sentinel("SWITCHED_PROTOCOL") +class MIGHT_SWITCH_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class SWITCHED_PROTOCOL(Sentinel, metaclass=Sentinel): + pass + + +class _SWITCH_UPGRADE(Sentinel, metaclass=Sentinel): + pass + -_SWITCH_UPGRADE = make_sentinel("_SWITCH_UPGRADE") -_SWITCH_CONNECT = make_sentinel("_SWITCH_CONNECT") +class _SWITCH_CONNECT(Sentinel, metaclass=Sentinel): + pass -EVENT_TRIGGERED_TRANSITIONS = { + +EventTransitionType = Dict[ + Type[Sentinel], + Dict[ + Type[Sentinel], + Dict[Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], Type[Sentinel]], + ], +] + +EVENT_TRIGGERED_TRANSITIONS: EventTransitionType = { CLIENT: { IDLE: {Request: SEND_BODY, ConnectionClosed: CLOSED}, SEND_BODY: {Data: SEND_BODY, EndOfMessage: DONE}, @@ -181,9 +226,13 @@ }, } +StateTransitionType = Dict[ + Tuple[Type[Sentinel], Type[Sentinel]], Dict[Type[Sentinel], Type[Sentinel]] +] + # NB: there are also some special-case state-triggered transitions hard-coded # into _fire_state_triggered_transitions below. -STATE_TRIGGERED_TRANSITIONS = { +STATE_TRIGGERED_TRANSITIONS: StateTransitionType = { # (Client state, Server state) -> new states # Protocol negotiation (MIGHT_SWITCH_PROTOCOL, SWITCHED_PROTOCOL): {CLIENT: SWITCHED_PROTOCOL}, @@ -198,7 +247,7 @@ class ConnectionState: - def __init__(self): + def __init__(self) -> None: # Extra bits of state that don't quite fit into the state model. # If this is False then it enables the automatic DONE -> MUST_CLOSE @@ -207,55 +256,64 @@ def __init__(self): # This is a subset of {UPGRADE, CONNECT}, containing the proposals # made by the client for switching protocols. - self.pending_switch_proposals = set() + self.pending_switch_proposals: Set[Type[Sentinel]] = set() - self.states = {CLIENT: IDLE, SERVER: IDLE} + self.states: Dict[Type[Sentinel], Type[Sentinel]] = {CLIENT: IDLE, SERVER: IDLE} - def process_error(self, role): + def process_error(self, role: Type[Sentinel]) -> None: self.states[role] = ERROR self._fire_state_triggered_transitions() - def process_keep_alive_disabled(self): + def process_keep_alive_disabled(self) -> None: self.keep_alive = False self._fire_state_triggered_transitions() - def process_client_switch_proposal(self, switch_event): + def process_client_switch_proposal(self, switch_event: Type[Sentinel]) -> None: self.pending_switch_proposals.add(switch_event) self._fire_state_triggered_transitions() - def process_event(self, role, event_type, server_switch_event=None): + def process_event( + self, + role: Type[Sentinel], + event_type: Type[Event], + server_switch_event: Optional[Type[Sentinel]] = None, + ) -> None: + _event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]] = event_type if server_switch_event is not None: assert role is SERVER if server_switch_event not in self.pending_switch_proposals: raise LocalProtocolError( - "Received server {} event without a pending proposal".format( - server_switch_event - ) + "Received server _SWITCH_UPGRADE event without a pending proposal" ) - event_type = (event_type, server_switch_event) - if server_switch_event is None and event_type is Response: + _event_type = (event_type, server_switch_event) + if server_switch_event is None and _event_type is Response: self.pending_switch_proposals = set() - self._fire_event_triggered_transitions(role, event_type) + self._fire_event_triggered_transitions(role, _event_type) # Special case: the server state does get to see Request # events. - if event_type is Request: + if _event_type is Request: assert role is CLIENT self._fire_event_triggered_transitions(SERVER, (Request, CLIENT)) self._fire_state_triggered_transitions() - def _fire_event_triggered_transitions(self, role, event_type): + def _fire_event_triggered_transitions( + self, + role: Type[Sentinel], + event_type: Union[Type[Event], Tuple[Type[Event], Type[Sentinel]]], + ) -> None: state = self.states[role] try: new_state = EVENT_TRIGGERED_TRANSITIONS[role][state][event_type] except KeyError: + event_type = cast(Type[Event], event_type) raise LocalProtocolError( "can't handle event type {} when role={} and state={}".format( event_type.__name__, role, self.states[role] ) - ) + ) from None self.states[role] = new_state - def _fire_state_triggered_transitions(self): + def _fire_state_triggered_transitions(self) -> None: # We apply these rules repeatedly until converging on a fixed point while True: start_states = dict(self.states) @@ -295,10 +353,10 @@ def _fire_state_triggered_transitions(self): # Fixed point reached return - def start_next_cycle(self): + def start_next_cycle(self) -> None: if self.states != {CLIENT: DONE, SERVER: DONE}: raise LocalProtocolError( - "not in a reusable state. self.states={}".format(self.states) + f"not in a reusable state. self.states={self.states}" ) # Can't reach DONE/DONE with any of these active, but still, let's be # sure. diff --git a/h11/_util.py b/h11/_util.py index eb1a5cd..6718445 100644 --- a/h11/_util.py +++ b/h11/_util.py @@ -1,9 +1,10 @@ +from typing import Any, Dict, NoReturn, Pattern, Tuple, Type, TypeVar, Union + __all__ = [ "ProtocolError", "LocalProtocolError", "RemoteProtocolError", "validate", - "make_sentinel", "bytesify", ] @@ -37,7 +38,7 @@ class ProtocolError(Exception): """ - def __init__(self, msg, error_status_hint=400): + def __init__(self, msg: str, error_status_hint: int = 400) -> None: if type(self) is ProtocolError: raise TypeError("tried to directly instantiate ProtocolError") Exception.__init__(self, msg) @@ -56,14 +57,14 @@ def __init__(self, msg, error_status_hint=400): # LocalProtocolError is for local errors and RemoteProtocolError is for # remote errors. class LocalProtocolError(ProtocolError): - def _reraise_as_remote_protocol_error(self): + def _reraise_as_remote_protocol_error(self) -> NoReturn: # After catching a LocalProtocolError, use this method to re-raise it # as a RemoteProtocolError. This method must be called from inside an # except: block. # # An easy way to get an equivalent RemoteProtocolError is just to # modify 'self' in place. - self.__class__ = RemoteProtocolError + self.__class__ = RemoteProtocolError # type: ignore # But the re-raising is somewhat non-trivial -- you might think that # now that we've modified the in-flight exception object, that just # doing 'raise' to re-raise it would be enough. But it turns out that @@ -80,7 +81,9 @@ class RemoteProtocolError(ProtocolError): pass -def validate(regex, data, msg="malformed data", *format_args): +def validate( + regex: Pattern[bytes], data: bytes, msg: str = "malformed data", *format_args: Any +) -> Dict[str, bytes]: match = regex.fullmatch(data) if not match: if format_args: @@ -97,21 +100,31 @@ def validate(regex, data, msg="malformed data", *format_args): # # The bonus property is useful if you want to take the return value from # next_event() and do some sort of dispatch based on type(event). -class _SentinelBase(type): - def __repr__(self): - return self.__name__ + +_T_Sentinel = TypeVar("_T_Sentinel", bound="Sentinel") -def make_sentinel(name): - cls = _SentinelBase(name, (_SentinelBase,), {}) - cls.__class__ = cls - return cls +class Sentinel(type): + def __new__( + cls: Type[_T_Sentinel], + name: str, + bases: Tuple[type, ...], + namespace: Dict[str, Any], + **kwds: Any + ) -> _T_Sentinel: + assert bases == (Sentinel,) + v = super().__new__(cls, name, bases, namespace, **kwds) + v.__class__ = v # type: ignore + return v + + def __repr__(self) -> str: + return self.__name__ # Used for methods, request targets, HTTP versions, header names, and header # values. Accepts ascii-strings, or bytes/bytearray/memoryview/..., and always # returns bytes. -def bytesify(s): +def bytesify(s: Union[bytes, bytearray, memoryview, int, str]) -> bytes: # Fast-path: if type(s) is bytes: return s diff --git a/h11/_version.py b/h11/_version.py index cb5c2c3..76e7327 100644 --- a/h11/_version.py +++ b/h11/_version.py @@ -13,4 +13,4 @@ # want. (Contrast with the special suffix 1.0.0.dev, which sorts *before* # 1.0.0.) -__version__ = "0.12.0" +__version__ = "0.16.0" diff --git a/h11/_writers.py b/h11/_writers.py index cb5e8a8..939cdb9 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -7,14 +7,19 @@ # - a writer # - or, for body writers, a dict of framin-dependent writer factories -from ._events import Data, EndOfMessage +from typing import Any, Callable, Dict, List, Tuple, Type, Union + +from ._events import Data, EndOfMessage, Event, InformationalResponse, Request, Response +from ._headers import Headers from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER -from ._util import LocalProtocolError +from ._util import LocalProtocolError, Sentinel __all__ = ["WRITERS"] +Writer = Callable[[bytes], Any] + -def write_headers(headers, write): +def write_headers(headers: Headers, write: Writer) -> None: # "Since the Host field-value is critical information for handling a # request, a user agent SHOULD generate Host as the first header field # following the request-line." - RFC 7230 @@ -28,7 +33,7 @@ def write_headers(headers, write): write(b"\r\n") -def write_request(request, write): +def write_request(request: Request, write: Writer) -> None: if request.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) @@ -36,7 +41,9 @@ def write_request(request, write): # Shared between InformationalResponse and Response -def write_any_response(response, write): +def write_any_response( + response: Union[InformationalResponse, Response], write: Writer +) -> None: if response.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") status_bytes = str(response.status_code).encode("ascii") @@ -53,7 +60,7 @@ def write_any_response(response, write): class BodyWriter: - def __call__(self, event, write): + def __call__(self, event: Event, write: Writer) -> None: if type(event) is Data: self.send_data(event.data, write) elif type(event) is EndOfMessage: @@ -61,6 +68,12 @@ def __call__(self, event, write): else: # pragma: no cover assert False + def send_data(self, data: bytes, write: Writer) -> None: + pass + + def send_eom(self, headers: Headers, write: Writer) -> None: + pass + # # These are all careful not to do anything to 'data' except call len(data) and @@ -69,16 +82,16 @@ def __call__(self, event, write): # sendfile(2). # class ContentLengthWriter(BodyWriter): - def __init__(self, length): + def __init__(self, length: int) -> None: self._length = length - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: self._length -= len(data) if self._length < 0: raise LocalProtocolError("Too much data for declared Content-Length") write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if self._length != 0: raise LocalProtocolError("Too little data for declared Content-Length") if headers: @@ -86,7 +99,7 @@ def send_eom(self, headers, write): class ChunkedWriter(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: # if we encoded 0-length data in the naive way, it would look like an # end-of-message. if not data: @@ -95,23 +108,32 @@ def send_data(self, data, write): write(data) write(b"\r\n") - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: write(b"0\r\n") write_headers(headers, write) class Http10Writer(BodyWriter): - def send_data(self, data, write): + def send_data(self, data: bytes, write: Writer) -> None: write(data) - def send_eom(self, headers, write): + def send_eom(self, headers: Headers, write: Writer) -> None: if headers: raise LocalProtocolError("can't send trailers to HTTP/1.0 client") # no need to close the socket ourselves, that will be taken care of by # Connection: close machinery -WRITERS = { +WritersType = Dict[ + Union[Tuple[Type[Sentinel], Type[Sentinel]], Type[Sentinel]], + Union[ + Dict[str, Type[BodyWriter]], + Callable[[Union[InformationalResponse, Response], Writer], None], + Callable[[Request, Writer], None], + ], +] + +WRITERS: WritersType = { (CLIENT, IDLE): write_request, (SERVER, IDLE): write_any_response, (SERVER, SEND_RESPONSE): write_any_response, diff --git a/h11/py.typed b/h11/py.typed new file mode 100644 index 0000000..f5642f7 --- /dev/null +++ b/h11/py.typed @@ -0,0 +1 @@ +Marker diff --git a/h11/tests/helpers.py b/h11/tests/helpers.py index 9d2cf38..571be44 100644 --- a/h11/tests/helpers.py +++ b/h11/tests/helpers.py @@ -1,36 +1,55 @@ -from .._connection import * -from .._events import * -from .._state import * +from typing import cast, List, Type, Union, ValuesView +from .._connection import Connection, NEED_DATA, PAUSED +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import CLIENT, CLOSED, DONE, MUST_CLOSE, SERVER +from .._util import Sentinel -def get_all_events(conn): +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + + +def get_all_events(conn: Connection) -> List[Event]: got_events = [] while True: event = conn.next_event() if event in (NEED_DATA, PAUSED): break + event = cast(Event, event) got_events.append(event) if type(event) is ConnectionClosed: break return got_events -def receive_and_get(conn, data): +def receive_and_get(conn: Connection, data: bytes) -> List[Event]: conn.receive_data(data) return get_all_events(conn) # Merges adjacent Data events, converts payloads to bytestrings, and removes # chunk boundaries. -def normalize_data_events(in_events): - out_events = [] +def normalize_data_events(in_events: List[Event]) -> List[Event]: + out_events: List[Event] = [] for event in in_events: if type(event) is Data: - event.data = bytes(event.data) - event.chunk_start = False - event.chunk_end = False + event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False) if out_events and type(out_events[-1]) is type(event) is Data: - out_events[-1].data += event.data + out_events[-1] = Data( + data=out_events[-1].data + event.data, + chunk_start=out_events[-1].chunk_start, + chunk_end=out_events[-1].chunk_end, + ) else: out_events.append(event) return out_events @@ -41,16 +60,21 @@ def normalize_data_events(in_events): # of pushing them through two Connections with a fake network link in # between. class ConnectionPair: - def __init__(self): + def __init__(self) -> None: self.conn = {CLIENT: Connection(CLIENT), SERVER: Connection(SERVER)} self.other = {CLIENT: SERVER, SERVER: CLIENT} @property - def conns(self): + def conns(self) -> ValuesView[Connection]: return self.conn.values() # expect="match" if expect=send_events; expect=[...] to say what expected - def send(self, role, send_events, expect="match"): + def send( + self, + role: Type[Sentinel], + send_events: Union[List[Event], Event], + expect: Union[List[Event], Event, Literal["match"]] = "match", + ) -> bytes: if not isinstance(send_events, list): send_events = [send_events] data = b"" diff --git a/h11/tests/test_against_stdlib_http.py b/h11/tests/test_against_stdlib_http.py index e6c5db4..3f66a10 100644 --- a/h11/tests/test_against_stdlib_http.py +++ b/h11/tests/test_against_stdlib_http.py @@ -5,13 +5,16 @@ import threading from contextlib import closing, contextmanager from http.server import SimpleHTTPRequestHandler +from typing import Callable, Generator from urllib.request import urlopen import h11 @contextmanager -def socket_server(handler): +def socket_server( + handler: Callable[..., socketserver.BaseRequestHandler], +) -> Generator[socketserver.TCPServer, None, None]: httpd = socketserver.TCPServer(("127.0.0.1", 0), handler) thread = threading.Thread( target=httpd.serve_forever, kwargs={"poll_interval": 0.01} @@ -30,13 +33,13 @@ def socket_server(handler): class SingleMindedRequestHandler(SimpleHTTPRequestHandler): - def translate_path(self, path): + def translate_path(self, path: str) -> str: return test_file_path -def test_h11_as_client(): +def test_h11_as_client() -> None: with socket_server(SingleMindedRequestHandler) as httpd: - with closing(socket.create_connection(httpd.server_address)) as s: + with closing(socket.create_connection(httpd.server_address)) as s: # type: ignore[arg-type] c = h11.Connection(h11.CLIENT) s.sendall( @@ -67,7 +70,7 @@ def test_h11_as_client(): class H11RequestHandler(socketserver.BaseRequestHandler): - def handle(self): + def handle(self) -> None: with closing(self.request) as s: c = h11.Connection(h11.SERVER) request = None @@ -82,6 +85,7 @@ def handle(self): request = event if type(event) is h11.EndOfMessage: break + assert request is not None info = json.dumps( { "method": request.method.decode("ascii"), @@ -97,10 +101,10 @@ def handle(self): s.sendall(c.send(h11.EndOfMessage())) -def test_h11_as_server(): +def test_h11_as_server() -> None: with socket_server(H11RequestHandler) as httpd: host, port = httpd.server_address - url = "http://{}:{}/some-path".format(host, port) + url = f"http://{host}:{port}/some-path" # type: ignore[str-bytes-safe] with closing(urlopen(url)) as f: assert f.getcode() == 200 data = f.read() diff --git a/h11/tests/test_connection.py b/h11/tests/test_connection.py index baadec8..01260dc 100644 --- a/h11/tests/test_connection.py +++ b/h11/tests/test_connection.py @@ -1,13 +1,33 @@ +from typing import Any, cast, Dict, List, Optional, Tuple, Type + import pytest from .._connection import _body_framing, _keep_alive, Connection, NEED_DATA, PAUSED -from .._events import * -from .._state import * -from .._util import LocalProtocolError, RemoteProtocolError +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + InformationalResponse, + Request, + Response, +) +from .._state import ( + CLIENT, + CLOSED, + DONE, + ERROR, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) +from .._util import LocalProtocolError, RemoteProtocolError, Sentinel from .helpers import ConnectionPair, get_all_events, receive_and_get -def test__keep_alive(): +def test__keep_alive() -> None: assert _keep_alive( Request(method="GET", target="/", headers=[("Host", "Example.com")]) ) @@ -37,8 +57,8 @@ def test__keep_alive(): assert not _keep_alive(Response(status_code=200, headers=[], http_version="1.0")) -def test__body_framing(): - def headers(cl, te): +def test__body_framing() -> None: + def headers(cl: Optional[int], te: bool) -> List[Tuple[str, str]]: headers = [] if cl is not None: headers.append(("Content-Length", str(cl))) @@ -46,16 +66,19 @@ def headers(cl, te): headers.append(("Transfer-Encoding", "chunked")) return headers - def resp(status_code=200, cl=None, te=False): + def resp( + status_code: int = 200, cl: Optional[int] = None, te: bool = False + ) -> Response: return Response(status_code=status_code, headers=headers(cl, te)) - def req(cl=None, te=False): + def req(cl: Optional[int] = None, te: bool = False) -> Request: h = headers(cl, te) h += [("Host", "example.com")] return Request(method="GET", target="/", headers=h) # Special cases where the headers are ignored: for kwargs in [{}, {"cl": 100}, {"te": True}, {"cl": 100, "te": True}]: + kwargs = cast(Dict[str, Any], kwargs) for meth, r in [ (b"HEAD", resp(**kwargs)), (b"GET", resp(status_code=204, **kwargs)), @@ -65,21 +88,22 @@ def req(cl=None, te=False): # Transfer-encoding for kwargs in [{"te": True}, {"cl": 100, "te": True}]: - for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: + kwargs = cast(Dict[str, Any], kwargs) + for meth, r in [(None, req(**kwargs)), (b"GET", resp(**kwargs))]: # type: ignore assert _body_framing(meth, r) == ("chunked", ()) # Content-Length - for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: + for meth, r in [(None, req(cl=100)), (b"GET", resp(cl=100))]: # type: ignore assert _body_framing(meth, r) == ("content-length", (100,)) # No headers - assert _body_framing(None, req()) == ("content-length", (0,)) + assert _body_framing(None, req()) == ("content-length", (0,)) # type: ignore assert _body_framing(b"GET", resp()) == ("http/1.0", ()) -def test_Connection_basics_and_content_length(): +def test_Connection_basics_and_content_length() -> None: with pytest.raises(ValueError): - Connection("CLIENT") + Connection("CLIENT") # type: ignore p = ConnectionPair() assert p.conn[CLIENT].our_role is CLIENT @@ -144,7 +168,7 @@ def test_Connection_basics_and_content_length(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunked(): +def test_chunked() -> None: p = ConnectionPair() p.send( @@ -175,7 +199,7 @@ def test_chunked(): assert conn.states == {CLIENT: DONE, SERVER: DONE} -def test_chunk_boundaries(): +def test_chunk_boundaries() -> None: conn = Connection(our_role=SERVER) request = ( @@ -214,7 +238,7 @@ def test_chunk_boundaries(): assert conn.next_event() == EndOfMessage() -def test_client_talking_to_http10_server(): +def test_client_talking_to_http10_server() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "example.com")])) c.send(EndOfMessage()) @@ -230,7 +254,7 @@ def test_client_talking_to_http10_server(): assert c.their_state is CLOSED -def test_server_talking_to_http10_client(): +def test_server_talking_to_http10_client() -> None: c = Connection(SERVER) # No content-length, so no body # NB: no host header @@ -267,7 +291,7 @@ def test_server_talking_to_http10_client(): assert receive_and_get(c, b"") == [ConnectionClosed()] -def test_automatic_transfer_encoding_in_response(): +def test_automatic_transfer_encoding_in_response() -> None: # Check that in responses, the user can specify either Transfer-Encoding: # chunked or no framing at all, and in both cases we automatically select # the right option depending on whether the peer speaks HTTP/1.0 or @@ -279,6 +303,7 @@ def test_automatic_transfer_encoding_in_response(): # because if both are set then Transfer-Encoding wins [("Transfer-Encoding", "chunked"), ("Content-Length", "100")], ]: + user_headers = cast(List[Tuple[str, str]], user_headers) p = ConnectionPair() p.send( CLIENT, @@ -308,7 +333,7 @@ def test_automatic_transfer_encoding_in_response(): assert c.send(Data(data=b"12345")) == b"12345" -def test_automagic_connection_close_handling(): +def test_automagic_connection_close_handling() -> None: p = ConnectionPair() # If the user explicitly sets Connection: close, then we notice and # respect it @@ -340,8 +365,8 @@ def test_automagic_connection_close_handling(): assert conn.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_100_continue(): - def setup(): +def test_100_continue() -> None: + def setup() -> ConnectionPair: p = ConnectionPair() p.send( CLIENT, @@ -385,7 +410,7 @@ def setup(): assert not conn.they_are_waiting_for_100_continue -def test_max_incomplete_event_size_countermeasure(): +def test_max_incomplete_event_size_countermeasure() -> None: # Infinitely long headers are definitely not okay c = Connection(SERVER) c.receive_data(b"GET / HTTP/1.0\r\nEndless: ") @@ -461,13 +486,19 @@ def test_max_incomplete_event_size_countermeasure(): c.next_event() -def test_reuse_simple(): +def test_reuse_simple() -> None: p = ConnectionPair() p.send( CLIENT, [Request(method="GET", target="/", headers=[("Host", "a")]), EndOfMessage()], ) - p.send(SERVER, [Response(status_code=200, headers=[]), EndOfMessage()]) + p.send( + SERVER, + [ + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ], + ) for conn in p.conns: assert conn.states == {CLIENT: DONE, SERVER: DONE} conn.start_next_cycle() @@ -479,10 +510,16 @@ def test_reuse_simple(): EndOfMessage(), ], ) - p.send(SERVER, [Response(status_code=404, headers=[]), EndOfMessage()]) + p.send( + SERVER, + [ + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ], + ) -def test_pipelining(): +def test_pipelining() -> None: # Client doesn't support pipelining, so we have to do this by hand c = Connection(SERVER) assert c.next_event() is NEED_DATA @@ -554,16 +591,16 @@ def test_pipelining(): c.receive_data(b"FDSA") -def test_protocol_switch(): - for (req, deny, accept) in [ +def test_protocol_switch() -> None: + for req, deny, accept in [ ( Request( method="CONNECT", target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1")], ), - Response(status_code=404, headers=[]), - Response(status_code=200, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), ), ( Request( @@ -571,7 +608,7 @@ def test_protocol_switch(): target="/", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), InformationalResponse(status_code=101, headers=[("Upgrade", "a")]), ), ( @@ -580,9 +617,9 @@ def test_protocol_switch(): target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=404, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), # Accept CONNECT, not upgrade - Response(status_code=200, headers=[]), + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), ), ( Request( @@ -590,13 +627,13 @@ def test_protocol_switch(): target="example.com:443", headers=[("Host", "foo"), ("Content-Length", "1"), ("Upgrade", "a, b")], ), - Response(status_code=404, headers=[]), + Response(status_code=404, headers=[(b"transfer-encoding", b"chunked")]), # Accept Upgrade, not CONNECT InformationalResponse(status_code=101, headers=[("Upgrade", "b")]), ), ]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(CLIENT, req) # No switch-related state change stuff yet; the client has to @@ -661,7 +698,7 @@ def setup(): p = setup() sc = p.conn[SERVER] - sc.receive_data(b"") == [] + sc.receive_data(b"") assert sc.next_event() is PAUSED sc.send(deny) assert sc.next_event() == ConnectionClosed() @@ -679,12 +716,12 @@ def setup(): p.conn[SERVER].send(Data(data=b"123")) -def test_close_simple(): +def test_close_simple() -> None: # Just immediately closing a new connection without anything having # happened yet. - for (who_shot_first, who_shot_second) in [(CLIENT, SERVER), (SERVER, CLIENT)]: + for who_shot_first, who_shot_second in [(CLIENT, SERVER), (SERVER, CLIENT)]: - def setup(): + def setup() -> ConnectionPair: p = ConnectionPair() p.send(who_shot_first, ConnectionClosed()) for conn in p.conns: @@ -720,12 +757,15 @@ def setup(): p.conn[who_shot_first].next_event() -def test_close_different_states(): +def test_close_different_states() -> None: req = [ Request(method="GET", target="/foo", headers=[("Host", "a")]), EndOfMessage(), ] - resp = [Response(status_code=200, headers=[]), EndOfMessage()] + resp = [ + Response(status_code=200, headers=[(b"transfer-encoding", b"chunked")]), + EndOfMessage(), + ] # Client before request p = ConnectionPair() @@ -783,7 +823,7 @@ def test_close_different_states(): # Receive several requests and then client shuts down their side of the # connection; we can respond to each -def test_pipelined_close(): +def test_pipelined_close() -> None: c = Connection(SERVER) # 2 requests then a close c.receive_data( @@ -825,41 +865,43 @@ def test_pipelined_close(): assert c.states == {CLIENT: CLOSED, SERVER: CLOSED} -def test_sendfile(): +def test_sendfile() -> None: class SendfilePlaceholder: - def __len__(self): + def __len__(self) -> int: return 10 placeholder = SendfilePlaceholder() - def setup(header, http_version): + def setup( + header: Tuple[str, str], http_version: str + ) -> Tuple[Connection, Optional[List[bytes]]]: c = Connection(SERVER) receive_and_get( - c, "GET / HTTP/{}\r\nHost: a\r\n\r\n".format(http_version).encode("ascii") + c, f"GET / HTTP/{http_version}\r\nHost: a\r\n\r\n".encode("ascii") ) headers = [] if header: headers.append(header) c.send(Response(status_code=200, headers=headers)) - return c, c.send_with_data_passthrough(Data(data=placeholder)) + return c, c.send_with_data_passthrough(Data(data=placeholder)) # type: ignore c, data = setup(("Content-Length", "10"), "1.1") - assert data == [placeholder] + assert data == [placeholder] # type: ignore # Raises an error if the connection object doesn't think we've sent # exactly 10 bytes c.send(EndOfMessage()) _, data = setup(("Transfer-Encoding", "chunked"), "1.1") - assert placeholder in data - data[data.index(placeholder)] = b"x" * 10 - assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" + assert placeholder in data # type: ignore + data[data.index(placeholder)] = b"x" * 10 # type: ignore + assert b"".join(data) == b"a\r\nxxxxxxxxxx\r\n" # type: ignore - c, data = setup(None, "1.0") - assert data == [placeholder] + c, data = setup(None, "1.0") # type: ignore + assert data == [placeholder] # type: ignore assert c.our_state is SEND_BODY -def test_errors(): +def test_errors() -> None: # After a receive error, you can't receive for role in [CLIENT, SERVER]: c = Connection(our_role=role) @@ -882,7 +924,7 @@ def test_errors(): # After an error sending, you can no longer send # (This is especially important for things like content-length errors, # where there's complex internal state being modified) - def conn(role): + def conn(role: Type[Sentinel]) -> Connection: c = Connection(our_role=role) if role is SERVER: # Put it into the state where it *could* send a response... @@ -902,8 +944,8 @@ def conn(role): http_version="1.0", ) elif role is SERVER: - good = Response(status_code=200, headers=[]) - bad = Response(status_code=200, headers=[], http_version="1.0") + good = Response(status_code=200, headers=[]) # type: ignore[assignment] + bad = Response(status_code=200, headers=[], http_version="1.0") # type: ignore[assignment] # Make sure 'good' actually is good c = conn(role) c.send(good) @@ -929,14 +971,14 @@ def conn(role): assert c.their_state is not ERROR -def test_idle_receive_nothing(): +def test_idle_receive_nothing() -> None: # At one point this incorrectly raised an error for role in [CLIENT, SERVER]: c = Connection(role) assert c.next_event() is NEED_DATA -def test_connection_drop(): +def test_connection_drop() -> None: c = Connection(SERVER) c.receive_data(b"GET /") assert c.next_event() is NEED_DATA @@ -945,15 +987,15 @@ def test_connection_drop(): c.next_event() -def test_408_request_timeout(): +def test_408_request_timeout() -> None: # Should be able to send this spontaneously as a server without seeing # anything from client p = ConnectionPair() - p.send(SERVER, Response(status_code=408, headers=[])) + p.send(SERVER, Response(status_code=408, headers=[(b"connection", b"close")])) # This used to raise IndexError -def test_empty_request(): +def test_empty_request() -> None: c = Connection(SERVER) c.receive_data(b"\r\n") with pytest.raises(RemoteProtocolError): @@ -961,7 +1003,7 @@ def test_empty_request(): # This used to raise IndexError -def test_empty_response(): +def test_empty_response() -> None: c = Connection(CLIENT) c.send(Request(method="GET", target="/", headers=[("Host", "a")])) c.receive_data(b"\r\n") @@ -977,7 +1019,7 @@ def test_empty_response(): b"\x16\x03\x01\x00\xa5", # Typical start of a TLS Client Hello ], ) -def test_early_detection_of_invalid_request(data): +def test_early_detection_of_invalid_request(data: bytes) -> None: c = Connection(SERVER) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -993,7 +1035,7 @@ def test_early_detection_of_invalid_request(data): b"\x16\x03\x03\x00\x31", # Typical start of a TLS Server Hello ], ) -def test_early_detection_of_invalid_response(data): +def test_early_detection_of_invalid_response(data: bytes) -> None: c = Connection(CLIENT) # Early detection should occur before even receiving a `\r\n` c.receive_data(data) @@ -1005,8 +1047,8 @@ def test_early_detection_of_invalid_response(data): # The correct way to handle HEAD is to put whatever headers we *would* have # put if it were a GET -- even though we know that for HEAD, those headers # will be ignored. -def test_HEAD_framing_headers(): - def setup(method, http_version): +def test_HEAD_framing_headers() -> None: + def setup(method: bytes, http_version: bytes) -> Connection: c = Connection(SERVER) c.receive_data( method + b" / HTTP/" + http_version + b"\r\n" + b"Host: example.com\r\n\r\n" @@ -1047,7 +1089,7 @@ def setup(method, http_version): ) -def test_special_exceptions_for_lost_connection_in_message_body(): +def test_special_exceptions_for_lost_connection_in_message_body() -> None: c = Connection(SERVER) c.receive_data( b"POST / HTTP/1.1\r\n" b"Host: example.com\r\n" b"Content-Length: 100\r\n\r\n" @@ -1071,7 +1113,7 @@ def test_special_exceptions_for_lost_connection_in_message_body(): assert type(c.next_event()) is Request assert c.next_event() is NEED_DATA c.receive_data(b"8\r\n012345") - assert c.next_event().data == b"012345" + assert c.next_event().data == b"012345" # type: ignore c.receive_data(b"") with pytest.raises(RemoteProtocolError) as excinfo: c.next_event() diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index e20f741..d691545 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -2,58 +2,18 @@ import pytest -from .. import _events -from .._events import * +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + InformationalResponse, + Request, + Response, +) from .._util import LocalProtocolError -def test_event_bundle(): - class T(_events._EventBundle): - _fields = ["a", "b"] - _defaults = {"b": 1} - - def _validate(self): - if self.a == 0: - raise ValueError - - # basic construction and methods - t = T(a=1, b=0) - assert repr(t) == "T(a=1, b=0)" - assert t == T(a=1, b=0) - assert not (t == T(a=2, b=0)) - assert not (t != T(a=1, b=0)) - assert t != T(a=2, b=0) - with pytest.raises(TypeError): - hash(t) - - # check defaults - t = T(a=10) - assert t.a == 10 - assert t.b == 1 - - # no positional args - with pytest.raises(TypeError): - T(1) - - with pytest.raises(TypeError): - T(1, a=1, b=0) - - # unknown field - with pytest.raises(TypeError): - T(a=1, b=0, c=10) - - # missing required field - with pytest.raises(TypeError) as exc: - T(b=0) - # make sure we error on the right missing kwarg - assert "kwarg a" in str(exc.value) - - # _validate is called - with pytest.raises(ValueError): - T(a=0, b=0) - - -def test_events(): +def test_events() -> None: with pytest.raises(LocalProtocolError): # Missing Host: req = Request( @@ -114,14 +74,23 @@ def test_events(): ) # Request target is validated - for bad_char in b"\x00\x20\x7f\xee": + for bad_byte in b"\x00\x20\x7f\xee": target = bytearray(b"/") - target.append(bad_char) + target.append(bad_byte) with pytest.raises(LocalProtocolError): Request( method="GET", target=target, headers=[("Host", "a")], http_version="1.1" ) + # Request method is validated + with pytest.raises(LocalProtocolError): + Request( + method="GET / HTTP/1.1", + target=target, + headers=[("Host", "a")], + http_version="1.1", + ) + ir = InformationalResponse(status_code=100, headers=[("Host", "a")]) assert ir.status_code == 100 assert ir.headers == [(b"host", b"a")] @@ -139,10 +108,10 @@ def test_events(): resp = Response(status_code=100, headers=[], http_version="1.0") with pytest.raises(LocalProtocolError): - Response(status_code="100", headers=[], http_version="1.0") + Response(status_code="100", headers=[], http_version="1.0") # type: ignore[arg-type] with pytest.raises(LocalProtocolError): - InformationalResponse(status_code=b"100", headers=[], http_version="1.0") + InformationalResponse(status_code=b"100", headers=[], http_version="1.0") # type: ignore[arg-type] d = Data(data=b"asdf") assert d.data == b"asdf" @@ -154,7 +123,7 @@ def test_events(): assert repr(cc) == "ConnectionClosed()" -def test_intenum_status_code(): +def test_intenum_status_code() -> None: # https://github.com/python-hyper/h11/issues/72 r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") @@ -163,7 +132,7 @@ def test_intenum_status_code(): assert type(r.status_code) is int -def test_header_casing(): +def test_header_casing() -> None: r = Request( method="GET", target="/", diff --git a/h11/tests/test_headers.py b/h11/tests/test_headers.py index ff3dc8d..b57274c 100644 --- a/h11/tests/test_headers.py +++ b/h11/tests/test_headers.py @@ -1,9 +1,17 @@ import pytest -from .._headers import * - - -def test_normalize_and_validate(): +from .._events import Request +from .._headers import ( + get_comma_header, + has_expect_100_continue, + Headers, + normalize_and_validate, + set_comma_header, +) +from .._util import LocalProtocolError + + +def test_normalize_and_validate() -> None: assert normalize_and_validate([("foo", "bar")]) == [(b"foo", b"bar")] assert normalize_and_validate([(b"foo", b"bar")]) == [(b"foo", b"bar")] @@ -66,6 +74,8 @@ def test_normalize_and_validate(): ) with pytest.raises(LocalProtocolError): normalize_and_validate([("Content-Length", "1 , 1,2")]) + with pytest.raises(LocalProtocolError): + normalize_and_validate([("Content-Length", "1" * 21)]) # 1 billion TB # transfer-encoding assert normalize_and_validate([("Transfer-Encoding", "chunked")]) == [ @@ -84,7 +94,7 @@ def test_normalize_and_validate(): assert excinfo.value.error_status_hint == 501 # Not Implemented -def test_get_set_comma_header(): +def test_get_set_comma_header() -> None: headers = normalize_and_validate( [ ("Connection", "close"), @@ -95,10 +105,10 @@ def test_get_set_comma_header(): assert get_comma_header(headers, b"connection") == [b"close", b"foo", b"bar"] - headers = set_comma_header(headers, b"newthing", ["a", "b"]) + headers = set_comma_header(headers, b"newthing", ["a", "b"]) # type: ignore with pytest.raises(LocalProtocolError): - set_comma_header(headers, b"newthing", [" a", "b"]) + set_comma_header(headers, b"newthing", [" a", "b"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -108,7 +118,7 @@ def test_get_set_comma_header(): (b"newthing", b"b"), ] - headers = set_comma_header(headers, b"whatever", ["different thing"]) + headers = set_comma_header(headers, b"whatever", ["different thing"]) # type: ignore assert headers == [ (b"connection", b"close"), @@ -119,9 +129,7 @@ def test_get_set_comma_header(): ] -def test_has_100_continue(): - from .._events import Request - +def test_has_100_continue() -> None: assert has_expect_100_continue( Request( method="GET", diff --git a/h11/tests/test_helpers.py b/h11/tests/test_helpers.py index 1477947..9a30dc6 100644 --- a/h11/tests/test_helpers.py +++ b/h11/tests/test_helpers.py @@ -1,7 +1,8 @@ -from .helpers import * +from .._events import Data, EndOfMessage, Response +from .helpers import normalize_data_events -def test_normalize_data_events(): +def test_normalize_data_events() -> None: assert normalize_data_events( [ Data(data=bytearray(b"1")), diff --git a/h11/tests/test_io.py b/h11/tests/test_io.py index 459a627..407e044 100644 --- a/h11/tests/test_io.py +++ b/h11/tests/test_io.py @@ -1,6 +1,15 @@ +from typing import Any, Callable, Generator, List + import pytest -from .._events import * +from .._events import ( + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) from .._headers import Headers, normalize_and_validate from .._readers import ( _obsolete_line_fold, @@ -10,7 +19,7 @@ READERS, ) from .._receivebuffer import ReceiveBuffer -from .._state import * +from .._state import CLIENT, IDLE, SEND_RESPONSE, SERVER from .._util import LocalProtocolError from .._writers import ( ChunkedWriter, @@ -58,30 +67,29 @@ ] -def dowrite(writer, obj): - got_list = [] +def dowrite(writer: Callable[..., None], obj: Any) -> bytes: + got_list: List[bytes] = [] writer(obj, got_list.append) return b"".join(got_list) -def tw(writer, obj, expected): +def tw(writer: Any, obj: Any, expected: Any) -> None: got = dowrite(writer, obj) assert got == expected -def makebuf(data): +def makebuf(data: bytes) -> ReceiveBuffer: buf = ReceiveBuffer() buf += data return buf -def tr(reader, data, expected): - def check(got): +def tr(reader: Any, data: bytes, expected: Any) -> None: + def check(got: Any) -> None: assert got == expected # Headers should always be returned as bytes, not e.g. bytearray # https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478 for name, value in getattr(got, "headers", []): - print(name, value) assert type(name) is bytes assert type(value) is bytes @@ -104,17 +112,17 @@ def check(got): assert bytes(buf) == b"trailing" -def test_writers_simple(): - for ((role, state), event, binary) in SIMPLE_CASES: +def test_writers_simple() -> None: + for (role, state), event, binary in SIMPLE_CASES: tw(WRITERS[role, state], event, binary) -def test_readers_simple(): - for ((role, state), event, binary) in SIMPLE_CASES: +def test_readers_simple() -> None: + for (role, state), event, binary in SIMPLE_CASES: tr(READERS[role, state], binary, event) -def test_writers_unusual(): +def test_writers_unusual() -> None: # Simple test of the write_headers utility routine tw( write_headers, @@ -145,7 +153,7 @@ def test_writers_unusual(): ) -def test_readers_unusual(): +def test_readers_unusual() -> None: # Reading HTTP/1.0 tr( READERS[CLIENT, IDLE], @@ -305,7 +313,7 @@ def test_readers_unusual(): tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None) -def test__obsolete_line_fold_bytes(): +def test__obsolete_line_fold_bytes() -> None: # _obsolete_line_fold has a defensive cast to bytearray, which is # necessary to protect against O(n^2) behavior in case anyone ever passes # in regular bytestrings... but right now we never pass in regular @@ -318,7 +326,9 @@ def test__obsolete_line_fold_bytes(): ] -def _run_reader_iter(reader, buf, do_eof): +def _run_reader_iter( + reader: Any, buf: bytes, do_eof: bool +) -> Generator[Any, None, None]: while True: event = reader(buf) if event is None: @@ -333,27 +343,39 @@ def _run_reader_iter(reader, buf, do_eof): yield reader.read_eof() -def _run_reader(*args): +def _run_reader(*args: Any) -> List[Event]: events = list(_run_reader_iter(*args)) return normalize_data_events(events) -def t_body_reader(thunk, data, expected, do_eof=False): +def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None: # Simple: consume whole thing print("Test 1") buf = makebuf(data) - assert _run_reader(thunk(), buf, do_eof) == expected + try: + assert _run_reader(thunk(), buf, do_eof) == expected + except LocalProtocolError: + if LocalProtocolError in expected: + pass + else: + raise # Incrementally growing buffer print("Test 2") reader = thunk() buf = ReceiveBuffer() events = [] - for i in range(len(data)): - events += _run_reader(reader, buf, False) - buf += data[i : i + 1] - events += _run_reader(reader, buf, do_eof) - assert normalize_data_events(events) == expected + try: + for i in range(len(data)): + events += _run_reader(reader, buf, False) + buf += data[i : i + 1] + events += _run_reader(reader, buf, do_eof) + assert normalize_data_events(events) == expected + except LocalProtocolError: + if LocalProtocolError in expected: + pass + else: + raise is_complete = any(type(event) is EndOfMessage for event in expected) if is_complete and not do_eof: @@ -361,7 +383,7 @@ def t_body_reader(thunk, data, expected, do_eof=False): assert _run_reader(thunk(), buf, False) == expected -def test_ContentLengthReader(): +def test_ContentLengthReader() -> None: t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()]) t_body_reader( @@ -371,7 +393,7 @@ def test_ContentLengthReader(): ) -def test_Http10Reader(): +def test_Http10Reader() -> None: t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True) t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False) t_body_reader( @@ -379,7 +401,7 @@ def test_Http10Reader(): ) -def test_ChunkedReader(): +def test_ChunkedReader() -> None: t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()]) t_body_reader( @@ -414,14 +436,12 @@ def test_ChunkedReader(): ) # refuses arbitrarily long chunk integers - with pytest.raises(LocalProtocolError): - # Technically this is legal HTTP/1.1, but we refuse to process chunk - # sizes that don't fit into 20 characters of hex - t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")]) + # Technically this is legal HTTP/1.1, but we refuse to process chunk + # sizes that don't fit into 20 characters of hex + t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [LocalProtocolError]) # refuses garbage in the chunk count - with pytest.raises(LocalProtocolError): - t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None) + t_body_reader(ChunkedReader, b"10\x00\r\nxxx", [LocalProtocolError]) # handles (and discards) "chunk extensions" omg wtf t_body_reader( @@ -433,8 +453,27 @@ def test_ChunkedReader(): [Data(data=b"xxxxx"), EndOfMessage()], ) + t_body_reader( + ChunkedReader, + b"5 \t \r\n01234\r\n" + b"0\r\n\r\n", + [Data(data=b"01234"), EndOfMessage()], + ) + + # Chunked encoding with bad chunk termination characters are refused. Originally we + # simply dropped the 2 bytes after a chunk, instead of validating that the bytes + # were \r\n -- so we would successfully decode the data below as b"xxxa". And + # apparently there are other HTTP processors that ignore the chunk length and just + # keep reading until they see \r\n, so they would decode it as b"xxx__1a". Any time + # two HTTP processors accept the same input but interpret it differently, there's a + # possibility of request smuggling shenanigans. So we now reject this. + t_body_reader(ChunkedReader, b"3\r\nxxx__1a\r\n", [LocalProtocolError]) + + # Confirm we check both bytes individually + t_body_reader(ChunkedReader, b"3\r\nxxx\r_1a\r\n", [LocalProtocolError]) + t_body_reader(ChunkedReader, b"3\r\nxxx_\n1a\r\n", [LocalProtocolError]) -def test_ContentLengthWriter(): + +def test_ContentLengthWriter() -> None: w = ContentLengthWriter(5) assert dowrite(w, Data(data=b"123")) == b"123" assert dowrite(w, Data(data=b"45")) == b"45" @@ -455,13 +494,13 @@ def test_ContentLengthWriter(): dowrite(w, EndOfMessage()) w = ContentLengthWriter(5) - dowrite(w, Data(data=b"123")) == b"123" - dowrite(w, Data(data=b"45")) == b"45" + assert dowrite(w, Data(data=b"123")) == b"123" + assert dowrite(w, Data(data=b"45")) == b"45" with pytest.raises(LocalProtocolError): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_ChunkedWriter(): +def test_ChunkedWriter() -> None: w = ChunkedWriter() assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n" assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n" @@ -476,7 +515,7 @@ def test_ChunkedWriter(): ) -def test_Http10Writer(): +def test_Http10Writer() -> None: w = Http10Writer() assert dowrite(w, Data(data=b"1234")) == b"1234" assert dowrite(w, EndOfMessage()) == b"" @@ -485,12 +524,12 @@ def test_Http10Writer(): dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) -def test_reject_garbage_after_request_line(): +def test_reject_garbage_after_request_line() -> None: with pytest.raises(LocalProtocolError): tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None) -def test_reject_garbage_after_response_line(): +def test_reject_garbage_after_response_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -499,7 +538,7 @@ def test_reject_garbage_after_response_line(): ) -def test_reject_garbage_in_header_line(): +def test_reject_garbage_in_header_line() -> None: with pytest.raises(LocalProtocolError): tr( READERS[CLIENT, IDLE], @@ -508,7 +547,7 @@ def test_reject_garbage_in_header_line(): ) -def test_reject_non_vchar_in_path(): +def test_reject_non_vchar_in_path() -> None: for bad_char in b"\x00\x20\x7f\xee": message = bytearray(b"HEAD /") message.append(bad_char) @@ -518,7 +557,7 @@ def test_reject_non_vchar_in_path(): # https://github.com/python-hyper/h11/issues/57 -def test_allow_some_garbage_in_cookies(): +def test_allow_some_garbage_in_cookies() -> None: tr( READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" @@ -536,7 +575,7 @@ def test_allow_some_garbage_in_cookies(): ) -def test_host_comes_first(): +def test_host_comes_first() -> None: tw( write_headers, normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), diff --git a/h11/tests/test_receivebuffer.py b/h11/tests/test_receivebuffer.py index 3a61f9d..21a3870 100644 --- a/h11/tests/test_receivebuffer.py +++ b/h11/tests/test_receivebuffer.py @@ -1,11 +1,12 @@ import re +from typing import Tuple import pytest from .._receivebuffer import ReceiveBuffer -def test_receivebuffer(): +def test_receivebuffer() -> None: b = ReceiveBuffer() assert not b assert len(b) == 0 @@ -118,7 +119,7 @@ def test_receivebuffer(): ), ], ) -def test_receivebuffer_for_invalid_delimiter(data): +def test_receivebuffer_for_invalid_delimiter(data: Tuple[bytes]) -> None: b = ReceiveBuffer() for line in data: diff --git a/h11/tests/test_state.py b/h11/tests/test_state.py index efe83f0..bc974e6 100644 --- a/h11/tests/test_state.py +++ b/h11/tests/test_state.py @@ -1,12 +1,33 @@ import pytest -from .._events import * -from .._state import * -from .._state import _SWITCH_CONNECT, _SWITCH_UPGRADE, ConnectionState +from .._events import ( + ConnectionClosed, + Data, + EndOfMessage, + Event, + InformationalResponse, + Request, + Response, +) +from .._state import ( + _SWITCH_CONNECT, + _SWITCH_UPGRADE, + CLIENT, + CLOSED, + ConnectionState, + DONE, + IDLE, + MIGHT_SWITCH_PROTOCOL, + MUST_CLOSE, + SEND_BODY, + SEND_RESPONSE, + SERVER, + SWITCHED_PROTOCOL, +) from .._util import LocalProtocolError -def test_ConnectionState(): +def test_ConnectionState() -> None: cs = ConnectionState() # Basic event-triggered transitions @@ -38,7 +59,7 @@ def test_ConnectionState(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: CLOSED} -def test_ConnectionState_keep_alive(): +def test_ConnectionState_keep_alive() -> None: # keep_alive = False cs = ConnectionState() cs.process_event(CLIENT, Request) @@ -51,7 +72,7 @@ def test_ConnectionState_keep_alive(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: MUST_CLOSE} -def test_ConnectionState_keep_alive_in_DONE(): +def test_ConnectionState_keep_alive_in_DONE() -> None: # Check that if keep_alive is disabled when the CLIENT is already in DONE, # then this is sufficient to immediately trigger the DONE -> MUST_CLOSE # transition @@ -63,7 +84,7 @@ def test_ConnectionState_keep_alive_in_DONE(): assert cs.states[CLIENT] is MUST_CLOSE -def test_ConnectionState_switch_denied(): +def test_ConnectionState_switch_denied() -> None: for switch_type in (_SWITCH_CONNECT, _SWITCH_UPGRADE): for deny_early in (True, False): cs = ConnectionState() @@ -107,7 +128,7 @@ def test_ConnectionState_switch_denied(): } -def test_ConnectionState_protocol_switch_accepted(): +def test_ConnectionState_protocol_switch_accepted() -> None: for switch_event in [_SWITCH_UPGRADE, _SWITCH_CONNECT]: cs = ConnectionState() cs.process_client_switch_proposal(switch_event) @@ -125,7 +146,7 @@ def test_ConnectionState_protocol_switch_accepted(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_double_protocol_switch(): +def test_ConnectionState_double_protocol_switch() -> None: # CONNECT + Upgrade is legal! Very silly, but legal. So we support # it. Because sometimes doing the silly thing is easier than not. for server_switch in [None, _SWITCH_UPGRADE, _SWITCH_CONNECT]: @@ -144,7 +165,7 @@ def test_ConnectionState_double_protocol_switch(): assert cs.states == {CLIENT: SWITCHED_PROTOCOL, SERVER: SWITCHED_PROTOCOL} -def test_ConnectionState_inconsistent_protocol_switch(): +def test_ConnectionState_inconsistent_protocol_switch() -> None: for client_switches, server_switch in [ ([], _SWITCH_CONNECT), ([], _SWITCH_UPGRADE), @@ -152,14 +173,14 @@ def test_ConnectionState_inconsistent_protocol_switch(): ([_SWITCH_CONNECT], _SWITCH_UPGRADE), ]: cs = ConnectionState() - for client_switch in client_switches: + for client_switch in client_switches: # type: ignore[attr-defined] cs.process_client_switch_proposal(client_switch) cs.process_event(CLIENT, Request) with pytest.raises(LocalProtocolError): cs.process_event(SERVER, Response, server_switch) -def test_ConnectionState_keepalive_protocol_switch_interaction(): +def test_ConnectionState_keepalive_protocol_switch_interaction() -> None: # keep_alive=False + pending_switch_proposals cs = ConnectionState() cs.process_client_switch_proposal(_SWITCH_UPGRADE) @@ -177,7 +198,7 @@ def test_ConnectionState_keepalive_protocol_switch_interaction(): assert cs.states == {CLIENT: MUST_CLOSE, SERVER: SEND_BODY} -def test_ConnectionState_reuse(): +def test_ConnectionState_reuse() -> None: cs = ConnectionState() with pytest.raises(LocalProtocolError): @@ -242,7 +263,7 @@ def test_ConnectionState_reuse(): assert cs.states == {CLIENT: IDLE, SERVER: IDLE} -def test_server_request_is_illegal(): +def test_server_request_is_illegal() -> None: # There used to be a bug in how we handled the Request special case that # made this allowed... cs = ConnectionState() diff --git a/h11/tests/test_util.py b/h11/tests/test_util.py index d851bdc..79bc095 100644 --- a/h11/tests/test_util.py +++ b/h11/tests/test_util.py @@ -1,18 +1,26 @@ import re import sys import traceback +from typing import NoReturn import pytest -from .._util import * +from .._util import ( + bytesify, + LocalProtocolError, + ProtocolError, + RemoteProtocolError, + Sentinel, + validate, +) -def test_ProtocolError(): +def test_ProtocolError() -> None: with pytest.raises(TypeError): ProtocolError("abstract base class") -def test_LocalProtocolError(): +def test_LocalProtocolError() -> None: try: raise LocalProtocolError("foo") except LocalProtocolError as e: @@ -25,7 +33,7 @@ def test_LocalProtocolError(): assert str(e) == "foo" assert e.error_status_hint == 418 - def thunk(): + def thunk() -> NoReturn: raise LocalProtocolError("a", error_status_hint=420) try: @@ -42,8 +50,8 @@ def thunk(): assert new_traceback.endswith(orig_traceback) -def test_validate(): - my_re = re.compile(br"(?P[0-9]+)\.(?P[0-9]+)") +def test_validate() -> None: + my_re = re.compile(rb"(?P[0-9]+)\.(?P[0-9]+)") with pytest.raises(LocalProtocolError): validate(my_re, b"0.") @@ -57,8 +65,8 @@ def test_validate(): validate(my_re, b"0.1\n") -def test_validate_formatting(): - my_re = re.compile(br"foo") +def test_validate_formatting() -> None: + my_re = re.compile(rb"foo") with pytest.raises(LocalProtocolError) as excinfo: validate(my_re, b"", "oops") @@ -73,21 +81,26 @@ def test_validate_formatting(): assert "oops 10 xx" in str(excinfo.value) -def test_make_sentinel(): - S = make_sentinel("S") +def test_make_sentinel() -> None: + class S(Sentinel, metaclass=Sentinel): + pass + assert repr(S) == "S" assert S == S assert type(S).__name__ == "S" assert S in {S} assert type(S) is S - S2 = make_sentinel("S2") + + class S2(Sentinel, metaclass=Sentinel): + pass + assert repr(S2) == "S2" assert S != S2 assert S not in {S2} assert type(S) is not type(S2) -def test_bytesify(): +def test_bytesify() -> None: assert bytesify(b"123") == b"123" assert bytesify(bytearray(b"123")) == b"123" assert bytesify("123") == b"123" diff --git a/pyproject.toml b/pyproject.toml index edd11ae..64a6883 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,3 +38,9 @@ showcontent = true directory = "misc" name = "Miscellaneous internal changes" showcontent = true + +[tool.mypy] +strict = true +warn_unused_configs = true +warn_unused_ignores = true +show_error_codes = true diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 0bd1262..0000000 --- a/setup.cfg +++ /dev/null @@ -1,10 +0,0 @@ -[isort] -combine_as_imports=True -force_grid_wrap=0 -include_trailing_comma=True -known_first_party=h11, test -known_third_party=pytest -line_length=88 -multi_line_output=3 -no_lines_before=LOCALFOLDER -order_by_type=False diff --git a/setup.py b/setup.py index 93b6b62..73713e2 100644 --- a/setup.py +++ b/setup.py @@ -12,12 +12,10 @@ author="Nathaniel J. Smith", author_email="njs@pobox.com", license="MIT", - packages=find_packages(), + packages=find_packages(exclude=["h11.tests"]), + package_data={'h11': ['py.typed']}, url="https://github.com/python-hyper/h11", - # This means, just install *everything* you see under h11/, even if it - # doesn't look like a source file, so long as it appears in MANIFEST.in: - include_package_data=True, - python_requires=">=3.6", + python_requires=">=3.8", classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", @@ -26,10 +24,11 @@ "Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.6", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Topic :: Internet :: WWW/HTTP", "Topic :: System :: Networking", ], diff --git a/tox.ini b/tox.ini index d4a2272..6614ecf 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,16 @@ [tox] -envlist = format, py36, py37, py38, py39, pypy3 +envlist = format, py{38, 39, 310, 311, 312, py3}, mypy [gh-actions] python = - 3.6: py36 - 3.7: py37 - 3.8: py38, format + 3.8: py38, format, mypy 3.9: py39 - pypy3: pypy3 + 3.10: py310 + 3.11: py311 + 3.12: py312 + 3.13: py313 + pypy-3.9: pypy3 + pypy-3.10: pypy3 [testenv] deps = -r{toxinidir}/test-requirements.txt @@ -15,9 +18,15 @@ commands = pytest --cov=h11 --cov-config=.coveragerc h11 [testenv:format] basepython = python3.8 -deps = - black - isort +deps = -r{toxinidir}/format-requirements.txt commands = black --check --diff h11/ bench/ examples/ fuzz/ - isort --check --diff h11 bench examples fuzz + isort --check --diff --profile black --dt h11 bench examples fuzz + +[testenv:mypy] +basepython = python3.8 +deps = + mypy==1.8.0 + pytest +commands = + mypy h11/