diff --git a/.coveragerc b/.coveragerc index 0f1b93d..5724e30 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,7 +5,7 @@ branch = False # branch = True omit = - multipart/tests/* + tests/* [report] # Regexes for lines to exclude from consideration diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..cf07f01 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,18 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + groups: + python-packages: + patterns: + - "*" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: monthly + groups: + github-actions: + patterns: + - "*" diff --git a/.github/workflows/test.yaml b/.github/workflows/main.yaml similarity index 61% rename from .github/workflows/test.yaml rename to .github/workflows/main.yaml index 21428c5..7b4fd5c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/main.yaml @@ -14,20 +14,22 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.7", "3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip pip install .[dev] + - name: Lint + if: matrix.python-version == '3.8' + run: | + ruff multipart tests - name: Test with pytest run: | - # This should be inv test but invoke does not have python3.11 support yet. - # See https://github.com/pyinvoke/invoke/issues/891 for details - pytest + inv test diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index d35bfa7..cc38611 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -10,22 +10,20 @@ name: Upload Python Package on: push: - tags: - - '[0-9]+.[0-9]+.[0-9]+.*' # Run on every git tag with semantic versioning. i.e: 1.5.0 or 1.5.0rc1 + tags: + - '*' permissions: contents: read jobs: deploy: - runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" @@ -38,7 +36,7 @@ jobs: run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@v1.6.4 + uses: pypa/gh-action-pypi-publish@v1.8.11 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.gitignore b/.gitignore index cfa9998..f7f7b71 100644 --- a/.gitignore +++ b/.gitignore @@ -28,7 +28,7 @@ lib64 pip-log.txt # Unit test / coverage reports -.coverage +.coverage* .tox nosetests.xml diff --git a/CHANGELOG.md b/CHANGELOG.md index 7373025..23c0fff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,23 @@ # Changelog +## 0.0.9 (2024-02-10) + +* Add support for Python 3.12 [#85](https://github.com/Kludex/python-multipart/pull/85). +* Drop support for Python 3.7 [#95](https://github.com/Kludex/python-multipart/pull/95). +* Add `MultipartState(IntEnum)` [#96](https://github.com/Kludex/python-multipart/pull/96). +* Add `QuerystringState` [#97](https://github.com/Kludex/python-multipart/pull/97). +* Add `TypedDict` callbacks [#98](https://github.com/Kludex/python-multipart/pull/98). +* Add config `TypedDict`s [#99](https://github.com/Kludex/python-multipart/pull/99). + +## 0.0.8 (2024-02-09) + +* Check if Message.get_params return 3-tuple instead of str on parse_options_header [#79](https://github.com/Kludex/python-multipart/pull/79). +* Cleanup unused regex patterns [#82](https://github.com/Kludex/python-multipart/pull/82). + +## 0.0.7 (2024-02-03) + +* Refactor header option parser to use the standard library instead of a custom RegEx [#75](https://github.com/andrew-d/python-multipart/pull/75). + ## 0.0.6 (2023-02-27) * Migrate package installation to `pyproject.toml` (PEP 621) [#54](https://github.com/andrew-d/python-multipart/pull/54). diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index 864fc99..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,4 +0,0 @@ -include requirements.txt README.md LICENSE.txt -recursive-include multipart *.py *.yaml *.bare *.http LICENSE*.* -recursive-exclude multipart *.pyc *.pyo *.pyd - diff --git a/README.rst b/README.rst index 0a7dac4..7bdc864 100644 --- a/README.rst +++ b/README.rst @@ -24,5 +24,5 @@ If you want to test: .. code-block:: bash - $ pip install .[dev] + $ pip install '.[dev]' $ inv test diff --git a/docs_requirements.txt b/docs_requirements.txt index ef0e956..8c3eda5 100644 --- a/docs_requirements.txt +++ b/docs_requirements.txt @@ -1,19 +1,19 @@ -Jinja2==2.11.3 -PyYAML==5.4 -Pygments==2.7.4 -Sphinx==1.2b1 -cov-core==1.7 -coverage==3.6 -distribute==0.6.34 -docutils==0.10 -invoke==0.2.0 +Jinja2==3.1.3 +PyYAML==6.0.1 +Pygments==2.17.2 +Sphinx==7.2.6 +cov-core==1.15.0 +coverage==7.4.1 +distribute==0.7.3 +docutils==0.20.1 +invoke==2.2.0 pexpect-u==2.5.1 -py==1.10.0 -pytest==6.2.6 +py==1.11.0 +pytest==8.0.0 pytest-capturelog==0.7 -pytest-cov==1.6 -pytest-timeout==0.3 -sphinx-bootstrap-theme==0.2.0 -tox==1.4.3 -virtualenv==1.9.1 +pytest-cov==4.1.0 +pytest-timeout==2.2.0 +sphinx-bootstrap-theme==0.8.1 +tox==4.12.1 +virtualenv==20.25.0 wsgiref==0.1.2 diff --git a/multipart/__init__.py b/multipart/__init__.py index 309d698..dc13f13 100644 --- a/multipart/__init__.py +++ b/multipart/__init__.py @@ -1,15 +1,16 @@ # This is the canonical package information. -__author__ = 'Andrew Dunham' -__license__ = 'Apache' +__author__ = "Andrew Dunham" +__license__ = "Apache" __copyright__ = "Copyright (c) 2012-2013, Andrew Dunham" -__version__ = "0.0.6" +__version__ = "0.0.9" +from .multipart import FormParser, MultipartParser, OctetStreamParser, QuerystringParser, create_form_parser, parse_form -from .multipart import ( - FormParser, - MultipartParser, - QuerystringParser, - OctetStreamParser, - create_form_parser, - parse_form, +__all__ = ( + "FormParser", + "MultipartParser", + "OctetStreamParser", + "QuerystringParser", + "create_form_parser", + "parse_form", ) diff --git a/multipart/decoders.py b/multipart/decoders.py index 0d7ab32..417650c 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -59,8 +59,7 @@ def write(self, data): try: decoded = base64.b64decode(val) except binascii.Error: - raise DecodeError('There was an error raised while decoding ' - 'base64-encoded data.') + raise DecodeError("There was an error raised while decoding base64-encoded data.") self.underlying.write(decoded) @@ -69,7 +68,7 @@ def write(self, data): if remaining_len > 0: self.cache = data[-remaining_len:] else: - self.cache = b'' + self.cache = b"" # Return the length of the data to indicate no error. return len(data) @@ -78,7 +77,7 @@ def close(self): """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ - if hasattr(self.underlying, 'close'): + if hasattr(self.underlying, "close"): self.underlying.close() def finalize(self): @@ -91,11 +90,11 @@ def finalize(self): call it. """ if len(self.cache) > 0: - raise DecodeError('There are %d bytes remaining in the ' - 'Base64Decoder cache when finalize() is called' - % len(self.cache)) + raise DecodeError( + "There are %d bytes remaining in the Base64Decoder cache when finalize() is called" % len(self.cache) + ) - if hasattr(self.underlying, 'finalize'): + if hasattr(self.underlying, "finalize"): self.underlying.finalize() def __repr__(self): @@ -111,8 +110,9 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ + def __init__(self, underlying): - self.cache = b'' + self.cache = b"" self.underlying = underlying def write(self, data): @@ -128,11 +128,11 @@ def write(self, data): # If the last 2 characters have an '=' sign in it, then we won't be # able to decode the encoded value and we'll need to save it for the # next decoding step. - if data[-2:].find(b'=') != -1: + if data[-2:].find(b"=") != -1: enc, rest = data[:-2], data[-2:] else: enc = data - rest = b'' + rest = b"" # Encode and write, if we have data. if len(enc) > 0: @@ -146,7 +146,7 @@ def close(self): """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ - if hasattr(self.underlying, 'close'): + if hasattr(self.underlying, "close"): self.underlying.close() def finalize(self): @@ -161,10 +161,10 @@ def finalize(self): # If we have a cache, write and then remove it. if len(self.cache) > 0: self.underlying.write(binascii.a2b_qp(self.cache)) - self.cache = b'' + self.cache = b"" # Finalize our underlying stream. - if hasattr(self.underlying, 'finalize'): + if hasattr(self.underlying, "finalize"): self.underlying.finalize() def __repr__(self): diff --git a/multipart/exceptions.py b/multipart/exceptions.py index 016e7f7..cc3671f 100644 --- a/multipart/exceptions.py +++ b/multipart/exceptions.py @@ -1,6 +1,5 @@ class FormParserError(ValueError): """Base error class for our form parser.""" - pass class ParseError(FormParserError): @@ -17,30 +16,19 @@ class MultipartParseError(ParseError): """This is a specific error that is raised when the MultipartParser detects an error while parsing. """ - pass class QuerystringParseError(ParseError): """This is a specific error that is raised when the QuerystringParser detects an error while parsing. """ - pass class DecodeError(ParseError): """This exception is raised when there is a decoding error - for example with the Base64Decoder or QuotedPrintableDecoder. """ - pass - - -# On Python 3.3, IOError is the same as OSError, so we don't want to inherit -# from both of them. We handle this case below. -if IOError is not OSError: # pragma: no cover - class FileError(FormParserError, IOError, OSError): - """Exception class for problems with the File class.""" - pass -else: # pragma: no cover - class FileError(FormParserError, OSError): - """Exception class for problems with the File class.""" - pass + + +class FileError(FormParserError, OSError): + """Exception class for problems with the File class.""" diff --git a/multipart/multipart.py b/multipart/multipart.py index a9f1f9f..221bb71 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1,119 +1,176 @@ -from .decoders import * -from .exceptions import * +from __future__ import annotations +import logging import os -import re -import sys import shutil -import logging +import sys import tempfile +from email.message import Message +from enum import IntEnum from io import BytesIO from numbers import Number +from typing import TYPE_CHECKING + +from .decoders import Base64Decoder, QuotedPrintableDecoder +from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError + +if TYPE_CHECKING: # pragma: no cover + from typing import Callable, TypedDict + + class QuerystringCallbacks(TypedDict, total=False): + on_field_start: Callable[[], None] + on_field_name: Callable[[bytes, int, int], None] + on_field_data: Callable[[bytes, int, int], None] + on_field_end: Callable[[], None] + on_end: Callable[[], None] + + class OctetStreamCallbacks(TypedDict, total=False): + on_start: Callable[[], None] + on_data: Callable[[bytes, int, int], None] + on_end: Callable[[], None] + + class MultipartCallbacks(TypedDict, total=False): + on_part_begin: Callable[[], None] + on_part_data: Callable[[bytes, int, int], None] + on_part_end: Callable[[], None] + on_headers_begin: Callable[[], None] + on_header_field: Callable[[bytes, int, int], None] + on_header_value: Callable[[bytes, int, int], None] + on_header_end: Callable[[], None] + on_headers_finished: Callable[[], None] + on_end: Callable[[], None] + + class FormParserConfig(TypedDict, total=False): + UPLOAD_DIR: str | None + UPLOAD_KEEP_FILENAME: bool + UPLOAD_KEEP_EXTENSIONS: bool + UPLOAD_ERROR_ON_BAD_CTE: bool + MAX_MEMORY_FILE_SIZE: int + MAX_BODY_SIZE: float + + class FileConfig(TypedDict, total=False): + UPLOAD_DIR: str | None + UPLOAD_DELETE_TMP: bool + UPLOAD_KEEP_FILENAME: bool + UPLOAD_KEEP_EXTENSIONS: bool + MAX_MEMORY_FILE_SIZE: int + # Unique missing object. _missing = object() -# States for the querystring parser. -STATE_BEFORE_FIELD = 0 -STATE_FIELD_NAME = 1 -STATE_FIELD_DATA = 2 - -# States for the multipart parser -STATE_START = 0 -STATE_START_BOUNDARY = 1 -STATE_HEADER_FIELD_START = 2 -STATE_HEADER_FIELD = 3 -STATE_HEADER_VALUE_START = 4 -STATE_HEADER_VALUE = 5 -STATE_HEADER_VALUE_ALMOST_DONE = 6 -STATE_HEADERS_ALMOST_DONE = 7 -STATE_PART_DATA_START = 8 -STATE_PART_DATA = 9 -STATE_PART_DATA_END = 10 -STATE_END = 11 - -STATES = [ - "START", - "START_BOUNDARY", "HEADER_FIELD_START", "HEADER_FIELD", "HEADER_VALUE_START", "HEADER_VALUE", - "HEADER_VALUE_ALMOST_DONE", "HEADRES_ALMOST_DONE", "PART_DATA_START", "PART_DATA", "PART_DATA_END", "END" -] + +class QuerystringState(IntEnum): + """Querystring parser states. + + These are used to keep track of the state of the parser, and are used to determine + what to do when new data is encountered. + """ + + BEFORE_FIELD = 0 + FIELD_NAME = 1 + FIELD_DATA = 2 + + +class MultipartState(IntEnum): + """Multipart parser states. + + These are used to keep track of the state of the parser, and are used to determine + what to do when new data is encountered. + """ + + START = 0 + START_BOUNDARY = 1 + HEADER_FIELD_START = 2 + HEADER_FIELD = 3 + HEADER_VALUE_START = 4 + HEADER_VALUE = 5 + HEADER_VALUE_ALMOST_DONE = 6 + HEADERS_ALMOST_DONE = 7 + PART_DATA_START = 8 + PART_DATA = 9 + PART_DATA_END = 10 + END = 11 # Flags for the multipart parser. -FLAG_PART_BOUNDARY = 1 -FLAG_LAST_BOUNDARY = 2 +FLAG_PART_BOUNDARY = 1 +FLAG_LAST_BOUNDARY = 2 # Get constants. Since iterating over a str on Python 2 gives you a 1-length # string, but iterating over a bytes object on Python 3 gives you an integer, # we need to save these constants. -CR = b'\r'[0] -LF = b'\n'[0] -COLON = b':'[0] -SPACE = b' '[0] -HYPHEN = b'-'[0] -AMPERSAND = b'&'[0] -SEMICOLON = b';'[0] -LOWER_A = b'a'[0] -LOWER_Z = b'z'[0] -NULL = b'\x00'[0] +CR = b"\r"[0] +LF = b"\n"[0] +COLON = b":"[0] +SPACE = b" "[0] +HYPHEN = b"-"[0] +AMPERSAND = b"&"[0] +SEMICOLON = b";"[0] +LOWER_A = b"a"[0] +LOWER_Z = b"z"[0] +NULL = b"\x00"[0] + # Lower-casing a character is different, because of the difference between # str on Py2, and bytes on Py3. Same with getting the ordinal value of a byte, # and joining a list of bytes together. # These functions abstract that. -lower_char = lambda c: c | 0x20 -ord_char = lambda c: c -join_bytes = lambda b: bytes(list(b)) - -# These are regexes for parsing header values. -SPECIAL_CHARS = re.escape(b'()<>@,;:\\"/[]?={} \t') -QUOTED_STR = br'"(?:\\.|[^"])*"' -VALUE_STR = br'(?:[^' + SPECIAL_CHARS + br']+|' + QUOTED_STR + br')' -OPTION_RE_STR = ( - br'(?:;|^)\s*([^' + SPECIAL_CHARS + br']+)\s*=\s*(' + VALUE_STR + br')' -) -OPTION_RE = re.compile(OPTION_RE_STR) -QUOTE = b'"'[0] - - -def parse_options_header(value): +def lower_char(c): + return c | 0x20 + + +def ord_char(c): + return c + + +def join_bytes(b): + return bytes(list(b)) + + +def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]: """ Parses a Content-Type header into a value in the following format: (content_type, {parameters}) """ + # Uses email.message.Message to parse the header as described in PEP 594. + # Ref: https://peps.python.org/pep-0594/#cgi if not value: - return (b'', {}) + return (b"", {}) + + # If we are passed bytes, we assume that it conforms to WSGI, encoding in latin-1. + if isinstance(value, bytes): # pragma: no cover + value = value.decode("latin-1") - # If we are passed a string, we assume that it conforms to WSGI and does - # not contain any code point that's not in latin-1. - if isinstance(value, str): # pragma: no cover - value = value.encode('latin-1') + # For types + assert isinstance(value, str), "Value should be a string by now" # If we have no options, return the string as-is. - if b';' not in value: - return (value.lower().strip(), {}) + if ";" not in value: + return (value.lower().strip().encode("latin-1"), {}) # Split at the first semicolon, to get our value and then options. - ctype, rest = value.split(b';', 1) + # ctype, rest = value.split(b';', 1) + message = Message() + message["content-type"] = value + params = message.get_params() + # If there were no parameters, this would have already returned above + assert params, "At least the content type value should be present" + ctype = params.pop(0)[0].encode("latin-1") options = {} - - # Parse the options. - for match in OPTION_RE.finditer(rest): - key = match.group(1).lower() - value = match.group(2) - if value[0] == QUOTE and value[-1] == QUOTE: - # Unquote the value. - value = value[1:-1] - value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"') - + for param in params: + key, value = param + # If the value returned from get_params() is a 3-tuple, the last + # element corresponds to the value. + # See: https://docs.python.org/3/library/email.compat32-message.html + if isinstance(value, tuple): + value = value[-1] # If the value is a filename, we need to fix a bug on IE6 that sends # the full file path instead of the filename. - if key == b'filename': - if value[1:3] == b':\\' or value[:2] == b'\\\\': - value = value.split(b'\\')[-1] - - options[key] = value - + if key == "filename": + if value[1:3] == ":\\" or value[:2] == "\\\\": + value = value.split("\\")[-1] + options[key.encode("latin-1")] = value.encode("latin-1") return ctype, options @@ -132,15 +189,16 @@ class Field: :param name: the name of the form field """ - def __init__(self, name): + + def __init__(self, name: str): self._name = name - self._value = [] + self._value: list[bytes] = [] # We cache the joined version of _value for speed. self._cache = _missing @classmethod - def from_value(klass, name, value): + def from_value(cls, name: str, value: bytes | None) -> Field: """Create an instance of a :class:`Field`, and set the corresponding value - either None or an actual value. This method will also finalize the Field itself. @@ -150,7 +208,7 @@ def from_value(klass, name, value): None """ - f = klass(name) + f = cls(name) if value is None: f.set_none() else: @@ -158,14 +216,14 @@ def from_value(klass, name, value): f.finalize() return f - def write(self, data): + def write(self, data: bytes) -> int: """Write some data into the form field. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes) -> int: """This method is a callback that will be called whenever data is written to the Field. @@ -175,27 +233,24 @@ def on_data(self, data): self._cache = _missing return len(data) - def on_end(self): - """This method is called whenever the Field is finalized. - """ + def on_end(self) -> None: + """This method is called whenever the Field is finalized.""" if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) - def finalize(self): - """Finalize the form field. - """ + def finalize(self) -> None: + """Finalize the form field.""" self.on_end() - def close(self): - """Close the Field object. This will free any underlying cache. - """ + def close(self) -> None: + """Close the Field object. This will free any underlying cache.""" # Free our value array. if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) del self._value - def set_none(self): + def set_none(self) -> None: """Some fields in a querystring can possibly have a value of None - for example, the string "foo&bar=&baz=asdf" will have a field with the name "foo" and value None, one with name "bar" and value "", and one @@ -205,7 +260,7 @@ def set_none(self): self._cache = None @property - def field_name(self): + def field_name(self) -> str: """This property returns the name of the field.""" return self._name @@ -213,20 +268,17 @@ def field_name(self): def value(self): """This property returns the value of the form field.""" if self._cache is _missing: - self._cache = b''.join(self._value) + self._cache = b"".join(self._value) return self._cache - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, Field): - return ( - self.field_name == other.field_name and - self.value == other.value - ) + return self.field_name == other.field_name and self.value == other.value else: return NotImplemented - def __repr__(self): + def __repr__(self) -> str: if len(self.value) > 97: # We get the repr, and then insert three dots before the final # quote. @@ -234,11 +286,7 @@ def __repr__(self): else: v = repr(self.value) - return "{}(field_name={!r}, value={})".format( - self.__class__.__name__, - self.field_name, - v - ) + return "{}(field_name={!r}, value={})".format(self.__class__.__name__, self.field_name, v) class File: @@ -300,7 +348,8 @@ class File: :param config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ - def __init__(self, file_name, field_name=None, config={}): + + def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}): # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -323,16 +372,15 @@ def __init__(self, file_name, field_name=None, config={}): self._ext = ext @property - def field_name(self): + def field_name(self) -> bytes | None: """The form field associated with this file. May be None if there isn't one, for example when we have an application/octet-stream upload. """ return self._field_name @property - def file_name(self): - """The file name given in the upload request. - """ + def file_name(self) -> bytes | None: + """The file name given in the upload request.""" return self._file_name @property @@ -358,13 +406,13 @@ def size(self): return self._bytes_written @property - def in_memory(self): + def in_memory(self) -> bool: """A boolean representing whether or not this file object is currently stored in-memory or on-disk. """ return self._in_memory - def flush_to_disk(self): + def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from the in-memory buffer to a disk file, and then reassign our internal file object to this new disk file. @@ -373,9 +421,7 @@ def flush_to_disk(self): warning will be logged to this module's logger. """ if not self._in_memory: - self.logger.warning( - "Trying to flush to disk when we're not in memory" - ) + self.logger.warning("Trying to flush to disk when we're not in memory") return # Go back to the start of our file. @@ -401,14 +447,13 @@ def flush_to_disk(self): old_fileobj.close() def _get_disk_file(self): - """This function is responsible for getting a file object on-disk for us. - """ + """This function is responsible for getting a file object on-disk for us.""" self.logger.info("Opening a file on disk") - file_dir = self._config.get('UPLOAD_DIR') - keep_filename = self._config.get('UPLOAD_KEEP_FILENAME', False) - keep_extensions = self._config.get('UPLOAD_KEEP_EXTENSIONS', False) - delete_tmp = self._config.get('UPLOAD_DELETE_TMP', True) + file_dir = self._config.get("UPLOAD_DIR") + keep_filename = self._config.get("UPLOAD_KEEP_FILENAME", False) + keep_extensions = self._config.get("UPLOAD_KEEP_EXTENSIONS", False) + delete_tmp = self._config.get("UPLOAD_DELETE_TMP", True) # If we have a directory and are to keep the filename... if file_dir is not None and keep_filename: @@ -423,8 +468,8 @@ def _get_disk_file(self): path = os.path.join(file_dir, fname) try: self.logger.info("Opening file: %r", path) - tmp_file = open(path, 'w+b') - except OSError as e: + tmp_file = open(path, "w+b") + except OSError: tmp_file = None self.logger.exception("Error opening temporary file") @@ -439,18 +484,17 @@ def _get_disk_file(self): if isinstance(ext, bytes): ext = ext.decode(sys.getfilesystemencoding()) - options['suffix'] = ext + options["suffix"] = ext if file_dir is not None: d = file_dir if isinstance(d, bytes): d = d.decode(sys.getfilesystemencoding()) - options['dir'] = d - options['delete'] = delete_tmp + options["dir"] = d + options["delete"] = delete_tmp # Create a temporary (named) file with the appropriate settings. - self.logger.info("Creating a temporary file with options: %r", - options) + self.logger.info("Creating a temporary file with options: %r", options) try: tmp_file = tempfile.NamedTemporaryFile(**options) except OSError: @@ -466,14 +510,14 @@ def _get_disk_file(self): self._actual_file_name = fname return tmp_file - def write(self, data): + def write(self, data: bytes): """Write some data to the File. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes): """This method is a callback that will be called whenever data is written to the File. @@ -487,49 +531,44 @@ def on_data(self, data): # If the bytes written isn't the same as the length, just return. if bwritten != len(data): - self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, - len(data)) + self.logger.warning("bwritten != len(data) (%d != %d)", bwritten, len(data)) return bwritten # Keep track of how many bytes we've written. self._bytes_written += bwritten # If we're in-memory and are over our limit, we create a file. - if (self._in_memory and - self._config.get('MAX_MEMORY_FILE_SIZE') is not None and - (self._bytes_written > - self._config.get('MAX_MEMORY_FILE_SIZE'))): + if ( + self._in_memory + and self._config.get("MAX_MEMORY_FILE_SIZE") is not None + and (self._bytes_written > self._config.get("MAX_MEMORY_FILE_SIZE")) + ): self.logger.info("Flushing to disk") self.flush_to_disk() # Return the number of bytes written. return bwritten - def on_end(self): - """This method is called whenever the Field is finalized. - """ + def on_end(self) -> None: + """This method is called whenever the Field is finalized.""" # Flush the underlying file object self._fileobj.flush() - def finalize(self): + def finalize(self) -> None: """Finalize the form file. This will not close the underlying file, but simply signal that we are finished writing to the File. """ self.on_end() - def close(self): + def close(self) -> None: """Close the File object. This will actually close the underlying file object (whether it's a :class:`io.BytesIO` or an actual file object). """ self._fileobj.close() - def __repr__(self): - return "{}(file_name={!r}, field_name={!r})".format( - self.__class__.__name__, - self.file_name, - self.field_name - ) + def __repr__(self) -> str: + return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name) class BaseParser: @@ -552,10 +591,11 @@ class BaseParser: The callback is not passed a copy of the data, since copying severely hurts performance. """ + def __init__(self): self.logger = logging.getLogger(__name__) - def callback(self, name, data=None, start=None, end=None): + def callback(self, name: str, data=None, start=None, end=None): """This function calls a provided callback with some data. If the callback is not set, will do nothing. @@ -586,7 +626,7 @@ def callback(self, name, data=None, start=None, end=None): self.logger.debug("Calling %s with no data", name) func() - def set_callback(self, name, new_func): + def set_callback(self, name: str, new_func): """Update the function for a callback. Removes from the callbacks dict if new_func is None. @@ -597,15 +637,15 @@ def set_callback(self, name, new_func): exist). """ if new_func is None: - self.callbacks.pop('on_' + name, None) + self.callbacks.pop("on_" + name, None) else: - self.callbacks['on_' + name] = new_func + self.callbacks["on_" + name] = new_func def close(self): - pass # pragma: no cover + pass # pragma: no cover def finalize(self): - pass # pragma: no cover + pass # pragma: no cover def __repr__(self): return "%s()" % self.__class__.__name__ @@ -638,25 +678,25 @@ class OctetStreamParser(BaseParser): :param max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. """ - def __init__(self, callbacks={}, max_size=float('inf')): + + def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")): super().__init__() self.callbacks = callbacks self._started = False if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 - def write(self, data): + def write(self, data: bytes): """Write some data to the parser, which will perform size verification, and then pass the data to the underlying callback. :param data: a bytestring """ if not self._started: - self.callback('start') + self.callback("start") self._started = True # Truncate data length. @@ -664,24 +704,27 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size # Increment size, then callback, in case there's an exception. self._current_size += data_len - self.callback('data', data, 0, data_len) + self.callback("data", data, 0, data_len) return data_len - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing, and sends the on_end callback. """ - self.callback('end') + self.callback("end") - def __repr__(self): + def __repr__(self) -> str: return "%s()" % self.__class__.__name__ @@ -730,25 +773,26 @@ class QuerystringParser(BaseParser): :param max_size: The maximum size of body to parse. Defaults to infinity - i.e. unbounded. """ - def __init__(self, callbacks={}, strict_parsing=False, - max_size=float('inf')): + + state: QuerystringState + + def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size=float("inf")): super().__init__() - self.state = STATE_BEFORE_FIELD + self.state = QuerystringState.BEFORE_FIELD self._found_sep = False self.callbacks = callbacks # Max-size stuff if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 # Should parsing be strict? self.strict_parsing = strict_parsing - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, parse into either a field name or value, and then pass the corresponding data to the underlying callback. If an error is @@ -763,10 +807,13 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size l = 0 @@ -777,7 +824,7 @@ def write(self, data): return l - def _internal_write(self, data, length): + def _internal_write(self, data: bytes, length: int) -> int: state = self.state strict_parsing = self.strict_parsing found_sep = self._found_sep @@ -787,7 +834,7 @@ def _internal_write(self, data, length): ch = data[i] # Depending on our state... - if state == STATE_BEFORE_FIELD: + if state == QuerystringState.BEFORE_FIELD: # If the 'found_sep' flag is set, we've already encountered # and skipped a single separator. If so, we check our strict # parsing flag and decide what to do. Otherwise, we haven't @@ -798,15 +845,11 @@ def _internal_write(self, data, length): if found_sep: # If we're parsing strictly, we disallow blank chunks. if strict_parsing: - e = QuerystringParseError( - "Skipping duplicate ampersand/semicolon at " - "%d" % i - ) + e = QuerystringParseError("Skipping duplicate ampersand/semicolon at %d" % i) e.offset = i raise e else: - self.logger.debug("Skipping duplicate ampersand/" - "semicolon at %d", i) + self.logger.debug("Skipping duplicate ampersand/semicolon at %d", i) else: # This case is when we're skipping the (first) # separator between fields, so we just set our flag @@ -816,98 +859,97 @@ def _internal_write(self, data, length): # Emit a field-start event, and go to that state. Also, # reset the "found_sep" flag, for the next time we get to # this state. - self.callback('field_start') + self.callback("field_start") i -= 1 - state = STATE_FIELD_NAME + state = QuerystringState.FIELD_NAME found_sep = False - elif state == STATE_FIELD_NAME: + elif state == QuerystringState.FIELD_NAME: # Try and find a separator - we ensure that, if we do, we only # look for the equal sign before it. - sep_pos = data.find(b'&', i) + sep_pos = data.find(b"&", i) if sep_pos == -1: - sep_pos = data.find(b';', i) + sep_pos = data.find(b";", i) # See if we can find an equals sign in the remaining data. If # so, we can immediately emit the field name and jump to the # data state. if sep_pos != -1: - equals_pos = data.find(b'=', i, sep_pos) + equals_pos = data.find(b"=", i, sep_pos) else: - equals_pos = data.find(b'=', i) + equals_pos = data.find(b"=", i) if equals_pos != -1: # Emit this name. - self.callback('field_name', data, i, equals_pos) + self.callback("field_name", data, i, equals_pos) # Jump i to this position. Note that it will then have 1 # added to it below, which means the next iteration of this # loop will inspect the character after the equals sign. i = equals_pos - state = STATE_FIELD_DATA + state = QuerystringState.FIELD_DATA else: # No equals sign found. if not strict_parsing: - # See also comments in the STATE_FIELD_DATA case below. + # See also comments in the QuerystringState.FIELD_DATA case below. # If we found the separator, we emit the name and just # end - there's no data callback at all (not even with # a blank value). if sep_pos != -1: - self.callback('field_name', data, i, sep_pos) - self.callback('field_end') + self.callback("field_name", data, i, sep_pos) + self.callback("field_end") i = sep_pos - 1 - state = STATE_BEFORE_FIELD + state = QuerystringState.BEFORE_FIELD else: # Otherwise, no separator in this block, so the # rest of this chunk must be a name. - self.callback('field_name', data, i, length) + self.callback("field_name", data, i, length) i = length else: # We're parsing strictly. If we find a separator, # this is an error - we require an equals sign. if sep_pos != -1: - e = QuerystringParseError( + e = QuerystringParseError( "When strict_parsing is True, we require an " "equals sign in all field chunks. Did not " - "find one in the chunk that starts at %d" % - (i,) + "find one in the chunk that starts at %d" % (i,) ) e.offset = i raise e # No separator in the rest of this chunk, so it's just # a field name. - self.callback('field_name', data, i, length) + self.callback("field_name", data, i, length) i = length - elif state == STATE_FIELD_DATA: + elif state == QuerystringState.FIELD_DATA: # Try finding either an ampersand or a semicolon after this # position. - sep_pos = data.find(b'&', i) + sep_pos = data.find(b"&", i) if sep_pos == -1: - sep_pos = data.find(b';', i) + sep_pos = data.find(b";", i) # If we found it, callback this bit as data and then go back # to expecting to find a field. if sep_pos != -1: - self.callback('field_data', data, i, sep_pos) - self.callback('field_end') + self.callback("field_data", data, i, sep_pos) + self.callback("field_end") # Note that we go to the separator, which brings us to the # "before field" state. This allows us to properly emit # "field_start" events only when we actually have data for # a field of some sort. i = sep_pos - 1 - state = STATE_BEFORE_FIELD + state = QuerystringState.BEFORE_FIELD # Otherwise, emit the rest as data and finish. else: - self.callback('field_data', data, i, length) + self.callback("field_data", data, i, length) i = length - else: # pragma: no cover (error case) + else: # pragma: no cover (error case) msg = "Reached an unknown state %d at %d" % (state, i) self.logger.warning(msg) e = QuerystringParseError(msg) @@ -920,20 +962,19 @@ def _internal_write(self, data, length): self._found_sep = found_sep return len(data) - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing, if we're still in the middle of a field, an on_field_end callback, and then the on_end callback. """ # If we're currently in the middle of a field, we finish it. - if self.state == STATE_FIELD_DATA: - self.callback('field_end') - self.callback('end') + if self.state == QuerystringState.FIELD_DATA: + self.callback("field_end") + self.callback("end") - def __repr__(self): + def __repr__(self) -> str: return "{}(strict_parsing={!r}, max_size={!r})".format( - self.__class__.__name__, - self.strict_parsing, self.max_size + self.__class__.__name__, self.strict_parsing, self.max_size ) @@ -992,17 +1033,16 @@ class MultipartParser(BaseParser): i.e. unbounded. """ - def __init__(self, boundary, callbacks={}, max_size=float('inf')): + def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size=float("inf")): # Initialize parser state. super().__init__() - self.state = STATE_START + self.state = MultipartState.START self.index = self.flags = 0 self.callbacks = callbacks if not isinstance(max_size, Number) or max_size < 1: - raise ValueError("max_size must be a positive number, not %r" % - max_size) + raise ValueError("max_size must be a positive number, not %r" % max_size) self.max_size = max_size self._current_size = 0 @@ -1019,9 +1059,9 @@ def __init__(self, boundary, callbacks={}, max_size=float('inf')): # self.skip = tuple(skip) # Save our boundary. - if isinstance(boundary, str): # pragma: no cover - boundary = boundary.encode('latin-1') - self.boundary = b'\r\n--' + boundary + if isinstance(boundary, str): # pragma: no cover + boundary = boundary.encode("latin-1") + self.boundary = b"\r\n--" + boundary # Get a set of characters that belong to our boundary. self.boundary_chars = frozenset(self.boundary) @@ -1032,7 +1072,7 @@ def __init__(self, boundary, callbacks={}, max_size=float('inf')): # '--\r\n' is 8 bytes. self.lookbehind = [NULL for x in range(len(boundary) + 8)] - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then parse the data into the appropriate location (e.g. header, data, etc.), and pass this on to the underlying callback. If an error @@ -1047,10 +1087,13 @@ def write(self, data): if (self._current_size + data_len) > self.max_size: # We truncate the length of data that we are to process. new_size = int(self.max_size - self._current_size) - self.logger.warning("Current size is %d (max %d), so truncating " - "data length from %d to %d", - self._current_size, self.max_size, data_len, - new_size) + self.logger.warning( + "Current size is %d (max %d), so truncating data length from %d to %d", + self._current_size, + self.max_size, + data_len, + new_size, + ) data_len = new_size l = 0 @@ -1061,7 +1104,7 @@ def write(self, data): return l - def _internal_write(self, data, length): + def _internal_write(self, data: bytes, length: int) -> int: # Get values from locals. boundary = self.boundary @@ -1108,7 +1151,7 @@ def data_callback(name, remaining=False): while i < length: c = data[i] - if state == STATE_START: + if state == MultipartState.START: # Skip leading newlines if c == CR or c == LF: i += 1 @@ -1120,10 +1163,10 @@ def data_callback(name, remaining=False): # Move to the next state, but decrement i so that we re-process # this character. - state = STATE_START_BOUNDARY + state = MultipartState.START_BOUNDARY i -= 1 - elif state == STATE_START_BOUNDARY: + elif state == MultipartState.START_BOUNDARY: # Check to ensure that the last 2 characters in our boundary # are CRLF. if index == len(boundary) - 2: @@ -1149,16 +1192,15 @@ def data_callback(name, remaining=False): index = 0 # Callback for the start of a part. - self.callback('part_begin') + self.callback("part_begin") # Move to the next character and state. - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START else: # Check to ensure our boundary matches if c != boundary[index + 2]: - msg = "Did not find boundary character %r at index " \ - "%d" % (c, index + 2) + msg = "Did not find boundary character %r at index " "%d" % (c, index + 2) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i @@ -1167,25 +1209,25 @@ def data_callback(name, remaining=False): # Increment index into boundary and continue. index += 1 - elif state == STATE_HEADER_FIELD_START: + elif state == MultipartState.HEADER_FIELD_START: # Mark the start of a header field here, reset the index, and # continue parsing our header field. index = 0 # Set a mark of our header field. - set_mark('header_field') + set_mark("header_field") # Move to parsing header fields. - state = STATE_HEADER_FIELD + state = MultipartState.HEADER_FIELD i -= 1 - elif state == STATE_HEADER_FIELD: + elif state == MultipartState.HEADER_FIELD: # If we've reached a CR at the beginning of a header, it means # that we've reached the second of 2 newlines, and so there are # no more headers to parse. if c == CR: - delete_mark('header_field') - state = STATE_HEADERS_ALMOST_DONE + delete_mark("header_field") + state = MultipartState.HEADERS_ALMOST_DONE i += 1 continue @@ -1207,49 +1249,47 @@ def data_callback(name, remaining=False): raise e # Call our callback with the header field. - data_callback('header_field') + data_callback("header_field") # Move to parsing the header value. - state = STATE_HEADER_VALUE_START + state = MultipartState.HEADER_VALUE_START else: # Lower-case this character, and ensure that it is in fact # a valid letter. If not, it's an error. cl = lower_char(c) if cl < LOWER_A or cl > LOWER_Z: - msg = "Found non-alphanumeric character %r in " \ - "header at %d" % (c, i) + msg = "Found non-alphanumeric character %r in " "header at %d" % (c, i) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i raise e - elif state == STATE_HEADER_VALUE_START: + elif state == MultipartState.HEADER_VALUE_START: # Skip leading spaces. if c == SPACE: i += 1 continue # Mark the start of the header value. - set_mark('header_value') + set_mark("header_value") # Move to the header-value state, reprocessing this character. - state = STATE_HEADER_VALUE + state = MultipartState.HEADER_VALUE i -= 1 - elif state == STATE_HEADER_VALUE: + elif state == MultipartState.HEADER_VALUE: # If we've got a CR, we're nearly done our headers. Otherwise, # we do nothing and just move past this character. if c == CR: - data_callback('header_value') - self.callback('header_end') - state = STATE_HEADER_VALUE_ALMOST_DONE + data_callback("header_value") + self.callback("header_end") + state = MultipartState.HEADER_VALUE_ALMOST_DONE - elif state == STATE_HEADER_VALUE_ALMOST_DONE: + elif state == MultipartState.HEADER_VALUE_ALMOST_DONE: # The last character should be a LF. If not, it's an error. if c != LF: - msg = "Did not find LF character at end of header " \ - "(found %r)" % (c,) + msg = "Did not find LF character at end of header " "(found %r)" % (c,) self.logger.warning(msg) e = MultipartParseError(msg) e.offset = i @@ -1258,9 +1298,9 @@ def data_callback(name, remaining=False): # Move back to the start of another header. Note that if that # state detects ANOTHER newline, it'll trigger the end of our # headers. - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START - elif state == STATE_HEADERS_ALMOST_DONE: + elif state == MultipartState.HEADERS_ALMOST_DONE: # We're almost done our headers. This is reached when we parse # a CR at the beginning of a header, so our next character # should be a LF, or it's an error. @@ -1271,18 +1311,18 @@ def data_callback(name, remaining=False): e.offset = i raise e - self.callback('headers_finished') - state = STATE_PART_DATA_START + self.callback("headers_finished") + state = MultipartState.PART_DATA_START - elif state == STATE_PART_DATA_START: + elif state == MultipartState.PART_DATA_START: # Mark the start of our part data. - set_mark('part_data') + set_mark("part_data") # Start processing part data, including this character. - state = STATE_PART_DATA + state = MultipartState.PART_DATA i -= 1 - elif state == STATE_PART_DATA: + elif state == MultipartState.PART_DATA: # We're processing our part data right now. During this, we # need to efficiently search for our boundary, since any data # on any number of lines can be a part of the current data. @@ -1324,7 +1364,7 @@ def data_callback(name, remaining=False): # If we found a match for our boundary, we send the # existing data. if index == 0: - data_callback('part_data') + data_callback("part_data") # The current character matches, so continue! index += 1 @@ -1360,23 +1400,23 @@ def data_callback(name, remaining=False): # We need a LF character next. if c == LF: # Unset the part boundary flag. - flags &= (~FLAG_PART_BOUNDARY) + flags &= ~FLAG_PART_BOUNDARY # Callback indicating that we've reached the end of # a part, and are starting a new one. - self.callback('part_end') - self.callback('part_begin') + self.callback("part_end") + self.callback("part_begin") # Move to parsing new headers. index = 0 - state = STATE_HEADER_FIELD_START + state = MultipartState.HEADER_FIELD_START i += 1 continue # We didn't find an LF character, so no match. Reset # our index and clear our flag. index = 0 - flags &= (~FLAG_PART_BOUNDARY) + flags &= ~FLAG_PART_BOUNDARY # Otherwise, if we're at the last boundary (i.e. we've # seen a hyphen already)... @@ -1385,9 +1425,9 @@ def data_callback(name, remaining=False): if c == HYPHEN: # Callback to end the current part, and then the # message. - self.callback('part_end') - self.callback('end') - state = STATE_END + self.callback("part_end") + self.callback("end") + state = MultipartState.END else: # No match, so reset index. index = 0 @@ -1404,24 +1444,24 @@ def data_callback(name, remaining=False): elif prev_index > 0: # Callback to write the saved data. lb_data = join_bytes(self.lookbehind) - self.callback('part_data', lb_data, 0, prev_index) + self.callback("part_data", lb_data, 0, prev_index) # Overwrite our previous index. prev_index = 0 # Re-set our mark for part data. - set_mark('part_data') + set_mark("part_data") # Re-consider the current character, since this could be # the start of the boundary itself. i -= 1 - elif state == STATE_END: + elif state == MultipartState.END: # Do nothing and just consume a byte in the end state. if c not in (CR, LF): self.logger.warning("Consuming a byte '0x%x' in the end state", c) - else: # pragma: no cover (error case) + else: # pragma: no cover (error case) # We got into a strange state somehow! Just stop processing. msg = "Reached an unknown state %d at %d" % (state, i) self.logger.warning(msg) @@ -1440,9 +1480,9 @@ def data_callback(name, remaining=False): # that we haven't yet reached the end of this 'thing'. So, by setting # the mark to 0, we cause any data callbacks that take place in future # calls to this function to start from the beginning of that buffer. - data_callback('header_field', True) - data_callback('header_value', True) - data_callback('part_data', True) + data_callback("header_field", True) + data_callback("header_value", True) + data_callback("part_data", True) # Save values to locals. self.state = state @@ -1453,14 +1493,14 @@ def data_callback(name, remaining=False): # all of it. return length - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing. Note: It does not currently, but in the future, it will verify that we are in the final state of the parser (i.e. the end of the multipart message is well-formed), and, if not, throw an error. """ - # TODO: verify that we're in the state STATE_END, otherwise throw an + # TODO: verify that we're in the state MultipartState.END, otherwise throw an # error or otherwise state that we're not finished parsing. pass @@ -1520,23 +1560,31 @@ class FormParser: default values. """ + #: This is the default configuration for our form parser. #: Note: all file sizes should be in bytes. - DEFAULT_CONFIG = { - 'MAX_BODY_SIZE': float('inf'), - 'MAX_MEMORY_FILE_SIZE': 1 * 1024 * 1024, - 'UPLOAD_DIR': None, - 'UPLOAD_KEEP_FILENAME': False, - 'UPLOAD_KEEP_EXTENSIONS': False, - + DEFAULT_CONFIG: FormParserConfig = { + "MAX_BODY_SIZE": float("inf"), + "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024, + "UPLOAD_DIR": None, + "UPLOAD_KEEP_FILENAME": False, + "UPLOAD_KEEP_EXTENSIONS": False, # Error on invalid Content-Transfer-Encoding? - 'UPLOAD_ERROR_ON_BAD_CTE': False, + "UPLOAD_ERROR_ON_BAD_CTE": False, } - def __init__(self, content_type, on_field, on_file, on_end=None, - boundary=None, file_name=None, FileClass=File, - FieldClass=Field, config={}): - + def __init__( + self, + content_type, + on_field, + on_file, + on_end=None, + boundary=None, + file_name=None, + FileClass=File, + FieldClass=Field, + config: FormParserConfig = {}, + ): self.logger = logging.getLogger(__name__) # Save variables. @@ -1559,18 +1607,18 @@ def __init__(self, content_type, on_field, on_file, on_end=None, self.config.update(config) # Depending on the Content-Type, we instantiate the correct parser. - if content_type == 'application/octet-stream': + if content_type == "application/octet-stream": # Work around the lack of 'nonlocal' in Py2 class vars: f = None - def on_start(): + def on_start() -> None: vars.f = FileClass(file_name, None, config=self.config) - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int) -> None: vars.f.write(data[start:end]) - def on_end(): + def on_end() -> None: # Finalize the file itself. vars.f.finalize() @@ -1581,42 +1629,36 @@ def on_end(): if self.on_end is not None: self.on_end() - callbacks = { - 'on_start': on_start, - 'on_data': on_data, - 'on_end': on_end, - } - # Instantiate an octet-stream parser - parser = OctetStreamParser(callbacks, - max_size=self.config['MAX_BODY_SIZE']) - - elif (content_type == 'application/x-www-form-urlencoded' or - content_type == 'application/x-url-encoded'): + parser = OctetStreamParser( + callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}, + max_size=self.config["MAX_BODY_SIZE"], + ) - name_buffer = [] + elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": + name_buffer: list[bytes] = [] class vars: f = None - def on_field_start(): + def on_field_start() -> None: pass - def on_field_name(data, start, end): + def on_field_name(data: bytes, start: int, end: int) -> None: name_buffer.append(data[start:end]) - def on_field_data(data, start, end): + def on_field_data(data: bytes, start: int, end: int) -> None: if vars.f is None: - vars.f = FieldClass(b''.join(name_buffer)) + vars.f = FieldClass(b"".join(name_buffer)) del name_buffer[:] vars.f.write(data[start:end]) - def on_field_end(): + def on_field_end() -> None: # Finalize and call callback. if vars.f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - vars.f = FieldClass(b''.join(name_buffer)) + vars.f = FieldClass(b"".join(name_buffer)) del name_buffer[:] vars.f.set_none() @@ -1624,32 +1666,29 @@ def on_field_end(): on_field(vars.f) vars.f = None - def on_end(): + def on_end() -> None: if self.on_end is not None: self.on_end() - # Setup callbacks. - callbacks = { - 'on_field_start': on_field_start, - 'on_field_name': on_field_name, - 'on_field_data': on_field_data, - 'on_field_end': on_field_end, - 'on_end': on_end, - } - # Instantiate parser. parser = QuerystringParser( - callbacks=callbacks, - max_size=self.config['MAX_BODY_SIZE'] + callbacks={ + "on_field_start": on_field_start, + "on_field_name": on_field_name, + "on_field_data": on_field_data, + "on_field_end": on_field_end, + "on_end": on_end, + }, + max_size=self.config["MAX_BODY_SIZE"], ) - elif content_type == 'multipart/form-data': + elif content_type == "multipart/form-data": if boundary is None: self.logger.error("No boundary given") raise FormParserError("No boundary given") - header_name = [] - header_value = [] + header_name: list[bytes] = [] + header_value: list[bytes] = [] headers = {} # No 'nonlocal' on Python 2 :-( @@ -1661,41 +1700,41 @@ class vars: def on_part_begin(): pass - def on_part_data(data, start, end): + def on_part_data(data: bytes, start: int, end: int): bytes_processed = vars.writer.write(data[start:end]) # TODO: check for error here. return bytes_processed - def on_part_end(): + def on_part_end() -> None: vars.f.finalize() if vars.is_file: on_file(vars.f) else: on_field(vars.f) - def on_header_field(data, start, end): + def on_header_field(data: bytes, start: int, end: int): header_name.append(data[start:end]) - def on_header_value(data, start, end): + def on_header_value(data: bytes, start: int, end: int): header_value.append(data[start:end]) def on_header_end(): - headers[b''.join(header_name)] = b''.join(header_value) + headers[b"".join(header_name)] = b"".join(header_value) del header_name[:] del header_value[:] - def on_headers_finished(): + def on_headers_finished() -> None: # Reset the 'is file' flag. vars.is_file = False # Parse the content-disposition header. # TODO: handle mixed case - content_disp = headers.get(b'Content-Disposition') + content_disp = headers.get(b"Content-Disposition") disp, options = parse_options_header(content_disp) # Get the field and filename. - field_name = options.get(b'name') - file_name = options.get(b'filename') + field_name = options.get(b"name") + file_name = options.get(b"filename") # TODO: check for errors # Create the proper class. @@ -1708,64 +1747,54 @@ def on_headers_finished(): # Parse the given Content-Transfer-Encoding to determine what # we need to do with the incoming data. # TODO: check that we properly handle 8bit / 7bit encoding. - transfer_encoding = headers.get(b'Content-Transfer-Encoding', - b'7bit') + transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") - if (transfer_encoding == b'binary' or - transfer_encoding == b'8bit' or - transfer_encoding == b'7bit'): + if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit": vars.writer = vars.f - elif transfer_encoding == b'base64': + elif transfer_encoding == b"base64": vars.writer = Base64Decoder(vars.f) - elif transfer_encoding == b'quoted-printable': + elif transfer_encoding == b"quoted-printable": vars.writer = QuotedPrintableDecoder(vars.f) else: - self.logger.warning("Unknown Content-Transfer-Encoding: " - "%r", transfer_encoding) - if self.config['UPLOAD_ERROR_ON_BAD_CTE']: - raise FormParserError( - 'Unknown Content-Transfer-Encoding "{}"'.format( - transfer_encoding - ) - ) + self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding) + if self.config["UPLOAD_ERROR_ON_BAD_CTE"]: + raise FormParserError('Unknown Content-Transfer-Encoding "{}"'.format(transfer_encoding)) else: # If we aren't erroring, then we just treat this as an # unencoded Content-Transfer-Encoding. vars.writer = vars.f - def on_end(): + def on_end() -> None: vars.writer.finalize() if self.on_end is not None: self.on_end() - # These are our callbacks for the parser. - callbacks = { - 'on_part_begin': on_part_begin, - 'on_part_data': on_part_data, - 'on_part_end': on_part_end, - 'on_header_field': on_header_field, - 'on_header_value': on_header_value, - 'on_header_end': on_header_end, - 'on_headers_finished': on_headers_finished, - 'on_end': on_end, - } - # Instantiate a multipart parser. - parser = MultipartParser(boundary, callbacks, - max_size=self.config['MAX_BODY_SIZE']) + parser = MultipartParser( + boundary, + callbacks={ + "on_part_begin": on_part_begin, + "on_part_data": on_part_data, + "on_part_end": on_part_end, + "on_header_field": on_header_field, + "on_header_value": on_header_value, + "on_header_end": on_header_end, + "on_headers_finished": on_headers_finished, + "on_end": on_end, + }, + max_size=self.config["MAX_BODY_SIZE"], + ) else: self.logger.warning("Unknown Content-Type: %r", content_type) - raise FormParserError("Unknown Content-Type: {}".format( - content_type - )) + raise FormParserError("Unknown Content-Type: {}".format(content_type)) self.parser = parser - def write(self, data): + def write(self, data: bytes): """Write some data. The parser will forward this to the appropriate underlying parser. @@ -1775,26 +1804,21 @@ def write(self, data): # TODO: check the parser's return value for errors? return self.parser.write(data) - def finalize(self): + def finalize(self) -> None: """Finalize the parser.""" - if self.parser is not None and hasattr(self.parser, 'finalize'): + if self.parser is not None and hasattr(self.parser, "finalize"): self.parser.finalize() - def close(self): + def close(self) -> None: """Close the parser.""" - if self.parser is not None and hasattr(self.parser, 'close'): + if self.parser is not None and hasattr(self.parser, "close"): self.parser.close() - def __repr__(self): - return "{}(content_type={!r}, parser={!r})".format( - self.__class__.__name__, - self.content_type, - self.parser, - ) + def __repr__(self) -> str: + return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser) -def create_form_parser(headers, on_field, on_file, trust_x_headers=False, - config={}): +def create_form_parser(headers, on_field, on_file, trust_x_headers=False, config={}): """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1814,7 +1838,7 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, :param config: Configuration variables to pass to the FormParser. """ - content_type = headers.get('Content-Type') + content_type = headers.get("Content-Type") if content_type is None: logging.getLogger(__name__).warning("No Content-Type header given") raise ValueError("No Content-Type header given!") @@ -1822,28 +1846,22 @@ def create_form_parser(headers, on_field, on_file, trust_x_headers=False, # Boundaries are optional (the FormParser will raise if one is needed # but not given). content_type, params = parse_options_header(content_type) - boundary = params.get(b'boundary') + boundary = params.get(b"boundary") # We need content_type to be a string, not a bytes object. - content_type = content_type.decode('latin-1') + content_type = content_type.decode("latin-1") # File names are optional. - file_name = headers.get('X-File-Name') + file_name = headers.get("X-File-Name") # Instantiate a form parser. - form_parser = FormParser(content_type, - on_field, - on_file, - boundary=boundary, - file_name=file_name, - config=config) + form_parser = FormParser(content_type, on_field, on_file, boundary=boundary, file_name=file_name, config=config) # Return our parser. return form_parser -def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, - **kwargs): +def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, **kwargs): """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's headers, and a file-like object for the input stream, along with two @@ -1868,11 +1886,11 @@ def parse_form(headers, input_stream, on_field, on_file, chunk_size=1048576, # Read chunks of 100KiB and write to the parser, but never read more than # the given Content-Length, if any. - content_length = headers.get('Content-Length') + content_length = headers.get("Content-Length") if content_length is not None: content_length = int(content_length) else: - content_length = float('inf') + content_length = float("inf") bytes_read = 0 while True: diff --git a/pyproject.toml b/pyproject.toml index 2220432..5833d83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,10 +8,8 @@ dynamic = ["version"] description = "A streaming multipart parser for Python" readme = "README.rst" license = "Apache-2.0" -requires-python = ">=3.7" -authors = [ - { name = "Andrew Dunham", email = "andrew@du.nham.ca" }, -] +requires-python = ">=3.8" +authors = [{ name = "Andrew Dunham", email = "andrew@du.nham.ca" }] classifiers = [ 'Development Status :: 5 - Production/Stable', 'Environment :: Web Environment', @@ -20,46 +18,56 @@ classifiers = [ 'Operating System :: OS Independent', 'Programming Language :: Python :: 3 :: Only', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Topic :: Software Development :: Libraries :: Python Modules', ] dependencies = [] [project.optional-dependencies] dev = [ - "atomicwrites==1.2.1", - "attrs==19.2.0", - "coverage==6.5.0", - "more-itertools==4.3.0", - "pbr==4.3.0", - "pluggy==1.0.0", + "atomicwrites==1.4.1", + "attrs==23.2.0", + "coverage==7.4.1", + "more-itertools==10.2.0", + "pbr==6.0.0", + "pluggy==1.4.0", "py==1.11.0", - "pytest==7.2.0", - "pytest-cov==4.0.0", - "PyYAML==5.1", - "invoke==1.7.3", - "pytest-timeout==2.1.0", + "pytest==8.0.0", + "pytest-cov==4.1.0", + "PyYAML==6.0.1", + "invoke==2.2.0", + "pytest-timeout==2.2.0", + "ruff==0.2.1", "hatch", ] [project.urls] Homepage = "https://github.com/andrew-d/python-multipart" Documentation = "https://andrew-d.github.io/python-multipart/" -Changelog = "https://github.com/andrew-d/python-multipart/tags" +Changelog = "https://github.com/andrew-d/python-multipart/blob/master/CHANGELOG.md" Source = "https://github.com/andrew-d/python-multipart" [tool.hatch.version] path = "multipart/__init__.py" [tool.hatch.build.targets.wheel] -packages = [ - "multipart", -] +packages = ["multipart"] + [tool.hatch.build.targets.sdist] -include = [ - "/multipart", -] +include = ["/multipart", "/tests"] + +[tool.ruff] +line-length = 120 +select = ["E", "F", "I", "FA"] +ignore = ["B904", "B028", "F841", "E741"] + +[tool.ruff.format] +skip-magic-trailing-comma = true + +[tool.ruff.lint.isort] +combine-as-imports = true +split-on-trailing-comma = false diff --git a/requirements.txt b/requirements.txt index ad6bb55..23baf78 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ -atomicwrites==1.2.1 -attrs==19.2.0 -coverage==4.5.1 -more-itertools==4.3.0 -pbr==4.3.0 -pluggy==1.0.0 +atomicwrites==1.4.1 +attrs==23.2.0 +coverage==7.4.1 +more-itertools==10.2.0 +pbr==6.0.0 +pluggy==1.4.0 py==1.11.0 -pytest==7.2.0 -PyYAML==5.4 +pytest==8.0.0 +PyYAML==6.0.1 +ruff==0.2.1 diff --git a/tasks.py b/tasks.py index 3ac7419..f2cec11 100644 --- a/tasks.py +++ b/tasks.py @@ -28,7 +28,7 @@ def test(ctx, all=False): test_cmd.append('-m "not slow_test"') # Test in this directory - test_cmd.append(os.path.join("multipart", "tests")) + test_cmd.append("tests") # Run the command. # TODO: why does this fail with pty=True? diff --git a/multipart/tests/__init__.py b/tests/__init__.py similarity index 100% rename from multipart/tests/__init__.py rename to tests/__init__.py diff --git a/multipart/tests/compat.py b/tests/compat.py similarity index 84% rename from multipart/tests/compat.py rename to tests/compat.py index 897188d..8b0ccae 100644 --- a/multipart/tests/compat.py +++ b/tests/compat.py @@ -1,8 +1,8 @@ +import functools import os import re import sys import types -import functools def ensure_in_path(path): @@ -10,7 +10,7 @@ def ensure_in_path(path): Ensure that a given path is in the sys.path array """ if not os.path.isdir(path): - raise RuntimeError('Tried to add nonexisting path') + raise RuntimeError("Tried to add nonexisting path") def _samefile(x, y): try: @@ -44,7 +44,9 @@ def _samefile(x, y): xfail = pytest.mark.xfail else: - slow_test = lambda x: x + + def slow_test(x): + return x def xfail(*args, **kwargs): if len(args) > 0 and isinstance(args[0], types.FunctionType): @@ -64,8 +66,8 @@ def parametrize(field_names, field_values): # Create a decorator that saves this list of field names and values on the # function for later parametrizing. def decorator(func): - func.__dict__['param_names'] = field_names - func.__dict__['param_values'] = field_values + func.__dict__["param_names"] = field_names + func.__dict__["param_values"] = field_values return func return decorator @@ -73,7 +75,7 @@ def decorator(func): # This is a metaclass that actually performs the parametrization. class ParametrizingMetaclass(type): - IDENTIFIER_RE = re.compile('[^A-Za-z0-9]') + IDENTIFIER_RE = re.compile("[^A-Za-z0-9]") def __new__(klass, name, bases, attrs): new_attrs = attrs.copy() @@ -82,8 +84,8 @@ def __new__(klass, name, bases, attrs): if not isinstance(attr, types.FunctionType): continue - param_names = attr.__dict__.pop('param_names', None) - param_values = attr.__dict__.pop('param_values', None) + param_names = attr.__dict__.pop("param_names", None) + param_values = attr.__dict__.pop("param_values", None) if param_names is None or param_values is None: continue @@ -92,9 +94,7 @@ def __new__(klass, name, bases, attrs): assert len(param_names) == len(values) # Get a repr of the values, and fix it to be a valid identifier - human = '_'.join( - [klass.IDENTIFIER_RE.sub('', repr(x)) for x in values] - ) + human = "_".join([klass.IDENTIFIER_RE.sub("", repr(x)) for x in values]) # Create a new name. # new_name = attr.__name__ + "_%d" % i @@ -128,6 +128,4 @@ def new_func(self): # This is a class decorator that actually applies the above metaclass. def parametrize_class(klass): - return ParametrizingMetaclass(klass.__name__, - klass.__bases__, - klass.__dict__) + return ParametrizingMetaclass(klass.__name__, klass.__bases__, klass.__dict__) diff --git a/multipart/tests/test_data/http/CR_in_header.http b/tests/test_data/http/CR_in_header.http similarity index 100% rename from multipart/tests/test_data/http/CR_in_header.http rename to tests/test_data/http/CR_in_header.http diff --git a/multipart/tests/test_data/http/CR_in_header.yaml b/tests/test_data/http/CR_in_header.yaml similarity index 100% rename from multipart/tests/test_data/http/CR_in_header.yaml rename to tests/test_data/http/CR_in_header.yaml diff --git a/multipart/tests/test_data/http/CR_in_header_value.http b/tests/test_data/http/CR_in_header_value.http similarity index 100% rename from multipart/tests/test_data/http/CR_in_header_value.http rename to tests/test_data/http/CR_in_header_value.http diff --git a/multipart/tests/test_data/http/CR_in_header_value.yaml b/tests/test_data/http/CR_in_header_value.yaml similarity index 100% rename from multipart/tests/test_data/http/CR_in_header_value.yaml rename to tests/test_data/http/CR_in_header_value.yaml diff --git a/multipart/tests/test_data/http/almost_match_boundary.http b/tests/test_data/http/almost_match_boundary.http similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary.http rename to tests/test_data/http/almost_match_boundary.http diff --git a/multipart/tests/test_data/http/almost_match_boundary.yaml b/tests/test_data/http/almost_match_boundary.yaml similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary.yaml rename to tests/test_data/http/almost_match_boundary.yaml diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_CR.http b/tests/test_data/http/almost_match_boundary_without_CR.http similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_CR.http rename to tests/test_data/http/almost_match_boundary_without_CR.http diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml b/tests/test_data/http/almost_match_boundary_without_CR.yaml similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_CR.yaml rename to tests/test_data/http/almost_match_boundary_without_CR.yaml diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_LF.http b/tests/test_data/http/almost_match_boundary_without_LF.http similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_LF.http rename to tests/test_data/http/almost_match_boundary_without_LF.http diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml b/tests/test_data/http/almost_match_boundary_without_LF.yaml similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_LF.yaml rename to tests/test_data/http/almost_match_boundary_without_LF.yaml diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http b/tests/test_data/http/almost_match_boundary_without_final_hyphen.http similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.http rename to tests/test_data/http/almost_match_boundary_without_final_hyphen.http diff --git a/multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml b/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml similarity index 100% rename from multipart/tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml rename to tests/test_data/http/almost_match_boundary_without_final_hyphen.yaml diff --git a/multipart/tests/test_data/http/bad_end_of_headers.http b/tests/test_data/http/bad_end_of_headers.http similarity index 100% rename from multipart/tests/test_data/http/bad_end_of_headers.http rename to tests/test_data/http/bad_end_of_headers.http diff --git a/multipart/tests/test_data/http/bad_end_of_headers.yaml b/tests/test_data/http/bad_end_of_headers.yaml similarity index 100% rename from multipart/tests/test_data/http/bad_end_of_headers.yaml rename to tests/test_data/http/bad_end_of_headers.yaml diff --git a/multipart/tests/test_data/http/bad_header_char.http b/tests/test_data/http/bad_header_char.http similarity index 100% rename from multipart/tests/test_data/http/bad_header_char.http rename to tests/test_data/http/bad_header_char.http diff --git a/multipart/tests/test_data/http/bad_header_char.yaml b/tests/test_data/http/bad_header_char.yaml similarity index 100% rename from multipart/tests/test_data/http/bad_header_char.yaml rename to tests/test_data/http/bad_header_char.yaml diff --git a/multipart/tests/test_data/http/bad_initial_boundary.http b/tests/test_data/http/bad_initial_boundary.http similarity index 100% rename from multipart/tests/test_data/http/bad_initial_boundary.http rename to tests/test_data/http/bad_initial_boundary.http diff --git a/multipart/tests/test_data/http/bad_initial_boundary.yaml b/tests/test_data/http/bad_initial_boundary.yaml similarity index 100% rename from multipart/tests/test_data/http/bad_initial_boundary.yaml rename to tests/test_data/http/bad_initial_boundary.yaml diff --git a/multipart/tests/test_data/http/base64_encoding.http b/tests/test_data/http/base64_encoding.http similarity index 100% rename from multipart/tests/test_data/http/base64_encoding.http rename to tests/test_data/http/base64_encoding.http diff --git a/multipart/tests/test_data/http/base64_encoding.yaml b/tests/test_data/http/base64_encoding.yaml similarity index 100% rename from multipart/tests/test_data/http/base64_encoding.yaml rename to tests/test_data/http/base64_encoding.yaml diff --git a/multipart/tests/test_data/http/empty_header.http b/tests/test_data/http/empty_header.http similarity index 100% rename from multipart/tests/test_data/http/empty_header.http rename to tests/test_data/http/empty_header.http diff --git a/multipart/tests/test_data/http/empty_header.yaml b/tests/test_data/http/empty_header.yaml similarity index 100% rename from multipart/tests/test_data/http/empty_header.yaml rename to tests/test_data/http/empty_header.yaml diff --git a/multipart/tests/test_data/http/multiple_fields.http b/tests/test_data/http/multiple_fields.http similarity index 100% rename from multipart/tests/test_data/http/multiple_fields.http rename to tests/test_data/http/multiple_fields.http diff --git a/multipart/tests/test_data/http/multiple_fields.yaml b/tests/test_data/http/multiple_fields.yaml similarity index 100% rename from multipart/tests/test_data/http/multiple_fields.yaml rename to tests/test_data/http/multiple_fields.yaml diff --git a/multipart/tests/test_data/http/multiple_files.http b/tests/test_data/http/multiple_files.http similarity index 100% rename from multipart/tests/test_data/http/multiple_files.http rename to tests/test_data/http/multiple_files.http diff --git a/multipart/tests/test_data/http/multiple_files.yaml b/tests/test_data/http/multiple_files.yaml similarity index 100% rename from multipart/tests/test_data/http/multiple_files.yaml rename to tests/test_data/http/multiple_files.yaml diff --git a/multipart/tests/test_data/http/quoted_printable_encoding.http b/tests/test_data/http/quoted_printable_encoding.http similarity index 100% rename from multipart/tests/test_data/http/quoted_printable_encoding.http rename to tests/test_data/http/quoted_printable_encoding.http diff --git a/multipart/tests/test_data/http/quoted_printable_encoding.yaml b/tests/test_data/http/quoted_printable_encoding.yaml similarity index 100% rename from multipart/tests/test_data/http/quoted_printable_encoding.yaml rename to tests/test_data/http/quoted_printable_encoding.yaml diff --git a/multipart/tests/test_data/http/single_field.http b/tests/test_data/http/single_field.http similarity index 100% rename from multipart/tests/test_data/http/single_field.http rename to tests/test_data/http/single_field.http diff --git a/multipart/tests/test_data/http/single_field.yaml b/tests/test_data/http/single_field.yaml similarity index 100% rename from multipart/tests/test_data/http/single_field.yaml rename to tests/test_data/http/single_field.yaml diff --git a/multipart/tests/test_data/http/single_field_blocks.http b/tests/test_data/http/single_field_blocks.http similarity index 100% rename from multipart/tests/test_data/http/single_field_blocks.http rename to tests/test_data/http/single_field_blocks.http diff --git a/multipart/tests/test_data/http/single_field_blocks.yaml b/tests/test_data/http/single_field_blocks.yaml similarity index 100% rename from multipart/tests/test_data/http/single_field_blocks.yaml rename to tests/test_data/http/single_field_blocks.yaml diff --git a/multipart/tests/test_data/http/single_field_longer.http b/tests/test_data/http/single_field_longer.http similarity index 100% rename from multipart/tests/test_data/http/single_field_longer.http rename to tests/test_data/http/single_field_longer.http diff --git a/multipart/tests/test_data/http/single_field_longer.yaml b/tests/test_data/http/single_field_longer.yaml similarity index 100% rename from multipart/tests/test_data/http/single_field_longer.yaml rename to tests/test_data/http/single_field_longer.yaml diff --git a/multipart/tests/test_data/http/single_field_single_file.http b/tests/test_data/http/single_field_single_file.http similarity index 100% rename from multipart/tests/test_data/http/single_field_single_file.http rename to tests/test_data/http/single_field_single_file.http diff --git a/multipart/tests/test_data/http/single_field_single_file.yaml b/tests/test_data/http/single_field_single_file.yaml similarity index 100% rename from multipart/tests/test_data/http/single_field_single_file.yaml rename to tests/test_data/http/single_field_single_file.yaml diff --git a/multipart/tests/test_data/http/single_field_with_leading_newlines.http b/tests/test_data/http/single_field_with_leading_newlines.http similarity index 100% rename from multipart/tests/test_data/http/single_field_with_leading_newlines.http rename to tests/test_data/http/single_field_with_leading_newlines.http diff --git a/multipart/tests/test_data/http/single_field_with_leading_newlines.yaml b/tests/test_data/http/single_field_with_leading_newlines.yaml similarity index 100% rename from multipart/tests/test_data/http/single_field_with_leading_newlines.yaml rename to tests/test_data/http/single_field_with_leading_newlines.yaml diff --git a/multipart/tests/test_data/http/single_file.http b/tests/test_data/http/single_file.http similarity index 100% rename from multipart/tests/test_data/http/single_file.http rename to tests/test_data/http/single_file.http diff --git a/multipart/tests/test_data/http/single_file.yaml b/tests/test_data/http/single_file.yaml similarity index 100% rename from multipart/tests/test_data/http/single_file.yaml rename to tests/test_data/http/single_file.yaml diff --git a/multipart/tests/test_data/http/utf8_filename.http b/tests/test_data/http/utf8_filename.http similarity index 100% rename from multipart/tests/test_data/http/utf8_filename.http rename to tests/test_data/http/utf8_filename.http diff --git a/multipart/tests/test_data/http/utf8_filename.yaml b/tests/test_data/http/utf8_filename.yaml similarity index 100% rename from multipart/tests/test_data/http/utf8_filename.yaml rename to tests/test_data/http/utf8_filename.yaml diff --git a/multipart/tests/test_multipart.py b/tests/test_multipart.py similarity index 64% rename from multipart/tests/test_multipart.py rename to tests/test_multipart.py index 089f451..16db5b3 100644 --- a/multipart/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1,21 +1,30 @@ import os -import sys -import glob -import yaml -import base64 import random +import sys import tempfile import unittest -from .compat import ( - parametrize, - parametrize_class, - slow_test, -) from io import BytesIO -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock -from ..multipart import * +import yaml + +from multipart.decoders import Base64Decoder, QuotedPrintableDecoder +from multipart.exceptions import DecodeError, FileError, FormParserError, MultipartParseError +from multipart.multipart import ( + BaseParser, + Field, + File, + FormParser, + MultipartParser, + OctetStreamParser, + QuerystringParseError, + QuerystringParser, + create_form_parser, + parse_form, + parse_options_header, +) +from .compat import parametrize, parametrize_class, slow_test # Get the current directory for our later test cases. curr_dir = os.path.abspath(os.path.dirname(__file__)) @@ -30,53 +39,53 @@ def force_bytes(val): class TestField(unittest.TestCase): def setUp(self): - self.f = Field('foo') + self.f = Field("foo") def test_name(self): - self.assertEqual(self.f.field_name, 'foo') + self.assertEqual(self.f.field_name, "foo") def test_data(self): - self.f.write(b'test123') - self.assertEqual(self.f.value, b'test123') + self.f.write(b"test123") + self.assertEqual(self.f.value, b"test123") def test_cache_expiration(self): - self.f.write(b'test') - self.assertEqual(self.f.value, b'test') - self.f.write(b'123') - self.assertEqual(self.f.value, b'test123') + self.f.write(b"test") + self.assertEqual(self.f.value, b"test") + self.f.write(b"123") + self.assertEqual(self.f.value, b"test123") def test_finalize(self): - self.f.write(b'test123') + self.f.write(b"test123") self.f.finalize() - self.assertEqual(self.f.value, b'test123') + self.assertEqual(self.f.value, b"test123") def test_close(self): - self.f.write(b'test123') + self.f.write(b"test123") self.f.close() - self.assertEqual(self.f.value, b'test123') + self.assertEqual(self.f.value, b"test123") def test_from_value(self): - f = Field.from_value(b'name', b'value') - self.assertEqual(f.field_name, b'name') - self.assertEqual(f.value, b'value') + f = Field.from_value(b"name", b"value") + self.assertEqual(f.field_name, b"name") + self.assertEqual(f.value, b"value") - f2 = Field.from_value(b'name', None) + f2 = Field.from_value(b"name", None) self.assertEqual(f2.value, None) def test_equality(self): - f1 = Field.from_value(b'name', b'value') - f2 = Field.from_value(b'name', b'value') + f1 = Field.from_value(b"name", b"value") + f2 = Field.from_value(b"name", b"value") self.assertEqual(f1, f2) def test_equality_with_other(self): - f = Field.from_value(b'foo', b'bar') - self.assertFalse(f == b'foo') - self.assertFalse(b'foo' == f) + f = Field.from_value(b"foo", b"bar") + self.assertFalse(f == b"foo") + self.assertFalse(b"foo" == f) def test_set_none(self): - f = Field(b'foo') - self.assertEqual(f.value, b'') + f = Field(b"foo") + self.assertEqual(f.value, b"") f.set_none() self.assertEqual(f.value, None) @@ -86,7 +95,7 @@ class TestFile(unittest.TestCase): def setUp(self): self.c = {} self.d = force_bytes(tempfile.mkdtemp()) - self.f = File(b'foo.txt', config=self.c) + self.f = File(b"foo.txt", config=self.c) def assert_data(self, data): f = self.f.file_object @@ -100,26 +109,26 @@ def assert_exists(self): self.assertTrue(os.path.exists(full_path)) def test_simple(self): - self.f.write(b'foobar') - self.assert_data(b'foobar') + self.f.write(b"foobar") + self.assert_data(b"foobar") def test_invalid_write(self): m = Mock() m.write.return_value = 5 self.f._fileobj = m - v = self.f.write(b'foobar') + v = self.f.write(b"foobar") self.assertEqual(v, 5) def test_file_fallback(self): - self.c['MAX_MEMORY_FILE_SIZE'] = 1 + self.c["MAX_MEMORY_FILE_SIZE"] = 1 - self.f.write(b'1') + self.f.write(b"1") self.assertTrue(self.f.in_memory) - self.assert_data(b'1') + self.assert_data(b"1") - self.f.write(b'123') + self.f.write(b"123") self.assertFalse(self.f.in_memory) - self.assert_data(b'123') + self.assert_data(b"123") # Test flushing too. old_obj = self.f.file_object @@ -128,23 +137,23 @@ def test_file_fallback(self): self.assertIs(self.f.file_object, old_obj) def test_file_fallback_with_data(self): - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["MAX_MEMORY_FILE_SIZE"] = 10 - self.f.write(b'1' * 10) + self.f.write(b"1" * 10) self.assertTrue(self.f.in_memory) - self.f.write(b'2' * 10) + self.f.write(b"2" * 10) self.assertFalse(self.f.in_memory) - self.assert_data(b'11111111112222222222') + self.assert_data(b"11111111112222222222") def test_file_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = self.d - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists @@ -153,125 +162,124 @@ def test_file_name(self): def test_file_full_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 - - # Write. - self.f.write(b'12345678901') - self.assertFalse(self.f.in_memory) - - # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo') - self.assert_exists() - - def test_file_full_name_with_ext(self): - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo.txt') + self.assertEqual(self.f.actual_file_name, b"foo") self.assert_exists() def test_file_full_name_with_ext(self): - self.c['UPLOAD_DIR'] = self.d - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_DIR"] = self.d + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["UPLOAD_KEEP_EXTENSIONS"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists - self.assertEqual(self.f.actual_file_name, b'foo.txt') + self.assertEqual(self.f.actual_file_name, b"foo.txt") self.assert_exists() def test_no_dir_with_extension(self): - self.c['UPLOAD_KEEP_EXTENSIONS'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 10 + self.c["UPLOAD_KEEP_EXTENSIONS"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 10 # Write. - self.f.write(b'12345678901') + self.f.write(b"12345678901") self.assertFalse(self.f.in_memory) # Assert that the file exists ext = os.path.splitext(self.f.actual_file_name)[1] - self.assertEqual(ext, b'.txt') + self.assertEqual(ext, b".txt") self.assert_exists() def test_invalid_dir_with_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) - self.c['UPLOAD_KEEP_FILENAME'] = True - self.c['MAX_MEMORY_FILE_SIZE'] = 5 + self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) + self.c["UPLOAD_KEEP_FILENAME"] = True + self.c["MAX_MEMORY_FILE_SIZE"] = 5 # Write. with self.assertRaises(FileError): - self.f.write(b'1234567890') + self.f.write(b"1234567890") def test_invalid_dir_no_name(self): # Write to this dir. - self.c['UPLOAD_DIR'] = force_bytes(os.path.join('/', 'tmp', 'notexisting')) - self.c['UPLOAD_KEEP_FILENAME'] = False - self.c['MAX_MEMORY_FILE_SIZE'] = 5 + self.c["UPLOAD_DIR"] = force_bytes(os.path.join("/", "tmp", "notexisting")) + self.c["UPLOAD_KEEP_FILENAME"] = False + self.c["MAX_MEMORY_FILE_SIZE"] = 5 # Write. with self.assertRaises(FileError): - self.f.write(b'1234567890') + self.f.write(b"1234567890") # TODO: test uploading two files with the same name. class TestParseOptionsHeader(unittest.TestCase): def test_simple(self): - t, p = parse_options_header('application/json') - self.assertEqual(t, b'application/json') + t, p = parse_options_header("application/json") + self.assertEqual(t, b"application/json") self.assertEqual(p, {}) def test_blank(self): - t, p = parse_options_header('') - self.assertEqual(t, b'') + t, p = parse_options_header("") + self.assertEqual(t, b"") self.assertEqual(p, {}) def test_single_param(self): - t, p = parse_options_header('application/json;par=val') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val'}) + t, p = parse_options_header("application/json;par=val") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val"}) def test_single_param_with_spaces(self): - t, p = parse_options_header(b'application/json; par=val') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val'}) + t, p = parse_options_header(b"application/json; par=val") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val"}) def test_multiple_params(self): - t, p = parse_options_header(b'application/json;par=val;asdf=foo') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'par': b'val', b'asdf': b'foo'}) + t, p = parse_options_header(b"application/json;par=val;asdf=foo") + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"par": b"val", b"asdf": b"foo"}) def test_quoted_param(self): t, p = parse_options_header(b'application/json;param="quoted"') - self.assertEqual(t, b'application/json') - self.assertEqual(p, {b'param': b'quoted'}) + self.assertEqual(t, b"application/json") + self.assertEqual(p, {b"param": b"quoted"}) def test_quoted_param_with_semicolon(self): t, p = parse_options_header(b'application/json;param="quoted;with;semicolons"') - self.assertEqual(p[b'param'], b'quoted;with;semicolons') + self.assertEqual(p[b"param"], b"quoted;with;semicolons") def test_quoted_param_with_escapes(self): t, p = parse_options_header(b'application/json;param="This \\" is \\" a \\" quote"') - self.assertEqual(p[b'param'], b'This " is " a " quote') + self.assertEqual(p[b"param"], b'This " is " a " quote') def test_handles_ie6_bug(self): t, p = parse_options_header(b'text/plain; filename="C:\\this\\is\\a\\path\\file.txt"') - self.assertEqual(p[b'filename'], b'file.txt') + self.assertEqual(p[b"filename"], b"file.txt") + + def test_redos_attack_header(self): + t, p = parse_options_header( + b'application/x-www-form-urlencoded; !="' + b"\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\" + ) + # If vulnerable, this test wouldn't finish, the line above would hang + self.assertIn(b'"\\', p[b"!"]) + + def test_handles_rfc_2231(self): + t, p = parse_options_header(b"text/plain; param*=us-ascii'en-us'encoded%20message") + + self.assertEqual(p[b"param"], b"encoded message") class TestBaseParser(unittest.TestCase): @@ -282,25 +290,26 @@ def setUp(self): def test_callbacks(self): # The stupid list-ness is to get around lack of nonlocal on py2 l = [0] + def on_foo(): l[0] += 1 - self.b.set_callback('foo', on_foo) - self.b.callback('foo') + self.b.set_callback("foo", on_foo) + self.b.callback("foo") self.assertEqual(l[0], 1) - self.b.set_callback('foo', None) - self.b.callback('foo') + self.b.set_callback("foo", None) + self.b.callback("foo") self.assertEqual(l[0], 1) class TestQuerystringParser(unittest.TestCase): def assert_fields(self, *args, **kwargs): - if kwargs.pop('finalize', True): + if kwargs.pop("finalize", True): self.p.finalize() self.assertEqual(self.f, list(args)) - if kwargs.get('reset', True): + if kwargs.get("reset", True): self.f = [] def setUp(self): @@ -319,103 +328,80 @@ def on_field_data(data, start, end): data_buffer.append(data[start:end]) def on_field_end(): - self.f.append(( - b''.join(name_buffer), - b''.join(data_buffer) - )) + self.f.append((b"".join(name_buffer), b"".join(data_buffer))) del name_buffer[:] del data_buffer[:] - callbacks = { - 'on_field_name': on_field_name, - 'on_field_data': on_field_data, - 'on_field_end': on_field_end - } - - self.p = QuerystringParser(callbacks) + self.p = QuerystringParser( + callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end} + ) def test_simple_querystring(self): - self.p.write(b'foo=bar') + self.p.write(b"foo=bar") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_querystring_blank_beginning(self): - self.p.write(b'&foo=bar') + self.p.write(b"&foo=bar") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_querystring_blank_end(self): - self.p.write(b'foo=bar&') + self.p.write(b"foo=bar&") - self.assert_fields((b'foo', b'bar')) + self.assert_fields((b"foo", b"bar")) def test_multiple_querystring(self): - self.p.write(b'foo=bar&asdf=baz') + self.p.write(b"foo=bar&asdf=baz") - self.assert_fields( - (b'foo', b'bar'), - (b'asdf', b'baz') - ) + self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) def test_streaming_simple(self): - self.p.write(b'foo=bar&') - self.assert_fields( - (b'foo', b'bar'), - finalize=False - ) + self.p.write(b"foo=bar&") + self.assert_fields((b"foo", b"bar"), finalize=False) - self.p.write(b'asdf=baz') - self.assert_fields( - (b'asdf', b'baz') - ) + self.p.write(b"asdf=baz") + self.assert_fields((b"asdf", b"baz")) def test_streaming_break(self): - self.p.write(b'foo=one') + self.p.write(b"foo=one") self.assert_fields(finalize=False) - self.p.write(b'two') + self.p.write(b"two") self.assert_fields(finalize=False) - self.p.write(b'three') + self.p.write(b"three") self.assert_fields(finalize=False) - self.p.write(b'&asd') - self.assert_fields( - (b'foo', b'onetwothree'), - finalize=False - ) + self.p.write(b"&asd") + self.assert_fields((b"foo", b"onetwothree"), finalize=False) - self.p.write(b'f=baz') - self.assert_fields( - (b'asdf', b'baz') - ) + self.p.write(b"f=baz") + self.assert_fields((b"asdf", b"baz")) def test_semicolon_separator(self): - self.p.write(b'foo=bar;asdf=baz') + self.p.write(b"foo=bar;asdf=baz") - self.assert_fields( - (b'foo', b'bar'), - (b'asdf', b'baz') - ) + self.assert_fields((b"foo", b"bar"), (b"asdf", b"baz")) def test_too_large_field(self): self.p.max_size = 15 # Note: len = 8 self.p.write(b"foo=bar&") - self.assert_fields((b'foo', b'bar'), finalize=False) + self.assert_fields((b"foo", b"bar"), finalize=False) # Note: len = 8, only 7 bytes processed - self.p.write(b'a=123456') - self.assert_fields((b'a', b'12345')) + self.p.write(b"a=123456") + self.assert_fields((b"a", b"12345")) def test_invalid_max_size(self): with self.assertRaises(ValueError): p = QuerystringParser(max_size=-100) def test_strict_parsing_pass(self): - data = b'foo=bar&another=asdf' + data = b"foo=bar&another=asdf" for first, last in split_all(data): self.reset() self.p.strict_parsing = True @@ -424,10 +410,10 @@ def test_strict_parsing_pass(self): self.p.write(first) self.p.write(last) - self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) def test_strict_parsing_fail_double_sep(self): - data = b'foo=bar&&another=asdf' + data = b"foo=bar&&another=asdf" for first, last in split_all(data): self.reset() self.p.strict_parsing = True @@ -444,7 +430,7 @@ def test_strict_parsing_fail_double_sep(self): self.assertEqual(cm.exception.offset, 8 - cnt) def test_double_sep(self): - data = b'foo=bar&&another=asdf' + data = b"foo=bar&&another=asdf" for first, last in split_all(data): print(f" {first!r} / {last!r} ") self.reset() @@ -453,23 +439,19 @@ def test_double_sep(self): cnt += self.p.write(first) cnt += self.p.write(last) - self.assert_fields((b'foo', b'bar'), (b'another', b'asdf')) + self.assert_fields((b"foo", b"bar"), (b"another", b"asdf")) def test_strict_parsing_fail_no_value(self): self.p.strict_parsing = True with self.assertRaises(QuerystringParseError) as cm: - self.p.write(b'foo=bar&blank&another=asdf') + self.p.write(b"foo=bar&blank&another=asdf") if cm is not None: self.assertEqual(cm.exception.offset, 8) def test_success_no_value(self): - self.p.write(b'foo=bar&blank&another=asdf') - self.assert_fields( - (b'foo', b'bar'), - (b'blank', b''), - (b'another', b'asdf') - ) + self.p.write(b"foo=bar&blank&another=asdf") + self.assert_fields((b"foo", b"bar"), (b"blank", b""), (b"another", b"asdf")) def test_repr(self): # Issue #29; verify we don't assert on repr() @@ -482,25 +464,19 @@ def setUp(self): self.started = 0 self.finished = 0 - def on_start(): + def on_start() -> None: self.started += 1 - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int) -> None: self.d.append(data[start:end]) - def on_end(): + def on_end() -> None: self.finished += 1 - callbacks = { - 'on_start': on_start, - 'on_data': on_data, - 'on_end': on_end - } - - self.p = OctetStreamParser(callbacks) + self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}) def assert_data(self, data, finalize=True): - self.assertEqual(b''.join(self.d), data) + self.assertEqual(b"".join(self.d), data) self.d = [] def assert_started(self, val=True): @@ -520,9 +496,9 @@ def test_simple(self): self.assert_started(False) # Write something, it should then be started + have data - self.p.write(b'foobar') + self.p.write(b"foobar") self.assert_started() - self.assert_data(b'foobar') + self.assert_data(b"foobar") # Finalize, and check self.assert_finished(False) @@ -530,26 +506,26 @@ def test_simple(self): self.assert_finished() def test_multiple_chunks(self): - self.p.write(b'foo') - self.p.write(b'bar') - self.p.write(b'baz') + self.p.write(b"foo") + self.p.write(b"bar") + self.p.write(b"baz") self.p.finalize() - self.assert_data(b'foobarbaz') + self.assert_data(b"foobarbaz") self.assert_finished() def test_max_size(self): self.p.max_size = 5 - self.p.write(b'0123456789') + self.p.write(b"0123456789") self.p.finalize() - self.assert_data(b'01234') + self.assert_data(b"01234") self.assert_finished() def test_invalid_max_size(self): with self.assertRaises(ValueError): - q = OctetStreamParser(max_size='foo') + q = OctetStreamParser(max_size="foo") class TestBase64Decoder(unittest.TestCase): @@ -568,37 +544,37 @@ def assert_data(self, data, finalize=True): self.f.truncate() def test_simple(self): - self.d.write(b'Zm9vYmFy') - self.assert_data(b'foobar') + self.d.write(b"Zm9vYmFy") + self.assert_data(b"foobar") def test_bad(self): with self.assertRaises(DecodeError): - self.d.write(b'Zm9v!mFy') + self.d.write(b"Zm9v!mFy") def test_split_properly(self): - self.d.write(b'Zm9v') - self.d.write(b'YmFy') - self.assert_data(b'foobar') + self.d.write(b"Zm9v") + self.d.write(b"YmFy") + self.assert_data(b"foobar") def test_bad_split(self): - buff = b'Zm9v' + buff = b"Zm9v" for i in range(1, 4): first, second = buff[:i], buff[i:] self.setUp() self.d.write(first) self.d.write(second) - self.assert_data(b'foo') + self.assert_data(b"foo") def test_long_bad_split(self): - buff = b'Zm9vYmFy' + buff = b"Zm9vYmFy" for i in range(5, 8): first, second = buff[:i], buff[i:] self.setUp() self.d.write(first) self.d.write(second) - self.assert_data(b'foobar') + self.assert_data(b"foobar") def test_close_and_finalize(self): parser = Mock() @@ -611,7 +587,7 @@ def test_close_and_finalize(self): parser.close.assert_called_once_with() def test_bad_length(self): - self.d.write(b'Zm9vYmF') # missing ending 'y' + self.d.write(b"Zm9vYmF") # missing ending 'y' with self.assertRaises(DecodeError): self.d.finalize() @@ -632,35 +608,35 @@ def assert_data(self, data, finalize=True): self.f.truncate() def test_simple(self): - self.d.write(b'foobar') - self.assert_data(b'foobar') + self.d.write(b"foobar") + self.assert_data(b"foobar") def test_with_escape(self): - self.d.write(b'foo=3Dbar') - self.assert_data(b'foo=bar') + self.d.write(b"foo=3Dbar") + self.assert_data(b"foo=bar") def test_with_newline_escape(self): - self.d.write(b'foo=\r\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\r\nbar") + self.assert_data(b"foobar") def test_with_only_newline_escape(self): - self.d.write(b'foo=\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\nbar") + self.assert_data(b"foobar") def test_with_split_escape(self): - self.d.write(b'foo=3') - self.d.write(b'Dbar') - self.assert_data(b'foo=bar') + self.d.write(b"foo=3") + self.d.write(b"Dbar") + self.assert_data(b"foo=bar") def test_with_split_newline_escape_1(self): - self.d.write(b'foo=\r') - self.d.write(b'\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=\r") + self.d.write(b"\nbar") + self.assert_data(b"foobar") def test_with_split_newline_escape_2(self): - self.d.write(b'foo=') - self.d.write(b'\r\nbar') - self.assert_data(b'foobar') + self.d.write(b"foo=") + self.d.write(b"\r\nbar") + self.assert_data(b"foobar") def test_close_and_finalize(self): parser = Mock() @@ -676,23 +652,23 @@ def test_not_aligned(self): """ https://github.com/andrew-d/python-multipart/issues/6 """ - self.d.write(b'=3AX') - self.assert_data(b':X') + self.d.write(b"=3AX") + self.assert_data(b":X") # Additional offset tests - self.d.write(b'=3') - self.d.write(b'AX') - self.assert_data(b':X') + self.d.write(b"=3") + self.d.write(b"AX") + self.assert_data(b":X") - self.d.write(b'q=3AX') - self.assert_data(b'q:X') + self.d.write(b"q=3AX") + self.assert_data(b"q:X") # Load our list of HTTP test cases. -http_tests_dir = os.path.join(curr_dir, 'test_data', 'http') +http_tests_dir = os.path.join(curr_dir, "test_data", "http") # Read in all test cases and load them. -NON_PARAMETRIZED_TESTS = {'single_field_blocks'} +NON_PARAMETRIZED_TESTS = {"single_field_blocks"} http_tests = [] for f in os.listdir(http_tests_dir): # Only load the HTTP test cases. @@ -700,22 +676,18 @@ def test_not_aligned(self): if fname in NON_PARAMETRIZED_TESTS: continue - if ext == '.http': + if ext == ".http": # Get the YAML file and load it too. - yaml_file = os.path.join(http_tests_dir, fname + '.yaml') + yaml_file = os.path.join(http_tests_dir, fname + ".yaml") # Load both. - with open(os.path.join(http_tests_dir, f), 'rb') as f: + with open(os.path.join(http_tests_dir, f), "rb") as f: test_data = f.read() - with open(yaml_file, 'rb') as f: + with open(yaml_file, "rb") as f: yaml_data = yaml.safe_load(f) - http_tests.append({ - 'name': fname, - 'test': test_data, - 'result': yaml_data - }) + http_tests.append({"name": fname, "test": test_data, "result": yaml_data}) def split_all(val): @@ -746,8 +718,7 @@ def on_end(): self.ended = True # Get a form-parser instance. - self.f = FormParser('multipart/form-data', on_field, on_file, on_end, - boundary=boundary, config=config) + self.f = FormParser("multipart/form-data", on_field, on_file, on_end, boundary=boundary, config=config) def assert_file_data(self, f, data): o = f.file_object @@ -792,18 +763,18 @@ def assert_field(self, name, value): # Remove it for future iterations. self.fields.remove(found) - @parametrize('param', http_tests) + @parametrize("param", http_tests) def test_http(self, param): # Firstly, create our parser with the given boundary. - boundary = param['result']['boundary'] + boundary = param["result"]["boundary"] if isinstance(boundary, str): - boundary = boundary.encode('latin-1') + boundary = boundary.encode("latin-1") self.make(boundary) # Now, we feed the parser with data. exc = None try: - processed = self.f.write(param['test']) + processed = self.f.write(param["test"]) self.f.finalize() except MultipartParseError as e: processed = 0 @@ -815,29 +786,25 @@ def test_http(self, param): # print(repr(self.files)) # Do we expect an error? - if 'error' in param['result']['expected']: + if "error" in param["result"]["expected"]: self.assertIsNotNone(exc) - self.assertEqual(param['result']['expected']['error'], exc.offset) + self.assertEqual(param["result"]["expected"]["error"], exc.offset) return # No error! - self.assertEqual(processed, len(param['test'])) + self.assertEqual(processed, len(param["test"])) # Assert that the parser gave us the appropriate fields/files. - for e in param['result']['expected']: + for e in param["result"]["expected"]: # Get our type and name. - type = e['type'] - name = e['name'].encode('latin-1') + type = e["type"] + name = e["name"].encode("latin-1") - if type == 'field': - self.assert_field(name, e['data']) + if type == "field": + self.assert_field(name, e["data"]) - elif type == 'file': - self.assert_file( - name, - e['file_name'].encode('latin-1'), - e['data'] - ) + elif type == "file": + self.assert_file(name, e["file_name"].encode("latin-1"), e["data"]) else: assert False @@ -848,14 +815,14 @@ def test_random_splitting(self): through every possible split. """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # We split the file through all cases. for first, last in split_all(test_data): # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data in 2 chunks. i = 0 @@ -867,27 +834,27 @@ def test_random_splitting(self): self.assertEqual(i, len(test_data)) # Assert that our file and field are here. - self.assert_field(b'field', b'test1') - self.assert_file(b'file', b'file.txt', b'test2') + self.assert_field(b"field", b"test1") + self.assert_file(b"file", b"file.txt", b"test2") def test_feed_single_bytes(self): """ This test parses a simple multipart body 1 byte at a time. """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser. - self.make('boundary') + self.make("boundary") # Write all bytes. # NOTE: Can't simply do `for b in test_data`, since that gives # an integer when iterating over a bytes object on Python 3. i = 0 for x in range(len(test_data)): - b = test_data[x:x + 1] + b = test_data[x : x + 1] i += self.f.write(b) self.f.finalize() @@ -896,24 +863,23 @@ def test_feed_single_bytes(self): self.assertEqual(i, len(test_data)) # Assert that our file and field are here. - self.assert_field(b'field', b'test1') - self.assert_file(b'file', b'file.txt', b'test2') + self.assert_field(b"field", b"test1") + self.assert_file(b"file", b"file.txt", b"test2") def test_feed_blocks(self): """ This test parses a simple multipart body 1 byte at a time. """ # Load test data. - test_file = 'single_field_blocks.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_blocks.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() for c in range(1, len(test_data) + 1): # Skip first `d` bytes - not interesting for d in range(c): - # Create form parser. - self.make('boundary') + self.make("boundary") # Skip i = 0 self.f.write(test_data[:d]) @@ -922,7 +888,7 @@ def test_feed_blocks(self): # Write a chunk to achieve condition # `i == data_length - 1` # in boundary search loop (multipatr.py:1302) - b = test_data[x:x + c] + b = test_data[x : x + c] i += self.f.write(b) self.f.finalize() @@ -931,8 +897,7 @@ def test_feed_blocks(self): self.assertEqual(i, len(test_data)) # Assert that our field is here. - self.assert_field(b'field', - b'0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ') + self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") @slow_test def test_request_body_fuzz(self): @@ -945,8 +910,8 @@ def test_request_body_fuzz(self): - Randomly swapping two bytes """ # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() iterations = 1000 @@ -987,7 +952,7 @@ def test_request_body_fuzz(self): print(" " + msg) # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data, and ignore form parser exceptions. i = 0 @@ -1025,7 +990,7 @@ def test_request_body_fuzz_random_data(self): print(" Testing with %d random bytes..." % (data_size,)) # Create form parser. - self.make('boundary') + self.make("boundary") # Feed with data, and ignore form parser exceptions. i = 0 @@ -1046,40 +1011,44 @@ def test_request_body_fuzz_random_data(self): print("Exceptions: %d" % (exceptions,)) def test_bad_start_boundary(self): - self.make('boundary') - data = b'--boundary\rfoobar' + self.make("boundary") + data = b"--boundary\rfoobar" with self.assertRaises(MultipartParseError): self.f.write(data) - self.make('boundary') - data = b'--boundaryfoobar' + self.make("boundary") + data = b"--boundaryfoobar" with self.assertRaises(MultipartParseError): i = self.f.write(data) def test_octet_stream(self): files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() - f = FormParser('application/octet-stream', on_field, on_file, on_end=on_end, file_name=b'foo.txt') + f = FormParser("application/octet-stream", on_field, on_file, on_end=on_end, file_name=b"foo.txt") self.assertTrue(isinstance(f.parser, OctetStreamParser)) - f.write(b'test') - f.write(b'1234') + f.write(b"test") + f.write(b"1234") f.finalize() # Assert that we only received a single file, with the right data, and that we're done. self.assertFalse(on_field.called) self.assertEqual(len(files), 1) - self.assert_file_data(files[0], b'test1234') + self.assert_file_data(files[0], b"test1234") self.assertTrue(on_end.called) def test_querystring(self): fields = [] + def on_field(f): fields.append(f) + on_file = Mock() on_end = Mock() @@ -1090,8 +1059,8 @@ def simple_test(f): on_end.reset_mock() # Write test data. - f.write(b'foo=bar') - f.write(b'&test=asdf') + f.write(b"foo=bar") + f.write(b"&test=asdf") f.finalize() # Assert we only received 2 fields... @@ -1099,26 +1068,26 @@ def simple_test(f): self.assertEqual(len(fields), 2) # ...assert that we have the correct data... - self.assertEqual(fields[0].field_name, b'foo') - self.assertEqual(fields[0].value, b'bar') + self.assertEqual(fields[0].field_name, b"foo") + self.assertEqual(fields[0].value, b"bar") - self.assertEqual(fields[1].field_name, b'test') - self.assertEqual(fields[1].value, b'asdf') + self.assertEqual(fields[1].field_name, b"test") + self.assertEqual(fields[1].value, b"asdf") # ... and assert that we've finished. self.assertTrue(on_end.called) - f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) + f = FormParser("application/x-www-form-urlencoded", on_field, on_file, on_end=on_end) self.assertTrue(isinstance(f.parser, QuerystringParser)) simple_test(f) - f = FormParser('application/x-url-encoded', on_field, on_file, on_end=on_end) + f = FormParser("application/x-url-encoded", on_field, on_file, on_end=on_end) self.assertTrue(isinstance(f.parser, QuerystringParser)) simple_test(f) def test_close_methods(self): parser = Mock() - f = FormParser('application/x-url-encoded', None, None) + f = FormParser("application/x-url-encoded", None, None) f.parser = parser f.finalize() @@ -1130,69 +1099,76 @@ def test_close_methods(self): def test_bad_content_type(self): # We should raise a ValueError for a bad Content-Type with self.assertRaises(ValueError): - f = FormParser('application/bad', None, None) + f = FormParser("application/bad", None, None) def test_no_boundary_given(self): # We should raise a FormParserError when parsing a multipart message # without a boundary. with self.assertRaises(FormParserError): - f = FormParser('multipart/form-data', None, None) + f = FormParser("multipart/form-data", None, None) def test_bad_content_transfer_encoding(self): - data = b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\nContent-Type: text/plain\r\nContent-Transfer-Encoding: badstuff\r\n\r\nTest\r\n----boundary--\r\n' + data = ( + b'----boundary\r\nContent-Disposition: form-data; name="file"; filename="test.txt"\r\n' + b"Content-Type: text/plain\r\n" + b"Content-Transfer-Encoding: badstuff\r\n\r\n" + b"Test\r\n----boundary--\r\n" + ) files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() # Test with erroring. - config = {'UPLOAD_ERROR_ON_BAD_CTE': True} - f = FormParser('multipart/form-data', on_field, on_file, - on_end=on_end, boundary='--boundary', config=config) + config = {"UPLOAD_ERROR_ON_BAD_CTE": True} + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary", config=config) with self.assertRaises(FormParserError): f.write(data) f.finalize() # Test without erroring. - config = {'UPLOAD_ERROR_ON_BAD_CTE': False} - f = FormParser('multipart/form-data', on_field, on_file, - on_end=on_end, boundary='--boundary', config=config) + config = {"UPLOAD_ERROR_ON_BAD_CTE": False} + f = FormParser("multipart/form-data", on_field, on_file, on_end=on_end, boundary="--boundary", config=config) f.write(data) f.finalize() - self.assert_file_data(files[0], b'Test') + self.assert_file_data(files[0], b"Test") def test_handles_None_fields(self): fields = [] + def on_field(f): fields.append(f) + on_file = Mock() on_end = Mock() - f = FormParser('application/x-www-form-urlencoded', on_field, on_file, on_end=on_end) - f.write(b'foo=bar&another&baz=asdf') + f = FormParser("application/x-www-form-urlencoded", on_field, on_file, on_end=on_end) + f.write(b"foo=bar&another&baz=asdf") f.finalize() - self.assertEqual(fields[0].field_name, b'foo') - self.assertEqual(fields[0].value, b'bar') + self.assertEqual(fields[0].field_name, b"foo") + self.assertEqual(fields[0].value, b"bar") - self.assertEqual(fields[1].field_name, b'another') + self.assertEqual(fields[1].field_name, b"another") self.assertEqual(fields[1].value, None) - self.assertEqual(fields[2].field_name, b'baz') - self.assertEqual(fields[2].value, b'asdf') + self.assertEqual(fields[2].field_name, b"baz") + self.assertEqual(fields[2].value, b"asdf") def test_max_size_multipart(self): # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser. - self.make('boundary') + self.make("boundary") # Set the maximum length that we can process to be halfway through the # given data. @@ -1206,14 +1182,14 @@ def test_max_size_multipart(self): def test_max_size_form_parser(self): # Load test data. - test_file = 'single_field_single_file.http' - with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_file = "single_field_single_file.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() # Create form parser setting the maximum length that we can process to # be halfway through the given data. size = len(test_data) / 2 - self.make('boundary', config={'MAX_BODY_SIZE': size}) + self.make("boundary", config={"MAX_BODY_SIZE": size}) i = self.f.write(test_data) self.f.finalize() @@ -1223,29 +1199,35 @@ def test_max_size_form_parser(self): def test_octet_stream_max_size(self): files = [] + def on_file(f): files.append(f) + on_field = Mock() on_end = Mock() - f = FormParser('application/octet-stream', on_field, on_file, - on_end=on_end, file_name=b'foo.txt', - config={'MAX_BODY_SIZE': 10}) + f = FormParser( + "application/octet-stream", + on_field, + on_file, + on_end=on_end, + file_name=b"foo.txt", + config={"MAX_BODY_SIZE": 10}, + ) - f.write(b'0123456789012345689') + f.write(b"0123456789012345689") f.finalize() - self.assert_file_data(files[0], b'0123456789') + self.assert_file_data(files[0], b"0123456789") def test_invalid_max_size_multipart(self): with self.assertRaises(ValueError): - q = MultipartParser(b'bound', max_size='foo') + q = MultipartParser(b"bound", max_size="foo") class TestHelperFunctions(unittest.TestCase): def test_create_form_parser(self): - r = create_form_parser({'Content-Type': 'application/octet-stream'}, - None, None) + r = create_form_parser({"Content-Type": "application/octet-stream"}, None, None) self.assertTrue(isinstance(r, FormParser)) def test_create_form_parser_error(self): @@ -1257,13 +1239,7 @@ def test_parse_form(self): on_field = Mock() on_file = Mock() - parse_form( - {'Content-Type': 'application/octet-stream', - }, - BytesIO(b'123456789012345'), - on_field, - on_file - ) + parse_form({"Content-Type": "application/octet-stream"}, BytesIO(b"123456789012345"), on_field, on_file) assert on_file.call_count == 1 @@ -1273,23 +1249,21 @@ def test_parse_form(self): def test_parse_form_content_length(self): files = [] + def on_file(file): files.append(file) parse_form( - {'Content-Type': 'application/octet-stream', - 'Content-Length': '10' - }, - BytesIO(b'123456789012345'), + {"Content-Type": "application/octet-stream", "Content-Length": "10"}, + BytesIO(b"123456789012345"), None, - on_file + on_file, ) self.assertEqual(len(files), 1) self.assertEqual(files[0].size, 10) - def suite(): suite = unittest.TestSuite() suite.addTest(unittest.defaultTestLoader.loadTestsFromTestCase(TestFile)) diff --git a/tox.ini b/tox.ini index 85d1b54..abf6e29 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py37,py38,py39,py310,py311 +envlist = py38,py39,py310,py311,py312 [testenv] deps= @@ -8,4 +8,4 @@ deps= pytest-timeout PyYAML commands= - pytest --cov-report term-missing --cov-config .coveragerc --cov multipart --timeout=30 multipart/tests + pytest --cov-report term-missing --cov-config .coveragerc --cov multipart --timeout=30 tests