Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
reshape.cc
00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/c/builtin_op_data.h" 00017 #include "tensorflow/lite/c/c_api_internal.h" 00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00019 #include "tensorflow/lite/kernels/kernel_util.h" 00020 #include "tensorflow/lite/kernels/op_macros.h" 00021 00022 namespace tflite { 00023 namespace ops { 00024 namespace micro { 00025 namespace reshape { 00026 00027 constexpr int kInputTensor = 0; 00028 constexpr int kShapeTensor = 1; 00029 constexpr int kOutputTensor = 0; 00030 00031 TfLiteStatus ReshapeOutput(TfLiteContext* context, TfLiteNode* node) { 00032 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00033 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00034 // Tensorflow's Reshape allows one of the shape components to have the 00035 // special -1 value, meaning it will be calculated automatically based on the 00036 // input. Here we calculate what that dimension should be so that the number 00037 // of output elements in the same as the number of input elements. 00038 int num_input_elements = NumElements(input); 00039 TfLiteIntArray* output_shape = output->dims; 00040 00041 if (NumInputs(node) == 1 && // Legacy scalar supported with params. 00042 output_shape->size == 1 && output_shape->data[0] == 0) { 00043 // Legacy tflite models use a shape parameter of [0] to indicate scalars, 00044 // so adjust accordingly. TODO(b/111614235): Allow zero-sized buffers during 00045 // toco conversion. 00046 output_shape->size = 0; 00047 } 00048 00049 int num_output_elements = 1; 00050 int stretch_dim = -1; 00051 for (int i = 0; i < output_shape->size; ++i) { 00052 int value = output_shape->data[i]; 00053 if (value == -1) { 00054 TF_LITE_ENSURE_EQ(context, stretch_dim, -1); 00055 stretch_dim = i; 00056 } else { 00057 num_output_elements *= value; 00058 } 00059 } 00060 if (stretch_dim != -1) { 00061 output_shape->data[stretch_dim] = num_input_elements / num_output_elements; 00062 num_output_elements *= output_shape->data[stretch_dim]; 00063 } 00064 00065 TF_LITE_ENSURE_EQ(context, input->type, output->type); 00066 TF_LITE_ENSURE_EQ(context, num_input_elements, num_output_elements); 00067 return kTfLiteOk; 00068 } 00069 00070 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00071 TF_LITE_ENSURE(context, NumInputs(node) == 1 || NumInputs(node) == 2); 00072 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 00073 return kTfLiteOk; 00074 } 00075 00076 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00077 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00078 TfLiteTensor* output = GetOutput(context, node, kOutputTensor); 00079 if (ReshapeOutput(context, node) != kTfLiteOk) { 00080 return kTfLiteError; 00081 } 00082 00083 for (int i = 0; i < input->bytes; ++i) { 00084 output->data.raw[i] = input->data.raw[i]; 00085 } 00086 return kTfLiteOk; 00087 } 00088 00089 } // namespace reshape 00090 00091 TfLiteRegistration* Register_RESHAPE() { 00092 static TfLiteRegistration r = {nullptr, nullptr, reshape::Prepare, 00093 reshape::Eval}; 00094 return &r; 00095 } 00096 00097 } // namespace micro 00098 } // namespace ops 00099 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2