diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 37e2265ea8a..0ddf287f264 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -2008,6 +2008,7 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None # use pointer to allow context clearance after event is processed e.g., resolve(evt, ctx) router.context = self.context + # Iterate through the routes defined in the router to configure and apply middlewares for each route for route, func in router._routes.items(): new_route = route @@ -2017,7 +2018,8 @@ def include_router(self, router: "Router", prefix: Optional[str] = None) -> None new_route = (rule, *route[1:]) # Middlewares are stored by route separately - must grab them to include - middlewares = router._routes_with_middleware.get(new_route) + # Middleware store the route without prefix, so we must not include prefix when grabbing + middlewares = router._routes_with_middleware.get(route) # Need to use "type: ignore" here since mypy does not like a named parameter after # tuple expansion since may cause duplicate named parameters in the function signature. diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 18d10ec4167..8ad5ac35b18 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -1197,6 +1197,25 @@ def base(): assert result["multiValueHeaders"]["Content-Type"] == [content_types.APPLICATION_JSON] +def test_api_gateway_app_with_strip_prefix_and_route_prefix(): + # GIVEN all routes are stripped from its version e.g., /v1 + app = ApiGatewayResolver(strip_prefixes=["/v1"]) + router = Router() + + event = {"httpMethod": "GET", "path": "/v1/users/leandro", "resource": "/users"} + + @router.get("") + def base(user_id: str): + return {"user": user_id} + + # WHEN a router is included prefixing all routes with "/users/" + app.include_router(router, prefix="/users/") + result = app(event, {}) + + # THEN route correctly to the registered route after stripping each prefix (global + router) + assert result["statusCode"] == 200 + + def test_api_gateway_app_router(): # GIVEN a Router with registered routes app = ApiGatewayResolver() diff --git a/tests/functional/event_handler/test_api_middlewares.py b/tests/functional/event_handler/test_api_middlewares.py index 8f98b93343f..58bec259072 100644 --- a/tests/functional/event_handler/test_api_middlewares.py +++ b/tests/functional/event_handler/test_api_middlewares.py @@ -397,6 +397,34 @@ def dummy_route(): assert result["statusCode"] == 200 +def test_api_gateway_middleware_with_include_router_prefix(): + # GIVEN an App and Router instance + app = ApiGatewayResolver() + router = Router() + + def app_middleware(app: EventHandlerInstance, next_middleware: NextMiddleware): + # AND a variable injected into resolver context + app.append_context(injected="injected_value") + return next_middleware(app) + + # WHEN we register a route with a middleware + @router.get("/path", middlewares=[app_middleware]) + def dummy_route(): + # THEN we should have access to the middleware's injected variable + assert app.context["injected"] == "injected_value" + + return Response(status_code=200, body="works!") + + # WHEN register the route with a prefix + app.include_router(router, prefix="/my") + + # THEN resolving a request must execute the middleware + # and return a successful response http 200 status code + result = app(API_REST_EVENT, {}) + + assert result["statusCode"] == 200 + + @pytest.mark.parametrize( "app, event", [