Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/files_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import mocket
import pytest
import requests as python_requests
from local_test_server import uses_local_server


@pytest.fixture
Expand All @@ -20,7 +21,7 @@ def log_stream():

@pytest.fixture
def post_url():
return "https://httpbin.org/post"
return "http://127.0.0.1:5000/post"


@pytest.fixture
Expand Down Expand Up @@ -63,6 +64,7 @@ def get_actual_request_data(log_stream):
return boundary, content_length, actual_request_post


@uses_local_server
def test_post_file_as_data( # pylint: disable=unused-argument
requests, sock, log_stream, post_url, request_logging
):
Expand All @@ -85,6 +87,7 @@ def test_post_file_as_data( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_text( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down Expand Up @@ -120,6 +123,7 @@ def test_post_files_text( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_file( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down Expand Up @@ -164,6 +168,7 @@ def test_post_files_file( # pylint: disable=unused-argument
assert sent.endswith(actual_request_post)


@uses_local_server
def test_post_files_complex( # pylint: disable=unused-argument
sock, requests, log_stream, post_url, request_logging
):
Expand Down
36 changes: 36 additions & 0 deletions tests/local_test_server.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,47 @@
# SPDX-FileCopyrightText: 2025 Tim Cocks
#
# SPDX-License-Identifier: MIT
import functools
import json
import socketserver
import threading
import time
from http.server import SimpleHTTPRequestHandler


def uses_local_server(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
with ReusableAddressTCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server:
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
time.sleep(2) # Give the server some time to start

result = func(*args, **kwargs)

server.shutdown()
server.server_close()
return result

return wrapper


class ReusableAddressTCPServer(socketserver.TCPServer):
# Enable SO_REUSEADDR
allow_reuse_address = True


class LocalTestServerHandler(SimpleHTTPRequestHandler):
def do_POST(self):
if self.path == "/post":
resp_body = json.dumps({"url": "http://localhost:5000/post"}).encode("utf-8")
self.send_response(200)
self.send_header("Content-type", "application/json")
self.send_header("Content-Length", str(len(resp_body)))
self.end_headers()
self.wfile.write(resp_body)

def do_GET(self):
if self.path == "/get":
resp_body = json.dumps({"url": "http://localhost:5000/get"}).encode("utf-8")
Expand Down
42 changes: 23 additions & 19 deletions tests/real_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@

import adafruit_connection_manager
import pytest
from local_test_server import LocalTestServerHandler
from local_test_server import uses_local_server

import adafruit_requests


@uses_local_server
def test_gets():
path_index = 0
status_code_index = 1
Expand All @@ -28,25 +29,28 @@ def test_gets():
("status/204", 204, "", None),
]

with socketserver.TCPServer(("127.0.0.1", 5000), LocalTestServerHandler) as server:
server_thread = threading.Thread(target=server.serve_forever)
server_thread.daemon = True
server_thread.start()
for case in cases:
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response:
assert response.status_code == case[status_code_index]
if case[text_result_index] is not None:
assert response.text == case[text_result_index]
if case[json_keys_index] is not None:
for key, value in case[json_keys_index].items():
assert response.json()[key] == value

time.sleep(2) # Give the server some time to start
adafruit_connection_manager.connection_manager_close_all(release_references=True)

for case in cases:
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(f"http://127.0.0.1:5000/{case[path_index]}") as response:
assert response.status_code == case[status_code_index]
if case[text_result_index] is not None:
assert response.text == case[text_result_index]
if case[json_keys_index] is not None:
for key, value in case[json_keys_index].items():
assert response.json()[key] == value

adafruit_connection_manager.connection_manager_close_all(release_references=True)
def test_http_to_https_redirect():
url = "http://www.adafruit.com/api/quotes.php"
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(url) as response:
assert response.status_code == 200

server.shutdown()
server.server_close()
time.sleep(2)

def test_https_direct():
url = "https://www.adafruit.com/api/quotes.php"
requests = adafruit_requests.Session(socket, ssl.create_default_context())
with requests.get(url) as response:
assert response.status_code == 200