Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers reshape.cc Source File

reshape.cc

00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/c/builtin_op_data.h"
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00019 #include "tensorflow/lite/kernels/kernel_util.h"
00020 #include "tensorflow/lite/kernels/op_macros.h"
00021 
00022 namespace tflite {
00023 namespace ops {
00024 namespace micro {
00025 namespace reshape {
00026 
00027 constexpr int kInputTensor = 0;
00028 constexpr int kShapeTensor = 1;
00029 constexpr int kOutputTensor = 0;
00030 
00031 TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) {
00032   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00033   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00034   // Tensorflow's Reshape allows one of the shape components to have the
00035   // special -1 value, meaning it will be calculated automatically based on the
00036   // input. Here we calculate what that dimension should be so that the number
00037   // of output elements in the same as the number of input elements.
00038   int num_input_elements = NumElements(input);
00039   TfLiteIntArray* output_shape = output->dims;
00040 
00041   if (NumInputs(node) == 1 &&  // Legacy scalar supported with params.
00042       output_shape->size == 1 && output_shape->data[0] == 0) {
00043     // Legacy tflite models use a shape parameter of [0] to indicate scalars,
00044     // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during
00045     // toco conversion.
00046     output_shape->size = 0;
00047   }
00048 
00049   int num_output_elements = 1;
00050   int stretch_dim = -1;
00051   for (int i = 0; i < output_shape->size; ++i) {
00052     int value = output_shape->data[i];
00053     if (value == -1) {
00054       TF_LITE_ENSURE_EQ(context, stretch_dim, -1);
00055       stretch_dim = i;
00056     } else {
00057       num_output_elements *= value;
00058     }
00059   }
00060   if (stretch_dim != -1) {
00061     output_shape->data[stretch_dim] = num_input_elements / num_output_elements;
00062     num_output_elements *= output_shape->data[stretch_dim];
00063   }
00064 
00065   TF_LITE_ENSURE_EQ(context, input->type, output->type);
00066   TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements);
00067   return kTfLiteOk;
00068 }
00069 
00070 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00071   TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2);
00072   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00073   return kTfLiteOk;
00074 }
00075 
00076 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00077   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00078   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00079   if (ReshapeOutput(context, node) != kTfLiteOk) {
00080     return kTfLiteError;
00081   }
00082 
00083   for (int i = 0; i < input->bytes; ++i) {
00084     output->data.raw[i] = input->data.raw[i];
00085   }
00086   return kTfLiteOk;
00087 }
00088 
00089 }  // namespace reshape
00090 
00091 TfLiteRegistration* Register_RESHAPE() {
00092   static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare,
00093                                  reshape::Eval};
00094   return &r;
00095 }
00096 
00097 }  // namespace micro
00098 }  // namespace ops
00099 }  // namespace tflite