Skip to content
This repository was archived by the owner on Feb 24, 2025. It is now read-only.

Commit 93e7be8

Browse files
authored
Sync from tflite-micro. (#146)
1 parent df33fd4 commit 93e7be8

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

src/tensorflow/lite/micro/kernels/conv.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ struct OpDataConv {
4545
// uint8_t these would be 0 and 255.
4646
int32_t output_activation_min;
4747
int32_t output_activation_max;
48+
49+
// A buffer used to store unpacked filter values. This is used if the source
50+
// tensor is of n-bit precision that cannot be easily processed by kernels.
51+
int filter_buffer_index;
4852
};
4953

5054
extern const int kConvInputTensor;

src/tensorflow/lite/micro/kernels/conv_common.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,8 @@ limitations under the License.
1414
==============================================================================*/
1515

1616
#include "tensorflow/lite/c/builtin_op_data.h"
17+
#include "tensorflow/lite/c/c_api_types.h"
1718
#include "tensorflow/lite/c/common.h"
18-
#include "tensorflow/lite/kernels/internal/common.h"
19-
#include "tensorflow/lite/kernels/internal/quantization_util.h"
20-
#include "tensorflow/lite/kernels/internal/reference/conv.h"
21-
#include "tensorflow/lite/kernels/internal/reference/integer_ops/conv.h"
22-
#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
2319
#include "tensorflow/lite/kernels/kernel_util.h"
2420
#include "tensorflow/lite/kernels/padding.h"
2521
#include "tensorflow/lite/micro/kernels/conv.h"
@@ -188,6 +184,15 @@ TfLiteStatus ConvPrepare(TfLiteContext* context, TfLiteNode* node) {
188184
context, node, params, input_width, input_height, filter_width,
189185
filter_height, output_width, output_height, input->type, data));
190186

187+
if (filter->type == kTfLiteInt4) {
188+
int filter_size =
189+
RuntimeShape(filter->dims->size,
190+
reinterpret_cast<const int32_t*>(filter->dims->data))
191+
.FlatSize();
192+
context->RequestScratchBufferInArena(context, filter_size,
193+
&data->filter_buffer_index);
194+
}
195+
191196
micro_context->DeallocateTempTfLiteTensor(filter);
192197
micro_context->DeallocateTempTfLiteTensor(input);
193198
micro_context->DeallocateTempTfLiteTensor(output);

src/tensorflow/lite/micro/memory_helpers.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ TfLiteStatus TfLiteTypeSizeOf(TfLiteType type, size_t* size) {
9090
case kTfLiteComplex128:
9191
*size = sizeof(double) * 2;
9292
break;
93+
case kTfLiteInt4:
94+
*size = sizeof(int8_t);
95+
break;
9396
default:
9497
return kTfLiteError;
9598
}

0 commit comments

Comments
 (0)