Skip to content

feat(kafka): add logic to handle protobuf deserialization #6841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions aws_lambda_powertools/utilities/kafka/consumer_records.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from functools import cached_property
from typing import TYPE_CHECKING, Any

Expand All @@ -13,6 +14,8 @@

from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig

logger = logging.getLogger(__name__)


class ConsumerRecordRecords(KafkaEventRecordBase):
"""
Expand All @@ -31,18 +34,24 @@ def key(self) -> Any:
if not key:
return None

logger.debug("Deserializing key field")

# Determine schema type and schema string
schema_type = None
schema_str = None
schema_value = None
output_serializer = None

if self.schema_config and self.schema_config.key_schema_type:
schema_type = self.schema_config.key_schema_type
schema_str = self.schema_config.key_schema
schema_value = self.schema_config.key_schema
output_serializer = self.schema_config.key_output_serializer

# Always use get_deserializer if None it will default to DEFAULT
deserializer = get_deserializer(schema_type, schema_str)
deserializer = get_deserializer(
schema_type=schema_type,
schema_value=schema_value,
field_metadata=self.key_schema_metadata,
)
deserialized_value = deserializer.deserialize(key)

# Apply output serializer if specified
Expand All @@ -57,16 +66,22 @@ def value(self) -> Any:

# Determine schema type and schema string
schema_type = None
schema_str = None
schema_value = None
output_serializer = None

logger.debug("Deserializing value field")

if self.schema_config and self.schema_config.value_schema_type:
schema_type = self.schema_config.value_schema_type
schema_str = self.schema_config.value_schema
schema_value = self.schema_config.value_schema
output_serializer = self.schema_config.value_output_serializer

# Always use get_deserializer if None it will default to DEFAULT
deserializer = get_deserializer(schema_type, schema_str)
deserializer = get_deserializer(
schema_type=schema_type,
schema_value=schema_value,
field_metadata=self.value_schema_metadata,
)
deserialized_value = deserializer.deserialize(value)

# Apply output serializer if specified
Expand Down
15 changes: 14 additions & 1 deletion aws_lambda_powertools/utilities/kafka/deserializer/avro.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import io
import logging
from typing import Any

from avro.io import BinaryDecoder, DatumReader
from avro.schema import parse as parse_schema
Expand All @@ -9,8 +11,11 @@
from aws_lambda_powertools.utilities.kafka.exceptions import (
KafkaConsumerAvroSchemaParserError,
KafkaConsumerDeserializationError,
KafkaConsumerDeserializationFormatMismatch,
)

logger = logging.getLogger(__name__)


class AvroDeserializer(DeserializerBase):
"""
Expand All @@ -20,10 +25,11 @@ class AvroDeserializer(DeserializerBase):
a provided Avro schema definition.
"""

def __init__(self, schema_str: str):
def __init__(self, schema_str: str, field_metadata: dict[str, Any] | None = None):
try:
self.parsed_schema = parse_schema(schema_str)
self.reader = DatumReader(self.parsed_schema)
self.field_metatada = field_metadata
except Exception as e:
raise KafkaConsumerAvroSchemaParserError(
f"Invalid Avro schema. Please ensure the provided avro schema is valid: {type(e).__name__}: {str(e)}",
Expand Down Expand Up @@ -60,6 +66,13 @@ def deserialize(self, data: bytes | str) -> object:
... except KafkaConsumerDeserializationError as e:
... print(f"Failed to deserialize: {e}")
"""
data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None

if data_format and data_format != "AVRO":
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is AVRO but you sent {data_format}")

logger.debug("Deserializing data with AVRO format")

try:
value = self._decode_input(data)
bytes_reader = io.BytesIO(value)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from __future__ import annotations

import base64
import logging

from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase

logger = logging.getLogger(__name__)


class DefaultDeserializer(DeserializerBase):
"""
Expand Down Expand Up @@ -43,4 +46,5 @@ def deserialize(self, data: bytes | str) -> str:
>>> result = deserializer.deserialize(bytes_data)
>>> print(result == bytes_data) # Output: True
"""
logger.debug("Deserializing data with primitives types")
return base64.b64decode(data).decode("utf-8")
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,27 @@
_deserializer_cache: dict[str, DeserializerBase] = {}


