From bd4040474b106bacc5f915027d9c8e1aa150c149 Mon Sep 17 00:00:00 2001
From: Naman Trivedi <trivenay@amazon.com>
Date: Wed, 31 Jul 2024 13:57:40 +0000
Subject: [PATCH] Raise all init errors in init instead of suppressing them
 until the fist invoke

---
 awslambdaric/bootstrap.py | 31 +++++++++-----------
 tests/test_bootstrap.py   | 60 ++++++---------------------------------
 2 files changed, 21 insertions(+), 70 deletions(-)

diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py
index e737b7b..60aa216 100644
--- a/awslambdaric/bootstrap.py
+++ b/awslambdaric/bootstrap.py
@@ -37,36 +37,30 @@ def _get_handler(handler):
     try:
         (modname, fname) = handler.rsplit(".", 1)
     except ValueError as e:
-        fault = FaultException(
+        raise FaultException(
             FaultException.MALFORMED_HANDLER_NAME,
             "Bad handler '{}': {}".format(handler, str(e)),
         )
-        return make_fault_handler(fault)
 
     try:
         if modname.split(".")[0] in sys.builtin_module_names:
-            fault = FaultException(
+            raise FaultException(
                 FaultException.BUILT_IN_MODULE_CONFLICT,
                 "Cannot use built-in module {} as a handler module".format(modname),
             )
-            return make_fault_handler(fault)
         m = importlib.import_module(modname.replace("/", "."))
     except ImportError as e:
-        fault = FaultException(
+        raise FaultException(
             FaultException.IMPORT_MODULE_ERROR,
             "Unable to import module '{}': {}".format(modname, str(e)),
         )
-        request_handler = make_fault_handler(fault)
-        return request_handler
     except SyntaxError as e:
         trace = ['  File "%s" Line %s\n    %s' % (e.filename, e.lineno, e.text)]
-        fault = FaultException(
+        raise FaultException(
             FaultException.USER_CODE_SYNTAX_ERROR,
             "Syntax error in module '{}': {}".format(modname, str(e)),
             trace,
         )
-        request_handler = make_fault_handler(fault)
-        return request_handler
 
     try:
         request_handler = getattr(m, fname)
@@ -76,15 +70,8 @@ def _get_handler(handler):
             "Handler '{}' missing on module '{}'".format(fname, modname),
             None,
         )
-        request_handler = make_fault_handler(fault)
-    return request_handler
-
-
-def make_fault_handler(fault):
-    def result(*args):
         raise fault
-
-    return result
+    return request_handler
 
 
 def make_error(
@@ -475,15 +462,23 @@ def run(app_root, handler, lambda_runtime_api_addr):
         lambda_runtime_client = LambdaRuntimeClient(
             lambda_runtime_api_addr, use_thread_for_polling_next
         )
+        error_result = None
 
         try:
             _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
             global _GLOBAL_AWS_REQUEST_ID
 
             request_handler = _get_handler(handler)
+        except FaultException as e:
+            error_result = make_error(
+                e.msg,
+                e.exception_type,
+                e.trace,
+            )
         except Exception:
             error_result = build_fault_result(sys.exc_info(), None)
 
+        if error_result is not None:
             log_error(error_result, log_sink)
             lambda_runtime_client.post_init_error(to_json(error_result))
 
diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py
index fd56d9f..7bc2ad2 100644
--- a/tests/test_bootstrap.py
+++ b/tests/test_bootstrap.py
@@ -603,43 +603,6 @@ def raise_exception_handler(json_input, lambda_context):
 
         self.assertEqual(mock_stdout.getvalue(), error_logs)
 
-    # The order of patches matter. Using MagicMock resets sys.stdout to the default.
-    @patch("importlib.import_module")
-    @patch("sys.stdout", new_callable=StringIO)
-    def test_handle_event_request_fault_exception_logging_syntax_error(
-        self, mock_stdout, mock_import_module
-    ):
-        try:
-            eval("-")
-        except SyntaxError as e:
-            syntax_error = e
-
-        mock_import_module.side_effect = syntax_error
-
-        response_handler = bootstrap._get_handler("a.b")
-
-        bootstrap.handle_event_request(
-            self.lambda_runtime,
-            response_handler,
-            "invoke_id",
-            self.event_body,
-            "application/json",
-            {},
-            {},
-            "invoked_function_arn",
-            0,
-            bootstrap.StandardLogSink(),
-        )
-        error_logs = (
-            lambda_unhandled_exception_warning_message
-            + f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r"
-        )
-        error_logs += "Traceback (most recent call last):\r"
-        error_logs += '  File "<string>" Line 1\r'
-        error_logs += "    -\n"
-
-        self.assertEqual(mock_stdout.getvalue(), error_logs)
-
 
 class TestXrayFault(unittest.TestCase):
     def test_make_xray(self):
@@ -717,10 +680,8 @@ def __eq__(self, other):
 
     def test_get_event_handler_bad_handler(self):
         handler_name = "bad_handler"
-        response_handler = bootstrap._get_handler(handler_name)
         with self.assertRaises(FaultException) as cm:
-            response_handler()
-
+            response_handler = bootstrap._get_handler(handler_name)
         returned_exception = cm.exception
         self.assertEqual(
             self.FaultExceptionMatcher(
@@ -732,9 +693,8 @@ def test_get_event_handler_bad_handler(self):
 
     def test_get_event_handler_import_error(self):
         handler_name = "no_module.handler"
-        response_handler = bootstrap._get_handler(handler_name)
         with self.assertRaises(FaultException) as cm:
-            response_handler()
+            response_handler = bootstrap._get_handler(handler_name)
         returned_exception = cm.exception
         self.assertEqual(
             self.FaultExceptionMatcher(
@@ -757,10 +717,9 @@ def test_get_event_handler_syntax_error(self):
             filename_w_ext = os.path.basename(tmp_file.name)
             filename, _ = os.path.splitext(filename_w_ext)
             handler_name = "{}.syntax_error".format(filename)
-            response_handler = bootstrap._get_handler(handler_name)
 
             with self.assertRaises(FaultException) as cm:
-                response_handler()
+                response_handler = bootstrap._get_handler(handler_name)
             returned_exception = cm.exception
             self.assertEqual(
                 self.FaultExceptionMatcher(
@@ -782,9 +741,8 @@ def test_get_event_handler_missing_error(self):
             filename_w_ext = os.path.basename(tmp_file.name)
             filename, _ = os.path.splitext(filename_w_ext)
             handler_name = "{}.my_handler".format(filename)
-            response_handler = bootstrap._get_handler(handler_name)
             with self.assertRaises(FaultException) as cm:
-                response_handler()
+                response_handler = bootstrap._get_handler(handler_name)
             returned_exception = cm.exception
             self.assertEqual(
                 self.FaultExceptionMatcher(
@@ -801,9 +759,8 @@ def test_get_event_handler_slash(self):
         response_handler()
 
     def test_get_event_handler_build_in_conflict(self):
-        response_handler = bootstrap._get_handler("sys.hello")
         with self.assertRaises(FaultException) as cm:
-            response_handler()
+            response_handler = bootstrap._get_handler("sys.hello")
         returned_exception = cm.exception
         self.assertEqual(
             self.FaultExceptionMatcher(
@@ -1452,9 +1409,8 @@ def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):
 
 
 class TestBootstrapModule(unittest.TestCase):
-    @patch("awslambdaric.bootstrap.handle_event_request")
     @patch("awslambdaric.bootstrap.LambdaRuntimeClient")
-    def test_run(self, mock_runtime_client, mock_handle_event_request):
+    def test_run(self, mock_runtime_client):
         expected_app_root = "/tmp/test/app_root"
         expected_handler = "app.my_test_handler"
         expected_lambda_runtime_api_addr = "test_addr"
@@ -1467,12 +1423,12 @@ def test_run(self, mock_runtime_client, mock_handle_event_request):
             MagicMock(),
         ]
 
-        with self.assertRaises(TypeError):
+        with self.assertRaises(SystemExit) as cm:
             bootstrap.run(
                 expected_app_root, expected_handler, expected_lambda_runtime_api_addr
             )
 
-        mock_handle_event_request.assert_called_once()
+        self.assertEqual(cm.exception.code, 1)
 
     @patch(
         "awslambdaric.bootstrap.LambdaLoggerHandler",