Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers micro_utils.cc Source File

micro_utils.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/experimental/micro/micro_utils.h"
00017 
00018 #include <limits.h>
00019 #include <math.h>
00020 #include <stdint.h>
00021 
00022 #include "tensorflow/lite/c/c_api_internal.h"
00023 
00024 namespace tflite {
00025 
00026 namespace {
00027 
00028 static const uint8_t kAsymmetricUInt8Min = 0;
00029 static const uint8_t kAsymmetricUInt8Max = 255;
00030 static const uint8_t kSymmetricUInt8Min = 1;
00031 static const uint8_t kSymmetricUInt8Max = 255;
00032 static const int8_t kAsymmetricInt8Min = -128;
00033 static const int8_t kAsymmetricInt8Max = 127;
00034 static const int kSymmetricInt8Scale = kAsymmetricInt8Max;
00035 
00036 }  // namespace
00037 
00038 int ElementCount(const TfLiteIntArray& dims) {
00039   int result = 1;
00040   for (int i = 0; i < dims.size; ++i) {
00041     result *= dims.data[i];
00042   }
00043   return result;
00044 }
00045 
00046 // Converts a float value into an unsigned eight-bit quantized value.
00047 uint8_t FloatToAsymmetricQuantizedUInt8(const float value, const float scale,
00048                                         const int zero_point) {
00049   int32_t result = round(value / scale) + zero_point;
00050   if (result < kAsymmetricUInt8Min) {
00051     result = kAsymmetricUInt8Min;
00052   }
00053   if (result > kAsymmetricUInt8Max) {
00054     result = kAsymmetricUInt8Max;
00055   }
00056   return result;
00057 }
00058 
00059 uint8_t FloatToSymmetricQuantizedUInt8(const float value, const float scale) {
00060   int32_t result = round(value / scale);
00061   if (result < kSymmetricUInt8Min) {
00062     result = kSymmetricUInt8Min;
00063   }
00064   if (result > kSymmetricUInt8Max) {
00065     result = kSymmetricUInt8Max;
00066   }
00067   return result;
00068 }
00069 
00070 int8_t FloatToAsymmetricQuantizedInt8(const float value, const float scale,
00071                                       const int zero_point) {
00072   return FloatToAsymmetricQuantizedUInt8(value, scale,
00073                                          zero_point - kAsymmetricInt8Min) +
00074          kAsymmetricInt8Min;
00075 }
00076 
00077 int8_t FloatToSymmetricQuantizedInt8(const float value, const float scale) {
00078   return FloatToSymmetricQuantizedUInt8(value, scale) + kAsymmetricInt8Min;
00079 }
00080 
00081 int32_t FloatToSymmetricQuantizedInt32(const float value, const float scale) {
00082   float quantized = round(value / scale);
00083   if (quantized > INT_MAX) {
00084     quantized = INT_MAX;
00085   } else if (quantized < INT_MIN) {
00086     quantized = INT_MIN;
00087   }
00088 
00089   return static_cast<int>(quantized);
00090 }
00091 
00092 void AsymmetricQuantize(const float* input, int8_t* output, int num_elements,
00093                         float scale, int zero_point) {
00094   for (int i = 0; i < num_elements; i++) {
00095     output[i] = FloatToAsymmetricQuantizedInt8(input[i], scale, zero_point);
00096   }
00097 }
00098 
00099 void AsymmetricQuantize(const float* input, uint8_t* output, int num_elements,
00100                         float scale, int zero_point) {
00101   for (int i = 0; i < num_elements; i++) {
00102     output[i] = FloatToAsymmetricQuantizedUInt8(input[i], scale, zero_point);
00103   }
00104 }
00105 
00106 void SymmetricQuantize(const float* input, int32_t* output, int num_elements,
00107                        float scale) {
00108   for (int i = 0; i < num_elements; i++) {
00109     output[i] = FloatToSymmetricQuantizedInt32(input[i], scale);
00110   }
00111 }
00112 
00113 void SymmetricPerChannelQuantize(const float* input, int32_t* output,
00114                                  int num_elements, int num_channels,
00115                                  float* scales) {
00116   int elements_per_channel = num_elements / num_channels;
00117   for (int i = 0; i < num_channels; i++) {
00118     for (int j = 0; j < elements_per_channel; j++) {
00119       output[i * elements_per_channel + j] = FloatToSymmetricQuantizedInt32(
00120           input[i * elements_per_channel + j], scales[i]);
00121     }
00122   }
00123 }
00124 
00125 void SignedSymmetricPerChannelQuantize(const float* values,
00126                                        TfLiteIntArray* dims,
00127                                        int quantized_dimension,
00128                                        int8_t* quantized_values,
00129                                        float* scaling_factors) {
00130   int input_size = ElementCount(*dims);
00131   int channel_count = dims->data[quantized_dimension];
00132   int per_channel_size = input_size / channel_count;
00133   for (int channel = 0; channel < channel_count; channel++) {
00134     float min = 0;
00135     float max = 0;
00136     int stride = 1;
00137     for (int i = 0; i < quantized_dimension; i++) {
00138       stride *= dims->data[i];
00139     }
00140     int channel_stride = per_channel_size / stride;
00141     // Calculate scales for each channel.
00142     for (int i = 0; i < per_channel_size; i++) {
00143       int idx = channel * channel_stride + i * stride;
00144       min = fminf(min, values[idx]);
00145       max = fmaxf(max, values[idx]);
00146     }
00147     scaling_factors[channel] =
00148         fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale;
00149     for (int i = 0; i < per_channel_size; i++) {
00150       int idx = channel * channel_stride + i * stride;
00151       const int32_t quantized_value =
00152           static_cast<int32_t>(roundf(values[idx] / scaling_factors[channel]));
00153       // Clamp: just in case some odd numeric offset.
00154       quantized_values[idx] = fminf(
00155           kSymmetricInt8Scale, fmaxf(-kSymmetricInt8Scale, quantized_value));
00156     }
00157   }
00158 }
00159 
00160 void SignedSymmetricQuantize(const float* values, TfLiteIntArray* dims,
00161                              int8_t* quantized_values, float* scaling_factor) {
00162   int input_size = ElementCount(*dims);
00163 
00164   float min = 0;
00165   float max = 0;
00166   for (int i = 0; i < input_size; i++) {
00167     min = fminf(min, values[i]);
00168     max = fmaxf(max, values[i]);
00169   }
00170   *scaling_factor = fmaxf(fabs(min), fabs(max)) / kSymmetricInt8Scale;
00171   for (int i = 0; i < input_size; i++) {
00172     const int32_t quantized_value =
00173         static_cast<int32_t>(roundf(values[i] / *scaling_factor));
00174     // Clamp: just in case some odd numeric offset.
00175     quantized_values[i] = fminf(kSymmetricInt8Scale,
00176                                 fmaxf(-kSymmetricInt8Scale, quantized_value));
00177   }
00178 }
00179 
00180 void SymmetricQuantize(const float* values, TfLiteIntArray* dims,
00181                        uint8_t* quantized_values, float* scaling_factor) {
00182   SignedSymmetricQuantize(values, dims,
00183                           reinterpret_cast<int8_t*>(quantized_values),
00184                           scaling_factor);
00185 }
00186 
00187 void SymmetricDequantize(const int8_t* values, const int size,
00188                          const float dequantization_scale,
00189                          float* dequantized_values) {
00190   for (int i = 0; i < size; ++i) {
00191     dequantized_values[i] = values[i] * dequantization_scale;
00192   }
00193 }
00194 
00195 }  // namespace tflite