Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers pack.cc Source File

pack.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/c/builtin_op_data.h"
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00019 #include "tensorflow/lite/kernels/kernel_util.h"
00020 
00021 namespace tflite {
00022 namespace ops {
00023 namespace micro {
00024 namespace pack {
00025 namespace {
00026 
00027 constexpr int kOutputTensor = 0;
00028 
00029 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00030   return kTfLiteOk;
00031 }
00032 
00033 template <typename T>
00034 TfLiteStatus PackImpl(TfLiteContext* context, TfLiteNode* node,
00035                       TfLiteTensor* output, int values_count, int axis) {
00036   const int dimensions = output->dims->size;
00037   const TfLiteTensor* input0 = &context->tensors[node->inputs->data[0]];
00038   const TfLiteIntArray* input_dims = input0->dims;
00039   const TfLiteIntArray* output_dims = output->dims;
00040 
00041   if (axis < 0) {
00042     axis += dimensions;
00043   }
00044 
00045   int outer_size = 1;
00046   for (int i = 0; i < axis; ++i) {
00047     outer_size *= output_dims->data[i];
00048   }
00049   int copy_size = 1;
00050   for (int i = axis + 1; i < dimensions; ++i) {
00051     copy_size *= output_dims->data[i];
00052   }
00053   int input_size = 1;
00054   for (int i = 0; i < input_dims->size; ++i) {
00055     input_size *= input_dims->data[i];
00056   }
00057   TFLITE_DCHECK_EQ(input_size, copy_size * outer_size);
00058 
00059   T* output_data = GetTensorData<T>(output);
00060 
00061   for (int i = 0; i < values_count; ++i) {
00062     TfLiteTensor* t = &context->tensors[node->inputs->data[i]];
00063     const T* input_data = GetTensorData<T>(t);
00064     for (int k = 0; k < outer_size; ++k) {
00065       const T* input_ptr = input_data + copy_size * k;
00066       int loc = k * values_count * copy_size + i * copy_size;
00067       T* output_ptr = output_data + loc;
00068       for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
00069     }
00070   }
00071 
00072   return kTfLiteOk;
00073 }
00074 
00075 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00076   const TfLitePackParams* data =
00077       reinterpret_cast<TfLitePackParams*>(node->builtin_data);
00078 
00079   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00080 
00081   switch (output->type) {
00082     case kTfLiteFloat32: {
00083       return PackImpl<float>(context, node, output, data->values_count,
00084                              data->axis);
00085     }
00086     case kTfLiteUInt8: {
00087       return PackImpl<uint8_t>(context, node, output, data->values_count,
00088                                data->axis);
00089     }
00090     case kTfLiteInt8: {
00091       return PackImpl<int8_t>(context, node, output, data->values_count,
00092                               data->axis);
00093     }
00094     case kTfLiteInt32: {
00095       return PackImpl<int32_t>(context, node, output, data->values_count,
00096                                data->axis);
00097     }
00098     case kTfLiteInt64: {
00099       return PackImpl<int64_t>(context, node, output, data->values_count,
00100                                data->axis);
00101     }
00102     default: {
00103       context->ReportError(context, "Type '%s' is not supported by pack.",
00104                            TfLiteTypeGetName(output->type));
00105       return kTfLiteError;
00106     }
00107   }
00108 
00109   return kTfLiteOk;
00110 }
00111 
00112 }  // namespace
00113 }  // namespace pack
00114 
00115 TfLiteRegistration* Register_PACK() {
00116   static TfLiteRegistration r = {nullptr, nullptr, pack::Prepare, pack::Eval};
00117   return &r;
00118 }
00119 
00120 }  // namespace micro
00121 }  // namespace ops
00122 }  // namespace tflite