@@ -377,40 +377,49 @@ inline Integer FloorLog2(Integer n) {
377377 }
378378}
379379
380- // The size of the LUT depends on the type of input. For uint8 and int8 inputs
381- // we use a 256 entries LUT to map all the values in the (u)int8 range. For
382- // int16 inputs the high 9 bits are used for indexing and the 7 remaining bits
383- // are used for interpolation. We thus use a 513-entries LUT for int16 cases,
384- // 512 for the 9-bit indexing and 1 extra entry to interpolate the last value.
385- template <typename T>
386- constexpr int LUTSize () {
387- static_assert (std::is_same<T, uint8_t >::value ||
388- std::is_same<T, int8_t >::value ||
389- std::is_same<T, int16_t >::value,
390- " Only LUTs with uint8, int8 or int16 inputs are supported." );
391- // As per c++11: constexpr methods cannot have more than one return statement.
392- return (std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value)
393- ? 256
394- : 513 ;
380+ namespace detail {
381+
382+ // LUTPopulate takes an optional type-erased transform_params to allow passing
383+ // extra parameters to the transform function pointer. const void* is used
384+ // instead of std::function to be compatible with TFLite Micro
385+ template <typename FloatT, typename Func>
386+ inline typename std::enable_if<std::is_same<Func, FloatT (*)(FloatT)>::value,
387+ FloatT>::type
388+ LUTTransform (Func transform, const void * /* transform_params*/ , FloatT value) {
389+ static_assert (std::is_floating_point<FloatT>::value,
390+ " FloatT must be a floating-point type." );
391+ return transform (value);
392+ }
393+
394+ template <typename FloatT, typename Func>
395+ inline typename std::enable_if<
396+ std::is_same<Func, FloatT (*)(FloatT, const void *)>::value, FloatT>::type
397+ LUTTransform (Func transform, const void * transform_params, FloatT value) {
398+ static_assert (std::is_floating_point<FloatT>::value,
399+ " FloatT must be a floating-point type." );
400+ return transform (value, transform_params);
395401}
396402
397403// Use the same LUT generation code for both uint8_t and int8_t. Int8_t indexes
398404// will be directly casted to uint8_t, the int8 LUT will thus be ordered as [0,
399405// 1, ..., 127, -128, ..., -2, -1] instead of [-128, -127, ..., -1, 0, 1, ...,
400406// 126, 127].
401- template <typename T>
402- inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
403- std::is_same<T, int8_t >::value,
404- void >::type
405- LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
406- int32_t output_zero_point, float (*transform)(float ), T* lut) {
407+ template <typename T, typename Func>
408+ inline void LUTPopulateInt8 (float input_scale, int32_t input_zero_point,
409+ float output_scale, int32_t output_zero_point,
410+ Func transform, const void * transform_params,
411+ T* lut) {
412+ static_assert (
413+ std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value,
414+ " T must be an uint8 or int8 type." );
407415 uint8_t * lut_uint8 = reinterpret_cast <uint8_t *>(lut);
408416 const float inverse_scale = 1 / output_scale;
409417 int32_t maxval = std::numeric_limits<T>::max ();
410418 int32_t minval = std::numeric_limits<T>::min ();
411419 for (int32_t val = minval; val <= maxval; ++val) {
412420 const float dequantized = input_scale * (val - input_zero_point);
413- const float transformed = transform (dequantized);
421+ const float transformed =
422+ LUTTransform (transform, transform_params, dequantized);
414423 const float rescaled = TfLiteRound (transformed * inverse_scale);
415424 const int32_t quantized =
416425 static_cast <int32_t >(rescaled + output_zero_point);
@@ -421,10 +430,11 @@ LUTPopulate(float input_scale, int32_t input_zero_point, float output_scale,
421430
422431// Keep floating-point type configurable for backward compatibility. float
423432// should be used for FloatT by default.
424- template <typename T, typename FloatT>
425- inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
426- LUTPopulate (FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
427- int32_t output_zero_point, FloatT (*transform)(FloatT), T* lut) {
433+ template <typename FloatT, typename Func>
434+ inline void LUTPopulateInt16 (FloatT input_scale, int32_t input_zero_point,
435+ FloatT output_scale, int32_t output_zero_point,
436+ Func transform, const void * transform_params,
437+ int16_t * lut) {
428438 static_assert (std::is_floating_point<FloatT>::value,
429439 " FloatT must be a floating-point type." );
430440 const FloatT input_min =
@@ -440,16 +450,21 @@ LUTPopulate(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
440450 const FloatT step = (input_max - input_min) / nb_steps;
441451 const FloatT half_step = step / 2 ;
442452 const FloatT output_scaling_inv =
443- static_cast <FloatT>(std::numeric_limits<T >::max () -
444- std::numeric_limits<T >::min () + 1 ) /
453+ static_cast <FloatT>(std::numeric_limits<int16_t >::max () -
454+ std::numeric_limits<int16_t >::min () + 1 ) /
445455 (output_max - output_min);
446- const FloatT table_min = static_cast <FloatT>(std::numeric_limits<T>::min ());
447- const FloatT table_max = static_cast <FloatT>(std::numeric_limits<T>::max ());
456+ const FloatT table_min =
457+ static_cast <FloatT>(std::numeric_limits<int16_t >::min ());
458+ const FloatT table_max =
459+ static_cast <FloatT>(std::numeric_limits<int16_t >::max ());
448460
449461 for (int i = 0 ; i < nb_steps; i++) {
450- const FloatT val = transform (input_min + i * step);
451- const FloatT val_midpoint = transform (input_min + i * step + half_step);
452- const FloatT val_next = transform (input_min + (i + 1 ) * step);
462+ const FloatT val =
463+ LUTTransform<FloatT>(transform, transform_params, input_min + i * step);
464+ const FloatT val_midpoint = LUTTransform<FloatT>(
465+ transform, transform_params, input_min + i * step + half_step);
466+ const FloatT val_next = LUTTransform<FloatT>(transform, transform_params,
467+ input_min + (i + 1 ) * step);
453468
454469 const FloatT sample_val = TfLiteRound (val * output_scaling_inv);
455470 const FloatT midpoint_interp_val =
@@ -460,54 +475,84 @@ LUTPopulate(FloatT input_scale, int32_t input_zero_point, FloatT output_scale,
460475 const FloatT midpoint_err = midpoint_interp_val - midpoint_val;
461476 const FloatT bias = TfLiteRound (midpoint_err / 2 );
462477
463- lut[i] = static_cast <T >(std::min<FloatT>(
478+ lut[i] = static_cast <int16_t >(std::min<FloatT>(
464479 std::max<FloatT>(sample_val - bias, table_min), table_max));
465480 }
466481
467- lut[nb_steps] = static_cast <T>(std::min<FloatT>(
468- std::max<FloatT>(TfLiteRound (transform (input_max) * output_scaling_inv),
482+ lut[nb_steps] = static_cast <int16_t >(std::min<FloatT>(
483+ std::max<FloatT>(TfLiteRound (LUTTransform<FloatT>(
484+ transform, transform_params, input_max) *
485+ output_scaling_inv),
469486 table_min),
470487 table_max));
471488}
472489
490+ } // namespace detail
491+
492+ template <typename T>
493+ inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
494+ std::is_same<T, int8_t >::value,
495+ void >::type
496+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
497+ int32_t output_zero_point, float (*transform)(float ), T* lut) {
498+ detail::LUTPopulateInt8 (input_scale, input_zero_point, output_scale,
499+ output_zero_point, transform, nullptr , lut);
500+ }
501+
502+ template <typename T>
503+ inline typename std::enable_if<std::is_same<T, uint8_t >::value ||
504+ std::is_same<T, int8_t >::value,
505+ void >::type
506+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
507+ int32_t output_zero_point, float (*transform)(float , const void *),
508+ const void * transform_params, T* lut) {
509+ detail::LUTPopulateInt8 (input_scale, input_zero_point, output_scale,
510+ output_zero_point, transform, transform_params, lut);
511+ }
512+
473513template <typename T>
474514inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
475515LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
476516 int32_t output_zero_point, float (*transform)(float ), T* lut) {
477- LUTPopulate<T, float >(input_scale, input_zero_point, output_scale,
478- output_zero_point, transform, lut);
517+ detail::LUTPopulateInt16<float >(input_scale, input_zero_point, output_scale,
518+ output_zero_point, transform, nullptr , lut);
519+ }
520+
521+ template <typename T>
522+ inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
523+ LUTPopulate (float input_scale, int32_t input_zero_point, float output_scale,
524+ int32_t output_zero_point, float (*transform)(float , const void *),
525+ const void * transform_params, T* lut) {
526+ detail::LUTPopulateInt16<float >(input_scale, input_zero_point, output_scale,
527+ output_zero_point, transform,
528+ transform_params, lut);
479529}
480530
481- // Deprecated and will be removed in future, please use LUTPopulate instead
482- template <typename FloatT, typename LutInT, typename LutOutT>
483- inline void gen_lut (FloatT (*func)(FloatT), FloatT input_min, FloatT input_max,
484- FloatT output_min, FloatT output_max, LutOutT* lut) {
485- static_assert (std::is_same<LutInT, LutOutT>::value,
486- " Input and output type of the LUT must be the same." );
487- static_assert (std::is_same<LutInT, int16_t >::value,
488- " Only int16_t type LUT are supported." );
489- static_assert (std::is_same<FloatT, float >::value,
490- " Only float type is supported for FloatT." );
491- using T = LutInT;
492-
493- const auto zero_point = [](float min, float max, float scale) {
494- // Symmetric int16 LUT, we know the zero-point will not overflow an int32_t
495- // and zero-point from min will be the same as from max.
496- return static_cast <int32_t >(
497- static_cast <float >(std::numeric_limits<T>::min ()) - min / scale);
498- };
499-
500- const float scale = static_cast <float >(std::numeric_limits<T>::max () -
501- std::numeric_limits<T>::min ());
502- const float input_scale = (input_max - input_min) / scale;
503- const FloatT output_scale = (output_max - output_min) / scale;
504- const int32_t input_zero_point =
505- zero_point (input_min, input_max, input_scale);
506- const int32_t output_zero_point =
507- zero_point (output_min, output_max, output_scale);
508-
509- return LUTPopulate<T, float >(input_scale, input_zero_point, output_scale,
510- output_zero_point, func, lut);
531+ // Deprecated, avoid usage and prefer the float version. Kept for
532+ // backward-compatiblity.
533+ template <typename T>
534+ inline typename std::enable_if<std::is_same<T, int16_t >::value, void >::type
535+ LUTPopulate (double input_scale, int32_t input_zero_point, double output_scale,
536+ int32_t output_zero_point, double (*transform)(double ), T* lut) {
537+ detail::LUTPopulateInt16<double >(input_scale, input_zero_point, output_scale,
538+ output_zero_point, transform, nullptr , lut);
539+ }
540+
541+ // The size of the LUT depends on the type of input. For uint8 and int8 inputs a
542+ // simple 256 entries LUT is used. For int16 inputs the high 9 bits are used for
543+ // indexing and the 7 remaining bits are used for interpolation. We thus use a
544+ // 513-entries LUT for int16 cases, 512 for the 9-bit indexing and 1 extra entry
545+ // to interpolate the last value.
546+ template <typename T>
547+ constexpr int LUTSize () {
548+ static_assert (std::is_same<T, uint8_t >::value ||
549+ std::is_same<T, int8_t >::value ||
550+ std::is_same<T, int16_t >::value,
551+ " Only LUTs with uint8, int8 or int16 inputs are supported." );
552+ // As per c++11: constexpr methods cannot have more than one return statement.
553+ return (std::is_same<T, uint8_t >::value || std::is_same<T, int8_t >::value)
554+ ? 256
555+ : 513 ;
511556}
512557
513558// int16_t -> int16_t table lookup with interpolation
0 commit comments