Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers split.cc Source File

split.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/c/builtin_op_data.h"
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00019 #include "tensorflow/lite/kernels/kernel_util.h"
00020 
00021 namespace tflite {
00022 namespace ops {
00023 namespace micro {
00024 namespace split {
00025 
00026 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00027   return kTfLiteOk;
00028 }
00029 
00030 template <typename T>
00031 TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node,
00032                        const TfLiteTensor* input, int axis_value) {
00033   const int output_count = NumOutputs(node);
00034   const TfLiteIntArray* input_dims = input->dims;
00035   const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]];
00036   const TfLiteIntArray* output_dims = output0->dims;
00037 
00038   const int split_dimensions = input_dims->size;
00039   int axis = axis_value < 0 ? axis_value + split_dimensions : axis_value;
00040 
00041   TFLITE_DCHECK_LT(axis, split_dimensions);
00042   TFLITE_DCHECK_EQ(output_dims->size, split_dimensions);
00043 
00044   int64_t split_size = output_dims->data[axis] * output_count;
00045 
00046   TFLITE_DCHECK_EQ(split_size, input_dims->data[axis]);
00047   int64_t outer_size = 1;
00048   for (int i = 0; i < axis; ++i) {
00049     outer_size *= input_dims->data[i];
00050   }
00051 
00052   int64_t base_inner_size = 1;
00053   for (int i = axis + 1; i < split_dimensions; ++i) {
00054     base_inner_size *= input_dims->data[i];
00055   }
00056 
00057   const T* input_ptr = GetTensorData<T>(input);
00058   for (int k = 0; k < outer_size; ++k) {
00059     for (int i = 0; i < output_count; ++i) {
00060       TfLiteTensor* t = &context->tensors[node->outputs->data[i]];
00061       T* output_data = GetTensorData<T>(t);
00062       const int copy_size = output_dims->data[axis] * base_inner_size;
00063       T* output_ptr = output_data + k * copy_size;
00064       for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j];
00065       input_ptr += copy_size;
00066     }
00067   }
00068 
00069   return kTfLiteOk;
00070 }
00071 
00072 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00073   const TfLiteTensor* axis = GetInput(context, node, 0);
00074   const TfLiteTensor* input = GetInput(context, node, 1);
00075 
00076   // Dynamic output tensors are needed if axis tensor is not constant.
00077   // But Micro doesn't support dynamic memeory allocation, so we only support
00078   // constant axis tensor for now.
00079   TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis),
00080                      "Non constant axis tensor not supported");
00081 
00082   int axis_value = GetTensorData<int32_t>(axis)[0];
00083   if (axis_value < 0) {
00084     axis_value += NumDimensions(input);
00085   }
00086 
00087   TF_LITE_ENSURE(context, axis_value >= 0);
00088   TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
00089 
00090   switch (input->type) {
00091     case kTfLiteFloat32: {
00092       return SplitImpl<float>(context, node, input, axis_value);
00093     }
00094     case kTfLiteUInt8: {
00095       return SplitImpl<uint8_t>(context, node, input, axis_value);
00096     }
00097     case kTfLiteInt8: {
00098       return SplitImpl<int8_t>(context, node, input, axis_value);
00099     }
00100     case kTfLiteInt16: {
00101       return SplitImpl<int16_t>(context, node, input, axis_value);
00102     }
00103     case kTfLiteInt32: {
00104       return SplitImpl<int32_t>(context, node, input, axis_value);
00105     }
00106     default:
00107       context->ReportError(context, "Type %s currently not supported.",
00108                            TfLiteTypeGetName(input->type));
00109       return kTfLiteError;
00110   }
00111 #undef TF_LITE_SPLIT
00112 
00113   return kTfLiteOk;
00114 }
00115 
00116 }  // namespace split
00117 
00118 TfLiteRegistration* Register_SPLIT() {
00119   static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval};
00120   return &r;
00121 }
00122 
00123 }  // namespace micro
00124 }  // namespace ops
00125 }  // namespace tflite