def _get_cache_key(schema_type: str | object, schema_value: Any) -> str:
def _get_cache_key(schema_type: str | object, schema_value: Any, field_metadata: dict[str, Any]) -> str:
schema_metadata = None

if field_metadata:
schema_metadata = field_metadata.get("schemaId")

if schema_value is None:
return str(schema_type)
schema_hash = f"{str(schema_type)}_{schema_metadata}"

if isinstance(schema_value, str):
hashable_value = f"{schema_value}_{schema_metadata}"
# For string schemas like Avro, hash the content
schema_hash = hashlib.md5(schema_value.encode("utf-8"), usedforsecurity=False).hexdigest()
schema_hash = hashlib.md5(hashable_value.encode("utf-8"), usedforsecurity=False).hexdigest()
else:
# For objects like Protobuf, use the object id
schema_hash = str(id(schema_value))
schema_hash = f"{str(id(schema_value))}_{schema_metadata}"

return f"{schema_type}_{schema_hash}"


def get_deserializer(schema_type: str | object, schema_value: Any) -> DeserializerBase:
def get_deserializer(schema_type: str | object, schema_value: Any, field_metadata: Any) -> DeserializerBase:
"""
Factory function to get the appropriate deserializer based on schema type.

Expand Down Expand Up @@ -75,7 +81,7 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ
"""

# Generate a cache key based on schema type and value
cache_key = _get_cache_key(schema_type, schema_value)
cache_key = _get_cache_key(schema_type, schema_value, field_metadata)

# Check if we already have this deserializer in cache
if cache_key in _deserializer_cache:
Expand All @@ -87,14 +93,14 @@ def get_deserializer(schema_type: str | object, schema_value: Any) -> Deserializ
# Import here to avoid dependency if not used
from aws_lambda_powertools.utilities.kafka.deserializer.avro import AvroDeserializer

deserializer = AvroDeserializer(schema_value)
deserializer = AvroDeserializer(schema_str=schema_value, field_metadata=field_metadata)
elif schema_type == "PROTOBUF":
# Import here to avoid dependency if not used
from aws_lambda_powertools.utilities.kafka.deserializer.protobuf import ProtobufDeserializer

deserializer = ProtobufDeserializer(schema_value)
deserializer = ProtobufDeserializer(message_class=schema_value, field_metadata=field_metadata)
elif schema_type == "JSON":
deserializer = JsonDeserializer()
deserializer = JsonDeserializer(field_metadata=field_metadata)

else:
# Default to no-op deserializer
Expand Down
20 changes: 19 additions & 1 deletion aws_lambda_powertools/utilities/kafka/deserializer/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@

import base64
import json
import logging
from typing import Any

from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
from aws_lambda_powertools.utilities.kafka.exceptions import KafkaConsumerDeserializationError
from aws_lambda_powertools.utilities.kafka.exceptions import (
KafkaConsumerDeserializationError,
KafkaConsumerDeserializationFormatMismatch,
)

logger = logging.getLogger(__name__)


