diff --git a/tests/files_test.py b/tests/files_test.py index c263cd3..5a81877 100644 --- a/tests/files_test.py +++ b/tests/files_test.py @@ -11,6 +11,7 @@ import mocket import pytest import requests as python_requests +from local_test_server import uses_local_server @pytest.fixture @@ -20,7 +21,7 @@ def log_stream(): @pytest.fixture def post_url(): - return "https://httpbin.org/post" + return "http://127.0.0.1:5000/post" @pytest.fixture @@ -63,6 +64,7 @@ def get_actual_request_data(log_stream): return boundary, content_length, actual_request_post +@uses_local_server def test_post_file_as_data( # pylint: disable=unused-argument requests, sock, log_stream, post_url, request_logging ): @@ -85,6 +87,7 @@ def test_post_file_as_data( # pylint: disable=unused-argument assert sent.endswith(actual_request_post) +@uses_local_server def test_post_files_text( # pylint: disable=unused-argument sock, requests, log_stream, post_url, request_logging ): @@ -120,6 +123,7 @@ def test_post_files_text( # pylint: disable=unused-argument assert sent.endswith(actual_request_post) +@uses_local_server def test_post_files_file( # pylint: disable=unused-argument sock, requests, log_stream, post_url, request_logging ): @@ -164,6 +168,7 @@ def test_post_files_file( # pylint: disable=unused-argument assert sent.endswith(actual_request_post) +@uses_local_server def test_post_files_complex( # pylint: disable=unused-argument sock, requests, log_stream, post_url, request_logging ): diff --git a/tests/local_test_server.py b/tests/local_test_server.py index 73b52b3..04f47ce 100644 --- a/tests/local_test_server.py +++ b/tests/local_test_server.py @@ -1,11 +1,47 @@ # SPDX-FileCopyrightText: 2025 Tim Cocks # # SPDX-License-Identifier: MIT +import functools import json +import socketserver +import threading +import time from http.server import SimpleHTTPRequestHandler +def uses_local_server(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + with ReusableAddressTCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server: + server_thread = threading.Thread(target=server.serve_forever) + server_thread.daemon = True + server_thread.start() + time.sleep(2) # Give the server some time to start + + result = func(*args, **kwargs) + + server.shutdown() + server.server_close() + return result + + return wrapper + + +class ReusableAddressTCPServer(socketserver.TCPServer): + # Enable SO_REUSEADDR + allow_reuse_address = True + + class LocalTestServerHandler(SimpleHTTPRequestHandler): + def do_POST(self): + if self.path == "/post": + resp_body = json.dumps({"url": "http://localhost:5000/post"}).encode("utf-8") + self.send_response(200) + self.send_header("Content-type", "application/json") + self.send_header("Content-Length", str(len(resp_body))) + self.end_headers() + self.wfile.write(resp_body) + def do_GET(self): if self.path == "/get": resp_body = json.dumps({"url": "http://localhost:5000/get"}).encode("utf-8") diff --git a/tests/real_call_test.py b/tests/real_call_test.py index b982f00..fad0178 100644 --- a/tests/real_call_test.py +++ b/tests/real_call_test.py @@ -12,11 +12,12 @@ import adafruit_connection_manager import pytest -from local_test_server import LocalTestServerHandler +from local_test_server import uses_local_server import adafruit_requests +@uses_local_server def test_gets(): path_index = 0 status_code_index = 1 @@ -28,25 +29,35 @@ def test_gets(): ("status/204", 204, "", None), ] - with socketserver.TCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server: - server_thread = threading.Thread(target=server.serve_forever) - server_thread.daemon = True - server_thread.start() - - time.sleep(2) # Give the server some time to start - - for case in cases: - requests = adafruit_requests.Session(socket, ssl.create_default_context()) - with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response: - assert response.status_code == case[status_code_index] - if case[text_result_index] is not None: - assert response.text == case[text_result_index] - if case[json_keys_index] is not None: - for key, value in case[json_keys_index].items(): - assert response.json()[key] == value - - adafruit_connection_manager.connection_manager_close_all(release_references=True) - - server.shutdown() - server.server_close() - time.sleep(2) + for case in cases: + requests = adafruit_requests.Session(socket, ssl.create_default_context()) + with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response: + assert response.status_code == case[status_code_index] + if case[text_result_index] is not None: + assert response.text == case[text_result_index] + if case[json_keys_index] is not None: + for key, value in case[json_keys_index].items(): + assert response.json()[key] == value + + adafruit_connection_manager.connection_manager_close_all(release_references=True) + + +@pytest.mark.parametrize( + ("allow_redirects", "status_code"), + ( + (True, 200), + (False, 301), + ), +) +def test_http_to_https_redirect(allow_redirects, status_code): + url = "http://www.adafruit.com/api/quotes.php" + requests = adafruit_requests.Session(socket, ssl.create_default_context()) + with requests.get(url, allow_redirects=allow_redirects) as response: + assert response.status_code == status_code + + +def test_https_direct(): + url = "https://www.adafruit.com/api/quotes.php" + requests = adafruit_requests.Session(socket, ssl.create_default_context()) + with requests.get(url) as response: + assert response.status_code == 200