Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers quantize.cc Source File

quantize.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #include "tensorflow/lite/kernels/internal/reference/quantize.h"
00016 
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00020 #include "tensorflow/lite/kernels/kernel_util.h"
00021 
00022 namespace tflite {
00023 namespace ops {
00024 namespace micro {
00025 namespace quantize {
00026 
00027 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
00028   return nullptr;
00029 }
00030 
00031 void Free(TfLiteContext* context, void* buffer) {}
00032 
00033 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00034   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
00035   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00036 
00037   TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
00038   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
00039 
00040   // TODO(b/128934713): Add support for fixed-point per-channel quantization.
00041   // Currently this only support affine per-layer quantization.
00042   TF_LITE_ENSURE_EQ(context, output->quantization.type,
00043                     kTfLiteAffineQuantization);
00044   const auto* affine_quantization =
00045       reinterpret_cast<TfLiteAffineQuantization*>(output->quantization.params);
00046   TF_LITE_ENSURE(context, affine_quantization);
00047   TF_LITE_ENSURE(context, affine_quantization->scale);
00048   TF_LITE_ENSURE(context, affine_quantization->scale->size == 1);
00049 
00050   TF_LITE_ENSURE(context, input->type == kTfLiteFloat32);
00051   TF_LITE_ENSURE(context,
00052                  output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
00053 
00054   return kTfLiteOk;
00055 }
00056 
00057 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00058   TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
00059   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
00060 
00061   tflite::QuantizationParams op_params;
00062   op_params.zero_point = output->params.zero_point;
00063   op_params.scale = output->params.scale;
00064   switch (output->type) {
00065     case kTfLiteInt8:
00066       reference_ops::AffineQuantize(
00067           op_params, GetTensorShape(input), GetTensorData<float>(input),
00068           GetTensorShape(output), GetTensorData<int8_t>(output));
00069       break;
00070     case kTfLiteUInt8:
00071       reference_ops::AffineQuantize(
00072           op_params, GetTensorShape(input), GetTensorData<float>(input),
00073           GetTensorShape(output), GetTensorData<uint8_t>(output));
00074       break;
00075     default:
00076       context->ReportError(context, "Output type %s (%d) not supported",
00077                            TfLiteTypeGetName(input->type), output->type);
00078       return kTfLiteError;
00079   }
00080 
00081   return kTfLiteOk;
00082 }
00083 
00084 }  // namespace quantize
00085 
00086 // This Op (QUANTIZE) quantizes the input and produces quantized output.
00087 // AffineQuantize takes scale and zero point and quantizes the float value to
00088 // quantized output, in int8 or uint8 format.
00089 TfLiteRegistration* Register_QUANTIZE() {
00090   static TfLiteRegistration r = {quantize::Init, quantize::Free,
00091                                  quantize::Prepare, quantize::Eval};
00092   return &r;
00093 }
00094 
00095 }  // namespace micro
00096 }  // namespace ops
00097 }  // namespace tflite