class JsonDeserializer(DeserializerBase):
Expand All @@ -15,6 +22,9 @@ class JsonDeserializer(DeserializerBase):
into Python dictionaries.
"""

def __init__(self, field_metadata: dict[str, Any] | None = None):
self.field_metatada = field_metadata

def deserialize(self, data: bytes | str) -> dict:
"""
Deserialize JSON data to a Python dictionary.
Expand Down Expand Up @@ -45,6 +55,14 @@ def deserialize(self, data: bytes | str) -> dict:
... except KafkaConsumerDeserializationError as e:
... print(f"Failed to deserialize: {e}")
"""

data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None

if data_format and data_format != "JSON":
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is JSON but you sent {data_format}")

logger.debug("Deserializing data with JSON format")

try:
return json.loads(base64.b64decode(data).decode("utf-8"))
except Exception as e:
Expand Down
94 changes: 49 additions & 45 deletions aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from __future__ import annotations

import logging
from typing import Any

from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined]
from google.protobuf.internal.decoder import _DecodeSignedVarint # type: ignore[attr-defined]
from google.protobuf.json_format import MessageToDict

from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase
from aws_lambda_powertools.utilities.kafka.exceptions import (
KafkaConsumerDeserializationError,
KafkaConsumerDeserializationFormatMismatch,
)

logger = logging.getLogger(__name__)


class ProtobufDeserializer(DeserializerBase):
"""
Expand All @@ -19,8 +23,9 @@ class ProtobufDeserializer(DeserializerBase):
into Python dictionaries using the provided Protocol Buffer message class.
"""

def __init__(self, message_class: Any):
def __init__(self, message_class: Any, field_metadata: dict[str, Any] | None = None):
self.message_class = message_class
self.field_metatada = field_metadata

def deserialize(self, data: bytes | str) -> dict:
"""
Expand Down Expand Up @@ -61,57 +66,56 @@ def deserialize(self, data: bytes | str) -> dict:
... except KafkaConsumerDeserializationError as e:
... print(f"Failed to deserialize: {e}")
"""
value = self._decode_input(data)
try:
message = self.message_class()
message.ParseFromString(value)
return MessageToDict(message, preserving_proto_field_name=True)
except Exception:
return self._deserialize_with_message_index(value, self.message_class())

def _deserialize_with_message_index(self, data: bytes, parser: Any) -> dict:
"""
Deserialize protobuf message with Confluent message index handling.
data_format = self.field_metatada.get("dataFormat") if self.field_metatada else None
schema_id = self.field_metatada.get("schemaId") if self.field_metatada else None

Parameters
----------
data : bytes
data
parser : google.protobuf.message.Message
Protobuf message instance to parse the data into
if data_format and data_format != "PROTOBUF":
raise KafkaConsumerDeserializationFormatMismatch(f"Expected data is PROTOBUF but you sent {data_format}")

Returns
-------
dict
Dictionary representation of the parsed protobuf message with original field names
logger.debug("Deserializing data with PROTOBUF format")

Raises
------
KafkaConsumerDeserializationError
If deserialization fails
try:
value = self._decode_input(data)
message = self.message_class()
if schema_id is None:
logger.debug("Plain PROTOBUF data: using default deserializer")
# Plain protobuf - direct parser
message.ParseFromString(value)
elif len(schema_id) > 20:
logger.debug("PROTOBUF data integrated with Glue SchemaRegistry: using Glue deserializer")
# Glue schema registry integration - remove the first byte
message.ParseFromString(value[1:])
else:
logger.debug("PROTOBUF data integrated with Confluent SchemaRegistry: using Confluent deserializer")
# Confluent schema registry integration - remove message index list
message.ParseFromString(self._remove_message_index(value))

Notes
-----
This method handles the special case of Confluent Schema Registry's message index
format, where the message is prefixed with either a single 0 (for the first schema)
or a list of schema indexes. The actual protobuf message follows these indexes.
"""
return MessageToDict(message, preserving_proto_field_name=True)
except Exception as e:
raise KafkaConsumerDeserializationError(
f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}",
) from e

def _remove_message_index(self, data):
"""
Identifies and removes Confluent Schema Registry MessageIndex from bytes.
Returns pure protobuf bytes.
"""
buffer = memoryview(data)
pos = 0

try:
first_value, new_pos = _DecodeVarint(buffer, pos)
pos = new_pos
logger.debug("Removing message list bytes")

if first_value != 0:
for _ in range(first_value):
_, new_pos = _DecodeVarint(buffer, pos)
pos = new_pos
# Read first varint (index count or 0)
first_value, new_pos = _DecodeSignedVarint(buffer, pos)
pos = new_pos

parser.ParseFromString(data[pos:])
return MessageToDict(parser, preserving_proto_field_name=True)
except Exception as e:
raise KafkaConsumerDeserializationError(
f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}",
) from e
# Skip index values if present
if first_value != 0:
for _ in range(first_value):
_, new_pos = _DecodeSignedVarint(buffer, pos)
pos = new_pos

# Return remaining bytes (pure protobuf)
return data[pos:]
6 changes: 6 additions & 0 deletions aws_lambda_powertools/utilities/kafka/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ class KafkaConsumerAvroSchemaParserError(Exception):
"""


class KafkaConsumerDeserializationFormatMismatch(Exception):
"""
Error raised when deserialization format is incompatible
"""


class KafkaConsumerDeserializationError(Exception):
"""
Error raised when message deserialization fails.
Expand Down
Loading
Loading