Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers dequantize.cc Source File

dequantize.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/dequantize.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00021 #include "tensorflow/lite/kernels/kernel_util.h"
00022 
00023 namespace tflite {
00024 namespace ops {
00025 namespace micro {
00026 namespace dequantize {
00027 
00028 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00029   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
00030   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00031 
00032   // TODO(b/140515557): Add cached dequant to improve hybrid model performance.
00033   TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
00034   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
00035 
00036   TF_LITE_ENSURE(context,
00037                  input->type == kTfLiteUInt8 || input->type == kTfLiteInt8);
00038   TF_LITE_ENSURE(context, output->type == kTfLiteFloat32);
00039 
00040   return kTfLiteOk;
00041 }
00042 
00043 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00044   TfLiteTensor* input = &context->tensors[node->inputs->data[0]];
00045   TfLiteTensor* output = &context->tensors[node->outputs->data[0]];
00046 
00047   tflite::DequantizationParams op_params;
00048   op_params.zero_point = input->params.zero_point;
00049   op_params.scale = input->params.scale;
00050   switch (input->type) {
00051     case kTfLiteUInt8:
00052       reference_ops::Dequantize(
00053           op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
00054           GetTensorShape(output), GetTensorData<float>(output));
00055       break;
00056     case kTfLiteInt8:
00057       reference_ops::Dequantize(
00058           op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
00059           GetTensorShape(output), GetTensorData<float>(output));
00060       break;
00061     default:
00062       context->ReportError(context, "Type %s (%d) not supported.",
00063                            TfLiteTypeGetName(input->type), input->type);
00064       return kTfLiteError;
00065   }
00066 
00067   return kTfLiteOk;
00068 }
00069 
00070 }  // namespace dequantize
00071 
00072 TfLiteRegistration* Register_DEQUANTIZE() {
00073   static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare,
00074                                  dequantize::Eval};
00075   return &r;
00076 }
00077 
00078 }  // namespace micro
00079 }  // namespace ops
00080 }  // namespace tflite