Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers unpack.cc Source File

unpack.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/c/builtin_op_data.h"
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00019 #include "tensorflow/lite/kernels/kernel_util.h"
00020 
00021 namespace tflite {
00022 namespace ops {
00023 namespace micro {
00024 namespace unpack {
00025 namespace {
00026 
00027 constexpr int kInputTensor = 0;
00028 
00029 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00030   return kTfLiteOk;
00031 }
00032 
00033 template <typename T>
00034 TfLiteStatus UnpackImpl(TfLiteContext* context, TfLiteNode* node,
00035                         const TfLiteTensor* input, int output_count, int axis) {
00036   const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]];
00037   const TfLiteIntArray* input_dims = input->dims;
00038   const TfLiteIntArray* output_dims = output0->dims;
00039   const int dimensions = input_dims->size;
00040 
00041   if (axis < 0) {
00042     axis += NumDimensions(input);
00043   }
00044 
00045   TFLITE_DCHECK_LT(axis, dimensions);
00046 
00047   int outer_size = 1;
00048   for (int i = 0; i < axis; ++i) {
00049     outer_size *= input_dims->data[i];
00050   }
00051   int copy_size = 1;
00052   for (int i = axis + 1; i < dimensions; ++i) {
00053     copy_size *= input_dims->data[i];
00054   }
00055   int output_size = 1;
00056   for (int i = 0; i < output_dims->size; ++i) {
00057     output_size *= output_dims->data[i];
00058   }
00059   TFLITE_DCHECK_EQ(output_size, copy_size * outer_size);
00060 
00061   const T* input_data = GetTensorData<T>(input);
00062 
00063   for (int i = 0; i < output_count; ++i) {
00064     TfLiteTensor* t = &context->tensors[node->outputs->data[i]];
00065     T* output_data = GetTensorData<T>(t);
00066     for (int k = 0; k < outer_size; ++k) {
00067       T* output_ptr = output_data + copy_size * k;
00068       int loc = k * output_count * copy_size + i * copy_size;
00069       const T* input_ptr = input_data + loc;
00070       for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
00071     }
00072   }
00073 
00074   return kTfLiteOk;
00075 }
00076 
00077 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00078   TfLiteUnpackParams* data =
00079       reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data);
00080 
00081   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00082 
00083   switch (input->type) {
00084     case kTfLiteFloat32: {
00085       return UnpackImpl<float>(context, node, input, data->num, data->axis);
00086     }
00087     case kTfLiteInt32: {
00088       return UnpackImpl<int32_t>(context, node, input, data->num, data->axis);
00089     }
00090     case kTfLiteUInt8: {
00091       return UnpackImpl<uint8_t>(context, node, input, data->num, data->axis);
00092     }
00093     case kTfLiteInt8: {
00094       return UnpackImpl<int8_t>(context, node, input, data->num, data->axis);
00095     }
00096     default: {
00097       context->ReportError(context, "Type '%s' is not supported by unpack.",
00098                            TfLiteTypeGetName(input->type));
00099       return kTfLiteError;
00100     }
00101   }
00102 
00103   return kTfLiteOk;
00104 }
00105 }  // namespace
00106 }  // namespace unpack
00107 
00108 TfLiteRegistration* Register_UNPACK() {
00109   static TfLiteRegistration r = {nullptr, nullptr, unpack::Prepare,
00110                                  unpack::Eval};
00111   return &r;
00112 }
00113 
00114 }  // namespace micro
00115 }  // namespace ops
00116 }  // namespace tflite