From 365c3b88aa2e6d15a866be2bdae73f7684479777 Mon Sep 17 00:00:00 2001
From: Artem Krivonos <artemkr@amazon.com>
Date: Mon, 21 Aug 2023 17:15:10 +0100
Subject: [PATCH] Add structured logging implementation

---
 awslambdaric/bootstrap.py                | 140 +++++++++-----
 awslambdaric/lambda_runtime_log_utils.py | 123 +++++++++++++
 tests/test_bootstrap.py                  | 221 ++++++++++++++++++++++-
 tests/test_lambda_context.py             |   2 +-
 tests/test_lambda_runtime_client.py      |   5 +-
 5 files changed, 436 insertions(+), 55 deletions(-)
 create mode 100644 awslambdaric/lambda_runtime_log_utils.py

diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py
index e7b9e5a..5ad7bb5 100644
--- a/awslambdaric/bootstrap.py
+++ b/awslambdaric/bootstrap.py
@@ -13,10 +13,19 @@
 from .lambda_context import LambdaContext
 from .lambda_runtime_client import LambdaRuntimeClient
 from .lambda_runtime_exception import FaultException
+from .lambda_runtime_log_utils import (
+    _DATETIME_FORMAT,
+    _DEFAULT_FRAME_TYPE,
+    _JSON_FRAME_TYPES,
+    JsonFormatter,
+    LogFormat,
+)
 from .lambda_runtime_marshaller import to_json
 
 ERROR_LOG_LINE_TERMINATE = "\r"
 ERROR_LOG_IDENT = "\u00a0"  # NO-BREAK SPACE U+00A0
+_AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT"))
+_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper()
 
 
 def _get_handler(handler):
@@ -73,7 +82,12 @@ def result(*args):
     return result
 
 
-def make_error(error_message, error_type, stack_trace, invoke_id=None):
+def make_error(
+    error_message,
+    error_type,
+    stack_trace,
+    invoke_id=None,
+):
     result = {
         "errorMessage": error_message if error_message else "",
         "errorType": error_type if error_type else "",
@@ -92,34 +106,52 @@ def replace_line_indentation(line, indent_char, new_indent_char):
     return (new_indent_char * ident_chars_count) + line[ident_chars_count:]
 
 
-def log_error(error_result, log_sink):
-    error_description = "[ERROR]"
+if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON:
+    _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR]
+
+    def log_error(error_result, log_sink):
+        error_result = {
+            "timestamp": time.strftime(
+                _DATETIME_FORMAT, logging.Formatter.converter(time.time())
+            ),
+            "log_level": "ERROR",
+            **error_result,
+        }
+        log_sink.log_error(
+            [to_json(error_result)],
+        )
 
-    error_result_type = error_result.get("errorType")
-    if error_result_type:
-        error_description += " " + error_result_type
+else:
+    _ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE
 
-    error_result_message = error_result.get("errorMessage")
-    if error_result_message:
+    def log_error(error_result, log_sink):
+        error_description = "[ERROR]"
+
+        error_result_type = error_result.get("errorType")
         if error_result_type:
-            error_description += ":"
-        error_description += " " + error_result_message
+            error_description += " " + error_result_type
+
+        error_result_message = error_result.get("errorMessage")
+        if error_result_message:
+            if error_result_type:
+                error_description += ":"
+            error_description += " " + error_result_message
 
-    error_message_lines = [error_description]
+        error_message_lines = [error_description]
 
-    stack_trace = error_result.get("stackTrace")
-    if stack_trace is not None:
-        error_message_lines += ["Traceback (most recent call last):"]
-        for trace_element in stack_trace:
-            if trace_element == "":
-                error_message_lines += [""]
-            else:
-                for trace_line in trace_element.splitlines():
-                    error_message_lines += [
-                        replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
-                    ]
+        stack_trace = error_result.get("stackTrace")
+        if stack_trace is not None:
+            error_message_lines += ["Traceback (most recent call last):"]
+            for trace_element in stack_trace:
+                if trace_element == "":
+                    error_message_lines += [""]
+                else:
+                    for trace_line in trace_element.splitlines():
+                        error_message_lines += [
+                            replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
+                        ]
 
