Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/tensorflow/lite/micro/kernels/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ struct OpDataConv {
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;

// A buffer used to store unpacked filter values. This is used if the source
// tensor is of n-bit precision that cannot be easily processed by kernels.
int filter_buffer_index;
};

extern const int kConvInputTensor;
Expand Down
15 changes: 10 additions & 5 deletions src/tensorflow/lite/micro/kernels/conv_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,8 @@ limitations under the License.
==============================================================================*/

#include "tensorflow/lite/c/builtin_op_data.h"
#include "tensorflow/lite/c/c_api_types.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/conv.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/padding.h"
#include "tensorflow/lite/micro/kernels/conv.h"
Expand Down Expand Up @@ -188,6 +184,15 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
context, node, params, input_width, input_height, filter_width,
filter_height, output_width, output_height, input->type, data));

if (filter->type == kTfLiteInt4) {
int filter_size =
RuntimeShape(filter->dims->size,
reinterpret_cast<const int32_t*>(filter->dims->data))
.FlatSize();
context->RequestScratchBufferInArena(context, filter_size,
&data->filter_buffer_index);
}

micro_context->DeallocateTempTfLiteTensor(filter);
micro_context->DeallocateTempTfLiteTensor(input);
micro_context->DeallocateTempTfLiteTensor(output);
Expand Down
3 changes: 3 additions & 0 deletions src/tensorflow/lite/micro/memory_helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
case kTfLiteComplex128:
*size = sizeof(double) * 2;
break;
case kTfLiteInt4:
*size = sizeof(int8_t);
break;
default:
return kTfLiteError;
}
Expand Down