From 365c3b88aa2e6d15a866be2bdae73f7684479777 Mon Sep 17 00:00:00 2001
From: Artem Krivonos <artemkr@amazon.com>
Date: Mon, 21 Aug 2023 17:15:10 +0100
Subject: [PATCH] Add structured logging implementation
---
awslambdaric/bootstrap.py | 140 +++++++++-----
awslambdaric/lambda_runtime_log_utils.py | 123 +++++++++++++
tests/test_bootstrap.py | 221 ++++++++++++++++++++++-
tests/test_lambda_context.py | 2 +-
tests/test_lambda_runtime_client.py | 5 +-
5 files changed, 436 insertions(+), 55 deletions(-)
create mode 100644 awslambdaric/lambda_runtime_log_utils.py
diff --git a/awslambdaric/bootstrap.py b/awslambdaric/bootstrap.py
index e7b9e5a..5ad7bb5 100644
--- a/awslambdaric/bootstrap.py
+++ b/awslambdaric/bootstrap.py
@@ -13,10 +13,19 @@
from .lambda_context import LambdaContext
from .lambda_runtime_client import LambdaRuntimeClient
from .lambda_runtime_exception import FaultException
+from .lambda_runtime_log_utils import (
+ _DATETIME_FORMAT,
+ _DEFAULT_FRAME_TYPE,
+ _JSON_FRAME_TYPES,
+ JsonFormatter,
+ LogFormat,
+)
from .lambda_runtime_marshaller import to_json
ERROR_LOG_LINE_TERMINATE = "\r"
ERROR_LOG_IDENT = "\u00a0" # NO-BREAK SPACE U+00A0
+_AWS_LAMBDA_LOG_FORMAT = LogFormat.from_str(os.environ.get("AWS_LAMBDA_LOG_FORMAT"))
+_AWS_LAMBDA_LOG_LEVEL = os.environ.get("AWS_LAMBDA_LOG_LEVEL", "").upper()
def _get_handler(handler):
@@ -73,7 +82,12 @@ def result(*args):
return result
-def make_error(error_message, error_type, stack_trace, invoke_id=None):
+def make_error(
+ error_message,
+ error_type,
+ stack_trace,
+ invoke_id=None,
+):
result = {
"errorMessage": error_message if error_message else "",
"errorType": error_type if error_type else "",
@@ -92,34 +106,52 @@ def replace_line_indentation(line, indent_char, new_indent_char):
return (new_indent_char * ident_chars_count) + line[ident_chars_count:]
-def log_error(error_result, log_sink):
- error_description = "[ERROR]"
+if _AWS_LAMBDA_LOG_FORMAT == LogFormat.JSON:
+ _ERROR_FRAME_TYPE = _JSON_FRAME_TYPES[logging.ERROR]
+
+ def log_error(error_result, log_sink):
+ error_result = {
+ "timestamp": time.strftime(
+ _DATETIME_FORMAT, logging.Formatter.converter(time.time())
+ ),
+ "log_level": "ERROR",
+ **error_result,
+ }
+ log_sink.log_error(
+ [to_json(error_result)],
+ )
- error_result_type = error_result.get("errorType")
- if error_result_type:
- error_description += " " + error_result_type
+else:
+ _ERROR_FRAME_TYPE = _DEFAULT_FRAME_TYPE
- error_result_message = error_result.get("errorMessage")
- if error_result_message:
+ def log_error(error_result, log_sink):
+ error_description = "[ERROR]"
+
+ error_result_type = error_result.get("errorType")
if error_result_type:
- error_description += ":"
- error_description += " " + error_result_message
+ error_description += " " + error_result_type
+
+ error_result_message = error_result.get("errorMessage")
+ if error_result_message:
+ if error_result_type:
+ error_description += ":"
+ error_description += " " + error_result_message
- error_message_lines = [error_description]
+ error_message_lines = [error_description]
- stack_trace = error_result.get("stackTrace")
- if stack_trace is not None:
- error_message_lines += ["Traceback (most recent call last):"]
- for trace_element in stack_trace:
- if trace_element == "":
- error_message_lines += [""]
- else:
- for trace_line in trace_element.splitlines():
- error_message_lines += [
- replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
- ]
+ stack_trace = error_result.get("stackTrace")
+ if stack_trace is not None:
+ error_message_lines += ["Traceback (most recent call last):"]
+ for trace_element in stack_trace:
+ if trace_element == "":
+ error_message_lines += [""]
+ else:
+ for trace_line in trace_element.splitlines():
+ error_message_lines += [
+ replace_line_indentation(trace_line, " ", ERROR_LOG_IDENT)
+ ]
- log_sink.log_error(error_message_lines)
+ log_sink.log_error(error_message_lines)
def handle_event_request(
@@ -152,7 +184,12 @@ def handle_event_request(
)
except FaultException as e:
xray_fault = make_xray_fault("LambdaValidationError", e.msg, os.getcwd(), [])
- error_result = make_error(e.msg, e.exception_type, e.trace, invoke_id)
+ error_result = make_error(
+ e.msg,
+ e.exception_type,
+ e.trace,
+ invoke_id,
+ )
except Exception:
etype, value, tb = sys.exc_info()
@@ -221,7 +258,9 @@ def build_fault_result(exc_info, msg):
break
return make_error(
- msg if msg else str(value), etype.__name__, traceback.format_list(tb_tuples)
+ msg if msg else str(value),
+ etype.__name__,
+ traceback.format_list(tb_tuples),
)
@@ -257,7 +296,8 @@ def __init__(self, log_sink):
def emit(self, record):
msg = self.format(record)
- self.log_sink.log(msg)
+
+ self.log_sink.log(msg, frame_type=getattr(record, "_frame_type", None))
class LambdaLoggerFilter(logging.Filter):
@@ -298,7 +338,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_tb):
pass
- def log(self, msg):
+ def log(self, msg, frame_type=None):
sys.stdout.write(msg)
def log_error(self, message_lines):
@@ -324,7 +364,6 @@ class FramedTelemetryLogSink(object):
def __init__(self, fd):
self.fd = int(fd)
- self.frame_type = 0xA55A0003.to_bytes(4, "big")
def __enter__(self):
self.file = os.fdopen(self.fd, "wb", 0)
@@ -333,11 +372,12 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_tb):
self.file.close()
- def log(self, msg):
+ def log(self, msg, frame_type=None):
encoded_msg = msg.encode("utf8")
+
timestamp = int(time.time_ns() / 1000) # UNIX timestamp in microseconds
log_msg = (
- self.frame_type
+ (frame_type or _DEFAULT_FRAME_TYPE)
+ len(encoded_msg).to_bytes(4, "big")
+ timestamp.to_bytes(8, "big")
+ encoded_msg
@@ -346,7 +386,10 @@ def log(self, msg):
def log_error(self, message_lines):
error_message = "\n".join(message_lines)
- self.log(error_message)
+ self.log(
+ error_message,
+ frame_type=_ERROR_FRAME_TYPE,
+ )
def update_xray_env_variable(xray_trace_id):
@@ -370,6 +413,28 @@ def create_log_sink():
_GLOBAL_AWS_REQUEST_ID = None
+def _setup_logging(log_format, log_level, log_sink):
+ logging.Formatter.converter = time.gmtime
+ logger = logging.getLogger()
+ logger_handler = LambdaLoggerHandler(log_sink)
+ if log_format == LogFormat.JSON:
+ logger_handler.setFormatter(JsonFormatter())
+
+ logging.addLevelName(logging.DEBUG, "TRACE")
+ if log_level in logging._nameToLevel:
+ logger.setLevel(log_level)
+ else:
+ logger_handler.setFormatter(
+ logging.Formatter(
+ "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
+ "%Y-%m-%dT%H:%M:%S",
+ )
+ )
+
+ logger_handler.addFilter(LambdaLoggerFilter())
+ logger.addHandler(logger_handler)
+
+
def run(app_root, handler, lambda_runtime_api_addr):
sys.stdout = Unbuffered(sys.stdout)
sys.stderr = Unbuffered(sys.stderr)
@@ -378,18 +443,7 @@ def run(app_root, handler, lambda_runtime_api_addr):
lambda_runtime_client = LambdaRuntimeClient(lambda_runtime_api_addr)
try:
- logging.Formatter.converter = time.gmtime
- logger = logging.getLogger()
- logger_handler = LambdaLoggerHandler(log_sink)
- logger_handler.setFormatter(
- logging.Formatter(
- "[%(levelname)s]\t%(asctime)s.%(msecs)03dZ\t%(aws_request_id)s\t%(message)s\n",
- "%Y-%m-%dT%H:%M:%S",
- )
- )
- logger_handler.addFilter(LambdaLoggerFilter())
- logger.addHandler(logger_handler)
-
+ _setup_logging(_AWS_LAMBDA_LOG_FORMAT, _AWS_LAMBDA_LOG_LEVEL, log_sink)
global _GLOBAL_AWS_REQUEST_ID
request_handler = _get_handler(handler)
diff --git a/awslambdaric/lambda_runtime_log_utils.py b/awslambdaric/lambda_runtime_log_utils.py
new file mode 100644
index 0000000..f140253
--- /dev/null
+++ b/awslambdaric/lambda_runtime_log_utils.py
@@ -0,0 +1,123 @@
+"""
+Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+"""
+
+import json
+import logging
+import traceback
+from enum import IntEnum
+
+_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%SZ"
+_RESERVED_FIELDS = {
+ "name",
+ "msg",
+ "args",
+ "levelname",
+ "levelno",
+ "pathname",
+ "filename",
+ "module",
+ "exc_info",
+ "exc_text",
+ "stack_info",
+ "lineno",
+ "funcName",
+ "created",
+ "msecs",
+ "relativeCreated",
+ "thread",
+ "threadName",
+ "processName",
+ "process",
+ "aws_request_id",
+ "_frame_type",
+}
+
+
+class LogFormat(IntEnum):
+ JSON = 0b0
+ TEXT = 0b1
+
+ @classmethod
+ def from_str(cls, value: str):
+ if value and value.upper() == "JSON":
+ return cls.JSON.value
+ return cls.TEXT.value
+
+
+_JSON_FRAME_TYPES = {
+ logging.NOTSET: 0xA55A0002.to_bytes(4, "big"),
+ logging.DEBUG: 0xA55A000A.to_bytes(4, "big"),
+ logging.INFO: 0xA55A000E.to_bytes(4, "big"),
+ logging.WARNING: 0xA55A0012.to_bytes(4, "big"),
+ logging.ERROR: 0xA55A0016.to_bytes(4, "big"),
+ logging.CRITICAL: 0xA55A001A.to_bytes(4, "big"),
+}
+_DEFAULT_FRAME_TYPE = 0xA55A0003.to_bytes(4, "big")
+
+_json_encoder = json.JSONEncoder(ensure_ascii=False)
+_encode_json = _json_encoder.encode
+
+
+class JsonFormatter(logging.Formatter):
+ def __init__(self):
+ super().__init__(datefmt=_DATETIME_FORMAT)
+
+ @staticmethod
+ def __format_stacktrace(exc_info):
+ if not exc_info:
+ return None
+ return traceback.format_tb(exc_info[2])
+
+ @staticmethod
+ def __format_exception_name(exc_info):
+ if not exc_info:
+ return None
+
+ return exc_info[0].__name__
+
+ @staticmethod
+ def __format_exception(exc_info):
+ if not exc_info:
+ return None
+
+ return str(exc_info[1])
+
+ @staticmethod
+ def __format_location(record: logging.LogRecord):
+ if not record.exc_info:
+ return None
+
+ return f"{record.pathname}:{record.funcName}:{record.lineno}"
+
+ @staticmethod
+ def __format_log_level(record: logging.LogRecord):
+ record.levelno = min(50, max(0, record.levelno)) // 10 * 10
+ record.levelname = logging.getLevelName(record.levelno)
+
+ def format(self, record: logging.LogRecord) -> str:
+ self.__format_log_level(record)
+ record._frame_type = _JSON_FRAME_TYPES.get(
+ record.levelno, _JSON_FRAME_TYPES[logging.NOTSET]
+ )
+
+ result = {
+ "timestamp": self.formatTime(record, self.datefmt),
+ "level": record.levelname,
+ "message": record.getMessage(),
+ "logger": record.name,
+ "stackTrace": self.__format_stacktrace(record.exc_info),
+ "errorType": self.__format_exception_name(record.exc_info),
+ "errorMessage": self.__format_exception(record.exc_info),
+ "requestId": getattr(record, "aws_request_id", None),
+ "location": self.__format_location(record),
+ }
+ result.update(
+ (key, value)
+ for key, value in record.__dict__.items()
+ if key not in _RESERVED_FIELDS and key not in result
+ )
+
+ result = {k: v for k, v in result.items() if v is not None}
+
+ return _encode_json(result) + "\n"
diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py
index edb0737..5614a2e 100644
--- a/tests/test_bootstrap.py
+++ b/tests/test_bootstrap.py
@@ -4,6 +4,8 @@
import importlib
import json
+import logging
+import logging.config
import os
import re
import tempfile
@@ -16,6 +18,7 @@
import awslambdaric.bootstrap as bootstrap
from awslambdaric.lambda_runtime_exception import FaultException
+from awslambdaric.lambda_runtime_log_utils import LogFormat
from awslambdaric.lambda_runtime_marshaller import LambdaMarshaller
@@ -613,14 +616,7 @@ def test_handle_event_request_fault_exception_logging_syntax_error(
bootstrap.StandardLogSink(),
)
- import sys
-
- sys.stderr.write(mock_stdout.getvalue())
-
- error_logs = (
- "[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': "
- "unexpected EOF while parsing (<string>, line 1)\r"
- )
+ error_logs = f"[ERROR] Runtime.UserCodeSyntaxError: Syntax error in module 'a': {syntax_error}\r"
error_logs += "Traceback (most recent call last):\r"
error_logs += ' File "<string>" Line 1\r'
error_logs += " -\n"
@@ -1174,6 +1170,215 @@ def test_multiple_frame(self):
self.assertEqual(content[pos:], b"")
+class TestLoggingSetup(unittest.TestCase):
+ def test_log_level(self) -> None:
+ test_cases = [
+ (LogFormat.JSON, "TRACE", logging.DEBUG),
+ (LogFormat.JSON, "DEBUG", logging.DEBUG),
+ (LogFormat.JSON, "INFO", logging.INFO),
+ (LogFormat.JSON, "WARN", logging.WARNING),
+ (LogFormat.JSON, "ERROR", logging.ERROR),
+ (LogFormat.JSON, "FATAL", logging.CRITICAL),
+ # Log level is set only for Json format
+ (LogFormat.TEXT, "TRACE", logging.NOTSET),
+ (LogFormat.TEXT, "DEBUG", logging.NOTSET),
+ (LogFormat.TEXT, "INFO", logging.NOTSET),
+ (LogFormat.TEXT, "WARN", logging.NOTSET),
+ (LogFormat.TEXT, "ERROR", logging.NOTSET),
+ (LogFormat.TEXT, "FATAL", logging.NOTSET),
+ ("Unknown format", "INFO", logging.NOTSET),
+ # if level is unknown fall back to default
+ (LogFormat.JSON, "Unknown level", logging.NOTSET),
+ ]
+ for fmt, log_level, expected_level in test_cases:
+ with self.subTest():
+ # Drop previous setup
+ logging.getLogger().handlers.clear()
+ logging.getLogger().level = logging.NOTSET
+
+ bootstrap._setup_logging(fmt, log_level, bootstrap.StandardLogSink())
+
+ self.assertEqual(expected_level, logging.getLogger().level)
+
+
+class TestLogging(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ logging.getLogger().handlers.clear()
+ logging.getLogger().level = logging.NOTSET
+ bootstrap._setup_logging(
+ LogFormat.from_str("JSON"), "INFO", bootstrap.StandardLogSink()
+ )
+
+ @patch("sys.stderr", new_callable=StringIO)
+ def test_json_formatter(self, mock_stderr):
+ logger = logging.getLogger("a.b")
+
+ test_cases = [
+ (
+ logging.ERROR,
+ "TEST 1",
+ {
+ "level": "ERROR",
+ "logger": "a.b",
+ "message": "TEST 1",
+ "requestId": "",
+ },
+ ),
+ (
+ logging.ERROR,
+ "test \nwith \nnew \nlines",
+ {
+ "level": "ERROR",
+ "logger": "a.b",
+ "message": "test \nwith \nnew \nlines",
+ "requestId": "",
+ },
+ ),
+ (
+ logging.CRITICAL,
+ "TEST CRITICAL",
+ {
+ "level": "CRITICAL",
+ "logger": "a.b",
+ "message": "TEST CRITICAL",
+ "requestId": "",
+ },
+ ),
+ ]
+ for level, msg, expected in test_cases:
+ with self.subTest(msg):
+ with patch("sys.stdout", new_callable=StringIO) as mock_stdout:
+ logger.log(level, msg)
+
+ data = json.loads(mock_stdout.getvalue())
+ data.pop("timestamp")
+ self.assertEqual(
+ data,
+ expected,
+ )
+ self.assertEqual(mock_stderr.getvalue(), "")
+
+ @patch("sys.stdout", new_callable=StringIO)
+ @patch("sys.stderr", new_callable=StringIO)
+ def test_exception(self, mock_stderr, mock_stdout):
+ try:
+ raise ValueError("error message")
+ except ValueError:
+ logging.getLogger("test.logger").exception("test exception")
+
+ exception_log = json.loads(mock_stdout.getvalue())
+ self.assertIn("location", exception_log)
+ self.assertIn("stackTrace", exception_log)
+ exception_log.pop("timestamp")
+ exception_log.pop("location")
+ stack_trace = exception_log.pop("stackTrace")
+
+ self.assertEqual(len(stack_trace), 1)
+
+ self.assertEqual(
+ exception_log,
+ {
+ "errorMessage": "error message",
+ "errorType": "ValueError",
+ "level": "ERROR",
+ "logger": "test.logger",
+ "message": "test exception",
+ "requestId": "",
+ },
+ )
+
+ self.assertEqual(mock_stderr.getvalue(), "")
+
+ @patch("sys.stdout", new_callable=StringIO)
+ @patch("sys.stderr", new_callable=StringIO)
+ def test_log_level(self, mock_stderr, mock_stdout):
+ logger = logging.getLogger("test.logger")
+
+ logger.debug("debug message")
+ logger.info("info message")
+
+ data = json.loads(mock_stdout.getvalue())
+ data.pop("timestamp")
+
+ self.assertEqual(
+ data,
+ {
+ "level": "INFO",
+ "logger": "test.logger",
+ "message": "info message",
+ "requestId": "",
+ },
+ )
+ self.assertEqual(mock_stderr.getvalue(), "")
+
+ @patch("sys.stdout", new_callable=StringIO)
+ @patch("sys.stderr", new_callable=StringIO)
+ def test_set_log_level_manually(self, mock_stderr, mock_stdout):
+ logger = logging.getLogger("test.logger")
+
+ # Changing log level after `bootstrap.setup_logging`
+ logging.getLogger().setLevel(logging.CRITICAL)
+
+ logger.debug("debug message")
+ logger.info("info message")
+ logger.warning("warning message")
+ logger.error("error message")
+ logger.critical("critical message")
+
+ data = json.loads(mock_stdout.getvalue())
+ data.pop("timestamp")
+
+ self.assertEqual(
+ data,
+ {
+ "level": "CRITICAL",
+ "logger": "test.logger",
+ "message": "critical message",
+ "requestId": "",
+ },
+ )
+ self.assertEqual(mock_stderr.getvalue(), "")
+
+ @patch("sys.stdout", new_callable=StringIO)
+ @patch("sys.stderr", new_callable=StringIO)
+ def test_set_log_level_with_dictConfig(self, mock_stderr, mock_stdout):
+ # Changing log level after `bootstrap.setup_logging`
+ logging.config.dictConfig(
+ {
+ "version": 1,
+ "disable_existing_loggers": False,
+ "formatters": {"simple": {"format": "%(levelname)-8s - %(message)s"}},
+ "handlers": {
+ "stdout": {
+ "class": "logging.StreamHandler",
+ "formatter": "simple",
+ },
+ },
+ "root": {
+ "level": "CRITICAL",
+ "handlers": [
+ "stdout",
+ ],
+ },
+ }
+ )
+
+ logger = logging.getLogger("test.logger")
+ logger.debug("debug message")
+ logger.info("info message")
+ logger.warning("warning message")
+ logger.error("error message")
+ logger.critical("critical message")
+
+ data = mock_stderr.getvalue()
+ self.assertEqual(
+ data,
+ "CRITICAL - critical message\n",
+ )
+ self.assertEqual(mock_stdout.getvalue(), "")
+
+
class TestBootstrapModule(unittest.TestCase):
@patch("awslambdaric.bootstrap.handle_event_request")
@patch("awslambdaric.bootstrap.LambdaRuntimeClient")
diff --git a/tests/test_lambda_context.py b/tests/test_lambda_context.py
index 545efa1..34d59da 100644
--- a/tests/test_lambda_context.py
+++ b/tests/test_lambda_context.py
@@ -4,7 +4,7 @@
import os
import unittest
-from unittest.mock import patch, MagicMock
+from unittest.mock import MagicMock, patch
from awslambdaric.lambda_context import LambdaContext
diff --git a/tests/test_lambda_runtime_client.py b/tests/test_lambda_runtime_client.py
index 814ca96..47d95cf 100644
--- a/tests/test_lambda_runtime_client.py
+++ b/tests/test_lambda_runtime_client.py
@@ -6,13 +6,12 @@
import http.client
import unittest.mock
from unittest.mock import MagicMock, patch
-from awslambdaric import __version__
-
+from awslambdaric import __version__
from awslambdaric.lambda_runtime_client import (
+ InvocationRequest,
LambdaRuntimeClient,
LambdaRuntimeClientError,
- InvocationRequest,
_user_agent,
)