-    log_sink.log_error(error_message_lines)
+        log_sink.log_error(error_message_lines)
 
 
 def handle_event_request(
@@ -152,7 +184,12 @@ def handle_event_request(
         )
     except FaultException as e:
         xray_fault = make_xray_fault("LambdaValidationError", e.msg, os.getcwd(), [])
-        error_result = make_error(e.msg, e.exception_type, e.trace, invoke_id)
+        error_result = make_error(
+            e.msg,
+            e.exception_type,
+            e.trace,
+            invoke_id,
+        )
 
     except Exception:
         etype, value, tb = sys.exc_info()
@@ -221,7 +258,9 @@ def build_fault_result(exc_info, msg):
             break
 
     return make_error(
-        msg if msg else str(value), etype.__name__, traceback.format_list(tb_tuples)
+        msg if msg else str(value),
+        etype.__name__,
+        traceback.format_list(tb_tuples),
     )
 
 
@@ -257,7 +296,8 @@ def __init__(self, log_sink):
 
     def emit(self, record):
         msg = self.format(record)
-        self.log_sink.log(msg)
+
+        self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None))
 
 
 class LambdaLoggerFilter(logging.Filter):
@@ -298,7 +338,7 @@ def __enter__(self):
     def __exit__(self, exc_type, exc_value, exc_tb):
         pass
 
-    def log(self, msg):
+    def log(self, msg, frame_type=None):
         sys.stdout.write(msg)
 
     def log_error(self, message_lines):
@@ -324,7 +364,6 @@ class FramedTelemetryLogSink(object):
 
     def __init__(self, fd):
         self.fd = int(fd)
-        self.frame_type = 0xA55A0003.to_bytes(4, "big")
 
     def __enter__(self):
         self.file = os.fdopen(self.fd, "wb", 0)
@@ -333,11 +372,12 @@ def __enter__(self):
     def __exit__(self, exc_type, exc_value, exc_tb):
         self.file.close()
 
-    def log(self, msg):
+    def log(self, msg, frame_type=None):
         encoded_msg = msg.encode("utf8")
+
         timestamp = int(time.time_ns() / 1000)  # UNIX timestamp in microseconds
         log_msg = (
-            self.frame_type
+            (frame_type or _DEFAULT_FRAME_TYPE)
             + len(encoded_msg).to_bytes(4, "big")
             + timestamp.to_bytes(8, "big")
             + encoded_msg
@@ -346,7 +386,10 @@ def log(self, msg):
 
     def log_error(self, message_lines):
         error_message = "\n".join(message_lines)
-        self.log(error_message)
+        self.log(
+            error_message,
+            frame_type=_ERROR_FRAME_TYPE,
+        )
 
 
 def update_xray_env_variable(xray_trace_id):
@@ -370,6 +413,28 @@ def create_log_sink():
 _GLOBAL_AWS_REQUEST_ID = None
 
 
+def _setup_logging(log_format, log_level, log_sink):
+    logging.Formatter.converter = time.gmtime
+    logger = logging.getLogger()
+    logger_handler = LambdaLoggerHandler(log_sink)
+    if log_format == LogFormat.JSON:
+        logger_handler.setFormatter(JsonFormatter())
+
+        logging.addLevelName(logging.DEBUG, "TRACE")
+        if log_level in logging._nameToLevel:
+            logger.setLevel(log_level)
+    else:
+        logger_handler.setFormatter(
+            logging.Formatter(
+                "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
+                "%Y-%m-%dT%H:%M:%S",
+            )
+        )
+
+    logger_handler.addFilter(LambdaLoggerFilter())
+    logger.addHandler(logger_handler)
+
+
 def run(app_root, handler, lambda_runtime_api_addr):
     sys.stdout = Unbuffered(sys.stdout)
     sys.stderr = Unbuffered(sys.stderr)
@@ -378,18 +443,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
         lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)
 
         try:
-            logging.Formatter.converter = time.gmtime
-            logger = logging.getLogger()
-            logger_handler = LambdaLoggerHandler(log_sink)
-            logger_handler.setFormatter(
-                logging.Formatter(
-                    "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
-                    "%Y-%m-%dT%H:%M:%S",
-                )
-            )
-            logger_handler.addFilter(LambdaLoggerFilter())
-            logger.addHandler(logger_handler)
-
+            _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
             global _GLOBAL_AWS_REQUEST_ID
 
             request_handler = _get_handler(handler)
diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py
new file mode 100644
index 0000000..f140253
--- /dev/null
+++ b/awslambdaric/lambda_runtime_log_utils.py
@@ -0,0 +1,123 @@
+"""
+Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+"""
+
+import json
+import logging
+import traceback
+from enum import IntEnum
+
+_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
+_RESERVED_FIELDS = {
+    "name",
+    "msg",
+    "args",
+    "levelname",
+    "levelno",
+    "pathname",
+    "filename",
+    "module",
+    "exc_info",
+    "exc_text",
+    "stack_info",
+    "lineno",
+    "funcName",
+    "created",
+    "msecs",
+    "relativeCreated",
+    "thread",
+    "threadName",
+    "processName",
+    "process",
+    "aws_request_id",
+    "_frame_type",
+}
+
+
+class LogFormat(IntEnum):
+    JSON = 0b0
+    TEXT = 0b1
+
+    @classmethod
+    def from_str(cls, value: str):
+        if value and value.upper() == "JSON":
+            return cls.JSON.value
+        return cls.TEXT.value
+
+
+_JSON_FRAME_TYPES = {
+    logging.NOTSET: 0xA55A0002.to_bytes(4, "big"),
+    logging.DEBUG: 0xA55A000A.to_bytes(4, "big"),
+    logging.INFO: 0xA55A000E.to_bytes(4, "big"),
+    logging.WARNING: 0xA55A0012.to_bytes(4, "big"),
+    logging.ERROR: 0xA55A0016.to_bytes(4, "big"),
+    logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"),
+}
+_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big")
+
+_json_encoder = json.JSONEncoder(ensure_ascii=False)
+_encode_json = _json_encoder.encode
+
+
+class JsonFormatter(logging.Formatter):
+    def __init__(self):
+        super().__init__(datefmt=_DATETIME_FORMAT)
+
+    @staticmethod
+    def __format_stacktrace(exc_info):
+        if not exc_info:
+            return None
+        return traceback.format_tb(exc_info[2])
+
+    @staticmethod
+    def __format_exception_name(exc_info):
+        if not exc_info:
+            return None
+
+        return exc_info[0].__name__
+
+    @staticmethod
+    def __format_exception(exc_info):
+        if not exc_info:
+            return None
+
+        return str(exc_info[1])
+
+    @staticmethod
+    def __format_location(record: logging.LogRecord):
+        if not record.exc_info:
+            return None
+
+        return f"{record.pathname}:{record.funcName}:{record.lineno}"
+
+    @staticmethod
+    def __format_log_level(record: logging.LogRecord):
+        record.levelno = min(50, max(0, record.levelno)) // 10 * 10
+        record.levelname = logging.getLevelName(record.levelno)
+
+    def format(self, record: logging.LogRecord) -> str:
+        self.__format_log_level(record)
+        record._frame_type = _JSON_FRAME_TYPES.get(
+            record.levelno, _JSON_FRAME_TYPES[logging.NOTSET]
+        )
+
+        result = {
+            "timestamp": self.formatTime(record, self.datefmt),
+            "level": record.levelname,
+            "message": record.getMessage(),
+            "logger": record.name,
+            "stackTrace": self.__format_stacktrace(record.exc_info),
+            "errorType": self.__format_exception_name(record.exc_info),
+            "errorMessage": self.__format_exception(record.exc_info),
+            "requestId": getattr(record, "aws_request_id", None),
+            "location": self.__format_location(record),
+        }
+        result.update(
+            (key, value)
+            for key, value in record.__dict__.items()
+            if key not in _RESERVED_FIELDS and key not in result
+        )
+
+        result = {k: v for k, v in result.items() if v is not None}
+
+        return _encode_json(result) + "\n"
diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py
index edb0737..5614a2e 100644
--- a/tests/test_bootstrap.py
+++ b/tests/test_bootstrap.py
@@ -4,6 +4,8 @@
 
 import importlib
 import json
+import logging
+import logging.config
 import os
 import re
 import tempfile
@@ -16,6 +18,7 @@
 
 import awslambdaric.bootstrap as bootstrap
 from awslambdaric.lambda_runtime_exception import FaultException
+from awslambdaric.lambda_runtime_log_utils import LogFormat
 from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller
 
 
