diff --git a/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py b/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py index f4e02b8c565..a9209330505 100644 --- a/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py +++ b/aws_lambda_powertools/utilities/kafka/deserializer/protobuf.py @@ -2,6 +2,7 @@ from typing import Any +from google.protobuf.internal.decoder import _DecodeVarint # type: ignore[attr-defined] from google.protobuf.json_format import MessageToDict from aws_lambda_powertools.utilities.kafka.deserializer.base import DeserializerBase @@ -43,6 +44,12 @@ def deserialize(self, data: bytes | str) -> dict: When the data cannot be deserialized according to the message class, typically due to data format incompatibility or incorrect message class. + Notes + ----- + This deserializer handles both standard Protocol Buffer format and the Confluent + Schema Registry format which includes message index information. It will first try + standard deserialization and fall back to message index handling if needed. + Example -------- >>> # Assuming proper protobuf setup @@ -54,11 +61,56 @@ def deserialize(self, data: bytes | str) -> dict: ... except KafkaConsumerDeserializationError as e: ... print(f"Failed to deserialize: {e}") """ + value = self._decode_input(data) try: - value = self._decode_input(data) message = self.message_class() message.ParseFromString(value) return MessageToDict(message, preserving_proto_field_name=True) + except Exception: + return self._deserialize_with_message_index(value, self.message_class()) + + def _deserialize_with_message_index(self, data: bytes, parser: Any) -> dict: + """ + Deserialize protobuf message with Confluent message index handling. + + Parameters + ---------- + data : bytes + data + parser : google.protobuf.message.Message + Protobuf message instance to parse the data into + + Returns + ------- + dict + Dictionary representation of the parsed protobuf message with original field names + + Raises + ------ + KafkaConsumerDeserializationError + If deserialization fails + + Notes + ----- + This method handles the special case of Confluent Schema Registry's message index + format, where the message is prefixed with either a single 0 (for the first schema) + or a list of schema indexes. The actual protobuf message follows these indexes. + """ + + buffer = memoryview(data) + pos = 0 + + try: + first_value, new_pos = _DecodeVarint(buffer, pos) + pos = new_pos + + if first_value != 0: + for _ in range(first_value): + _, new_pos = _DecodeVarint(buffer, pos) + pos = new_pos + + parser.ParseFromString(data[pos:]) + return MessageToDict(parser, preserving_proto_field_name=True) except Exception as e: raise KafkaConsumerDeserializationError( f"Error trying to deserialize protobuf data - {type(e).__name__}: {str(e)}", diff --git a/tests/functional/kafka_consumer/_protobuf/confluent_protobuf.proto b/tests/functional/kafka_consumer/_protobuf/confluent_protobuf.proto new file mode 100644 index 00000000000..ee7e7593c32 --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/confluent_protobuf.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package org.demo.kafka.protobuf; + +option java_package = "org.demo.kafka.protobuf"; +option java_outer_classname = "ProtobufProductOuterClass"; +option java_multiple_files = true; + +message ProtobufProduct { + int32 id = 1; + string name = 2; + double price = 3; +} diff --git a/tests/functional/kafka_consumer/_protobuf/confluent_protobuf_pb2.py b/tests/functional/kafka_consumer/_protobuf/confluent_protobuf_pb2.py new file mode 100644 index 00000000000..87bf81abe44 --- /dev/null +++ b/tests/functional/kafka_consumer/_protobuf/confluent_protobuf_pb2.py @@ -0,0 +1,37 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# Protobuf Python Version: 6.30.2 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 30, + 2, + "", + "confluent_protobuf.proto", +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x18\x63onfluent_protobuf.proto\x12\x17org.demo.kafka.protobuf":\n\x0fProtobufProduct\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05price\x18\x03 \x01(\x01\x42\x36\n\x17org.demo.kafka.protobufB\x19ProtobufProductOuterClassP\x01\x62\x06proto3', # noqa: E501 +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "confluent_protobuf_pb2", _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals["DESCRIPTOR"]._loaded_options = None + _globals["DESCRIPTOR"]._serialized_options = b"\n\027org.demo.kafka.protobufB\031ProtobufProductOuterClassP\001" + _globals["_PROTOBUFPRODUCT"]._serialized_start = 53 + _globals["_PROTOBUFPRODUCT"]._serialized_end = 111 +# @@protoc_insertion_point(module_scope) diff --git a/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py b/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py index 0fbc07158eb..2ab38dcc4f6 100644 --- a/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py +++ b/tests/functional/kafka_consumer/_protobuf/test_kafka_consumer_with_protobuf.py @@ -12,6 +12,9 @@ from aws_lambda_powertools.utilities.kafka.kafka_consumer import kafka_consumer from aws_lambda_powertools.utilities.kafka.schema_config import SchemaConfig +# Import confluent complex schema +from .confluent_protobuf_pb2 import ProtobufProduct + # Import the generated protobuf classes from .user_pb2 import Key, User @@ -335,3 +338,87 @@ def test_kafka_consumer_without_protobuf_key_schema(): # Verify the error message mentions the missing key schema assert "key_schema" in str(excinfo.value) assert "PROTOBUF" in str(excinfo.value) + + +def test_confluent_complex_schema(lambda_context): + # GIVEN + # A scenario where a complex schema is used with the PROTOBUF schema type + complex_event = { + "eventSource": "aws:kafka", + "eventSourceArn": "arn:aws:kafka:us-east-1:0123456789019:cluster/SalesCluster/abcd1234", + "bootstrapServers": "b-2.demo-cluster-1.a1bcde.c1.kafka.us-east-1.amazonaws.com:9092", + "records": { + "mytopic-0": [ + { + "topic": "mytopic", + "partition": 0, + "offset": 15, + "timestamp": 1545084650987, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "COkHEgZMYXB0b3AZUrgehes/j0A=", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 16, + "timestamp": 1545084650988, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "AAjpBxIGTGFwdG9wGVK4HoXrP49A", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + }, + { + "topic": "mytopic", + "partition": 0, + "offset": 17, + "timestamp": 1545084650989, + "timestampType": "CREATE_TIME", + "key": "NDI=", + "value": "AgEACOkHEgZMYXB0b3AZUrgehes/j0A=", + "headers": [{"headerKey": [104, 101, 97, 100, 101, 114, 86, 97, 108, 117, 101]}], + }, + ], + }, + } + + # GIVEN A Kafka consumer configured to deserialize Protobuf data + # using the User protobuf message type as the schema + schema_config = SchemaConfig( + value_schema_type="PROTOBUF", + value_schema=ProtobufProduct, + ) + + processed_records = [] + + @kafka_consumer(schema_config=schema_config) + def handler(event: ConsumerRecords, context): + for record in event.records: + processed_records.append( + {"id": record.value["id"], "name": record.value["name"], "price": record.value["price"]}, + ) + return {"processed": len(processed_records)} + + # WHEN The handler processes a Kafka event containing Protobuf-encoded data + result = handler(complex_event, lambda_context) + + # THEN + # The handler should successfully process both records + # and return the correct count + assert result == {"processed": 3} + + # All records should be correctly deserialized with proper values + assert len(processed_records) == 3 + + # First record should contain decoded values + assert processed_records[0]["id"] == 1001 + assert processed_records[0]["name"] == "Laptop" + + # Second record should contain decoded values + assert processed_records[1]["id"] == 1001 + assert processed_records[1]["name"] == "Laptop" + + # Third record should contain decoded values + assert processed_records[2]["id"] == 1001 + assert processed_records[2]["name"] == "Laptop"