Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers strided_slice.cc Source File

strided_slice.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h"
00016 
00017 #include <cmath>
00018 
00019 #include "tensorflow/lite/c/builtin_op_data.h"
00020 #include "tensorflow/lite/c/c_api_internal.h"
00021 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00022 #include "tensorflow/lite/kernels/kernel_util.h"
00023 #include "tensorflow/lite/kernels/op_macros.h"
00024 
00025 namespace tflite {
00026 namespace ops {
00027 namespace micro {
00028 namespace strided_slice {
00029 
00030 enum KernelType {
00031   kReference,
00032   // TODO(soroosh): add kGenericOptimized
00033 };
00034 
00035 constexpr int kInputTensor = 0;
00036 constexpr int kBeginTensor = 1;
00037 constexpr int kEndTensor = 2;
00038 constexpr int kStridesTensor = 3;
00039 constexpr int kOutputTensor = 0;
00040 
00041 struct StridedSliceContext {
00042   StridedSliceContext(TfLiteContext* context, TfLiteNode* node) {
00043     params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data);
00044     input = GetInput(context, node, kInputTensor);
00045     begin = GetInput(context, node, kBeginTensor);
00046     end = GetInput(context, node, kEndTensor);
00047     strides = GetInput(context, node, kStridesTensor);
00048     output = GetOutput(context, node, kOutputTensor);
00049     dims = NumDimensions(input);
00050   }
00051   const TfLiteStridedSliceParams* params;
00052   const TfLiteTensor* input;
00053   const TfLiteTensor* begin;
00054   const TfLiteTensor* end;
00055   const TfLiteTensor* strides;
00056   TfLiteTensor* output;
00057   int dims;
00058 };
00059 
00060 // This Op only supports 1-4D cases and since we use the reference 4D
00061 // implementation, the 1-3D tensors are mapped to 4D.
00062 const int kMaxDim = 4;
00063 
00064 tflite::StridedSliceParams BuildStridedSliceParams(
00065     StridedSliceContext* op_context) {
00066   tflite::StridedSliceParams op_params;
00067   op_params.start_indices_count = op_context->dims;
00068   op_params.stop_indices_count = op_context->dims;
00069   op_params.strides_count = op_context->dims;
00070 
00071   for (int i = 0; i < op_context->dims; ++i) {
00072     op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i];
00073     op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i];
00074     op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i];
00075   }
00076 
00077   op_params.begin_mask = op_context->params->begin_mask;
00078   op_params.ellipsis_mask = 0;
00079   op_params.end_mask = op_context->params->end_mask;
00080   op_params.new_axis_mask = 0;
00081   op_params.shrink_axis_mask = op_context->params->shrink_axis_mask;
00082   return op_params;
00083 }
00084 
00085 // Processes the indexing tensors (begin, end and strides) to resize the
00086 // output tensor. This function is callable from both Prepare() and Eval() as
00087 // long as the caller ensures the indexing tensors are present.
00088 TfLiteStatus CheckOutputSize(TfLiteContext* context,
00089                              StridedSliceContext* op_context) {
00090   using ::tflite::strided_slice::StartForAxis;
00091   using ::tflite::strided_slice::StopForAxis;
00092   TfLiteIntArray* output_shape = op_context->output->dims;
00093   int shape_size = 0;
00094   auto op_params = BuildStridedSliceParams(op_context);
00095   auto input_shape = GetTensorShape(op_context->input);
00096   for (int idx = 0; idx < op_context->dims; ++idx) {
00097     int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx];
00098     TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero");
00099     int32_t begin = StartForAxis(op_params, input_shape, idx);
00100     int32_t end = StopForAxis(op_params, input_shape, idx, begin);
00101 
00102     // When shrinking an axis, the end position does not matter (and can be
00103     // incorrect when negative indexing is used, see Issue #19260). Always use
00104     // begin + 1 to generate a length 1 slice, since begin has
00105     // already been adjusted for negative indices by StartForAxis.
00106     const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx);
00107     if (shrink_axis) {
00108       end = begin + 1;
00109     }
00110 
00111     // This is valid for both positive and negative strides
00112     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
00113     dim_shape = dim_shape < 0 ? 0 : dim_shape;
00114     if (!shrink_axis) {
00115       TF_LITE_ENSURE_EQ(context, output_shape->data[shape_size], dim_shape);
00116       shape_size++;
00117     }
00118   }
00119   TF_LITE_ENSURE_EQ(context, output_shape->size, shape_size);
00120   return kTfLiteOk;
00121 }
00122 
00123 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00124   TF_LITE_ENSURE_EQ(context, NumInputs(node), 4);
00125   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00126   StridedSliceContext op_context(context, node);
00127   TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim,
00128                      "input dim should not exceed 4");
00129   return CheckOutputSize(context, &op_context);
00130 }
00131 
00132 template <KernelType kernel_type>
00133 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00134   StridedSliceContext op_context(context, node);
00135   auto op_params = BuildStridedSliceParams(&op_context);
00136 
00137 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type)                    \
00138   kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \
00139                             GetTensorData<data_type>(op_context.input),  \
00140                             GetTensorShape(op_context.output),           \
00141                             GetTensorData<data_type>(op_context.output))
00142 
00143   switch (op_context.input->type) {
00144     case kTfLiteFloat32:
00145       if (kernel_type == kReference) {
00146         TF_LITE_STRIDED_SLICE(reference_ops, float);
00147       }
00148       break;
00149     case kTfLiteUInt8:
00150       if (kernel_type == kReference) {
00151         TF_LITE_STRIDED_SLICE(reference_ops, uint8_t);
00152       }
00153       break;
00154     case kTfLiteInt8:
00155       if (kernel_type == kReference) {
00156         TF_LITE_STRIDED_SLICE(reference_ops, int8_t);
00157       }
00158       break;
00159     default:
00160       context->ReportError(context,
00161                            "Type %d is currently not supported "
00162                            "by StridedSlice.",
00163                            op_context.input->type);
00164       return kTfLiteError;
00165   }
00166 #undef TF_LITE_STRIDED_SLICE
00167   return kTfLiteOk;
00168 }
00169 }  // namespace strided_slice
00170 
00171 TfLiteRegistration* Register_STRIDED_SLICE() {
00172   static TfLiteRegistration r = {
00173       nullptr, nullptr, strided_slice::Prepare,
00174       strided_slice::Eval<strided_slice::kReference>};
00175   return &r;
00176 }
00177 
00178 }  // namespace micro
00179 }  // namespace ops
00180 }  // namespace tflite