@@ -613,14 +616,7 @@ def test_handle_event_request_fault_exception_logging_syntax_error(
             bootstrap.StandardLogSink(),
         )
 
-        import sys
-
-        sys.stderr.write(mock_stdout.getvalue())
-
-        error_logs = (
-            "[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': "
-            "unexpected EOF while parsing (<string>, line 1)\r"
-        )
+        error_logs = f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r"
         error_logs += "Traceback (most recent call last):\r"
         error_logs += '  File "<string>" Line 1\r'
         error_logs += "    -\n"
@@ -1174,6 +1170,215 @@ def test_multiple_frame(self):
                 self.assertEqual(content[pos:], b"")
 
 
+class TestLoggingSetup(unittest.TestCase):
+    def test_log_level(self) -> None:
+        test_cases = [
+            (LogFormat.JSON, "TRACE", logging.DEBUG),
+            (LogFormat.JSON, "DEBUG", logging.DEBUG),
+            (LogFormat.JSON, "INFO", logging.INFO),
+            (LogFormat.JSON, "WARN", logging.WARNING),
+            (LogFormat.JSON, "ERROR", logging.ERROR),
+            (LogFormat.JSON, "FATAL", logging.CRITICAL),
+            # Log level is set only for Json format
+            (LogFormat.TEXT, "TRACE", logging.NOTSET),
+            (LogFormat.TEXT, "DEBUG", logging.NOTSET),
+            (LogFormat.TEXT, "INFO", logging.NOTSET),
+            (LogFormat.TEXT, "WARN", logging.NOTSET),
+            (LogFormat.TEXT, "ERROR", logging.NOTSET),
+            (LogFormat.TEXT, "FATAL", logging.NOTSET),
+            ("Unknown format", "INFO", logging.NOTSET),
+            # if level is unknown fall back to default
+            (LogFormat.JSON, "Unknown level", logging.NOTSET),
+        ]
+        for fmt, log_level, expected_level in test_cases:
+            with self.subTest():
+                # Drop previous setup
+                logging.getLogger().handlers.clear()
+                logging.getLogger().level = logging.NOTSET
+
+                bootstrap._setup_logging(fmt, log_level, bootstrap.StandardLogSink())
+
+                self.assertEqual(expected_level, logging.getLogger().level)
+
+
+class TestLogging(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls) -> None:
+        logging.getLogger().handlers.clear()
+        logging.getLogger().level = logging.NOTSET
+        bootstrap._setup_logging(
+            LogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink()
+        )
+
+    @patch("sys.stderr", new_callable=StringIO)
+    def test_json_formatter(self, mock_stderr):
+        logger = logging.getLogger("a.b")
+
+        test_cases = [
+            (
+                logging.ERROR,
+                "TEST 1",
+                {
+                    "level": "ERROR",
+                    "logger": "a.b",
+                    "message": "TEST 1",
+                    "requestId": "",
+                },
+            ),
+            (
+                logging.ERROR,
+                "test \nwith \nnew \nlines",
+                {
+                    "level": "ERROR",
+                    "logger": "a.b",
+                    "message": "test \nwith \nnew \nlines",
+                    "requestId": "",
+                },
+            ),
+            (
+                logging.CRITICAL,
+                "TEST CRITICAL",
+                {
+                    "level": "CRITICAL",
+                    "logger": "a.b",
+                    "message": "TEST CRITICAL",
+                    "requestId": "",
+                },
+            ),
+        ]
+        for level, msg, expected in test_cases:
+            with self.subTest(msg):
+                with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+                    logger.log(level, msg)
+
+                    data = json.loads(mock_stdout.getvalue())
+                    data.pop("timestamp")
+                    self.assertEqual(
+                        data,
+                        expected,
+                    )
+        self.assertEqual(mock_stderr.getvalue(), "")
+
+    @patch("sys.stdout", new_callable=StringIO)
+    @patch("sys.stderr", new_callable=StringIO)
+    def test_exception(self, mock_stderr, mock_stdout):
+        try:
+            raise ValueError("error message")
+        except ValueError:
+            logging.getLogger("test.logger").exception("test exception")
+
+        exception_log = json.loads(mock_stdout.getvalue())
+        self.assertIn("location", exception_log)
+        self.assertIn("stackTrace", exception_log)
+        exception_log.pop("timestamp")
+        exception_log.pop("location")
+        stack_trace = exception_log.pop("stackTrace")
+
+        self.assertEqual(len(stack_trace), 1)
+
+        self.assertEqual(
+            exception_log,
+            {
+                "errorMessage": "error message",
+                "errorType": "ValueError",
+                "level": "ERROR",
+                "logger": "test.logger",
+                "message": "test exception",
+                "requestId": "",
+            },
+        )
+
+        self.assertEqual(mock_stderr.getvalue(), "")
+
+    @patch("sys.stdout", new_callable=StringIO)
+    @patch("sys.stderr", new_callable=StringIO)
+    def test_log_level(self, mock_stderr, mock_stdout):
+        logger = logging.getLogger("test.logger")
+
+        logger.debug("debug message")
+        logger.info("info message")
+
+        data = json.loads(mock_stdout.getvalue())
+        data.pop("timestamp")
+
+        self.assertEqual(
+            data,
+            {
+                "level": "INFO",
+                "logger": "test.logger",
+                "message": "info message",
+                "requestId": "",
+            },
+        )
+        self.assertEqual(mock_stderr.getvalue(), "")
+
+    @patch("sys.stdout", new_callable=StringIO)
+    @patch("sys.stderr", new_callable=StringIO)
+    def test_set_log_level_manually(self, mock_stderr, mock_stdout):
+        logger = logging.getLogger("test.logger")
+
+        # Changing log level after `bootstrap.setup_logging`
+        logging.getLogger().setLevel(logging.CRITICAL)
+
+        logger.debug("debug message")
+        logger.info("info message")
+        logger.warning("warning message")
+        logger.error("error message")
+        logger.critical("critical message")
+
+        data = json.loads(mock_stdout.getvalue())
+        data.pop("timestamp")
+
+        self.assertEqual(
+            data,
+            {
+                "level": "CRITICAL",
+                "logger": "test.logger",
+                "message": "critical message",
+                "requestId": "",
+            },
+        )
+        self.assertEqual(mock_stderr.getvalue(), "")
+
+    @patch("sys.stdout", new_callable=StringIO)
+    @patch("sys.stderr", new_callable=StringIO)
+    def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):
+        # Changing log level after `bootstrap.setup_logging`
+        logging.config.dictConfig(
+            {
+                "version": 1,
+                "disable_existing_loggers": False,
+                "formatters": {"simple": {"format": "%(levelname)-8s - %(message)s"}},
+                "handlers": {
+                    "stdout": {
+                        "class": "logging.StreamHandler",
+                        "formatter": "simple",
+                    },
+                },
+                "root": {
+                    "level": "CRITICAL",
+                    "handlers": [
+                        "stdout",
+                    ],
+                },
+            }
+        )
+
+        logger = logging.getLogger("test.logger")
+        logger.debug("debug message")
+        logger.info("info message")
+        logger.warning("warning message")
+        logger.error("error message")
+        logger.critical("critical message")
+
+        data = mock_stderr.getvalue()
+        self.assertEqual(
+            data,
+            "CRITICAL - critical message\n",
+        )
+        self.assertEqual(mock_stdout.getvalue(), "")
+
+
 class TestBootstrapModule(unittest.TestCase):
     @patch("awslambdaric.bootstrap.handle_event_request")
     @patch("awslambdaric.bootstrap.LambdaRuntimeClient")
diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py
index 545efa1..34d59da 100644
--- a/tests/test_lambda_context.py
+++ b/tests/test_lambda_context.py
@@ -4,7 +4,7 @@
 
 import os
 import unittest
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
 
 from awslambdaric.lambda_context import LambdaContext
 
diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py
index 814ca96..47d95cf 100644
--- a/tests/test_lambda_runtime_client.py
+++ b/tests/test_lambda_runtime_client.py
@@ -6,13 +6,12 @@
 import http.client
 import unittest.mock
 from unittest.mock import MagicMock, patch
-from awslambdaric import __version__
-
 
+from awslambdaric import __version__
 from awslambdaric.lambda_runtime_client import (
+    InvocationRequest,
     LambdaRuntimeClient,
     LambdaRuntimeClientError,
-    InvocationRequest,
     _user_agent,
 )