Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers strided_slice.h Source File

strided_slice.h

00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/common.h"
00019 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
00020 #include "tensorflow/lite/kernels/internal/types.h"
00021 
00022 namespace tflite {
00023 
00024 namespace reference_ops {
00025 template <typename T>
00026 inline void StridedSlice(const tflite::StridedSliceParams& op_params,
00027                          const RuntimeShape& unextended_input_shape,
00028                          const T* input_data,
00029                          const RuntimeShape& unextended_output_shape,
00030                          T* output_data) {
00031   // Note that the output_shape is not used herein.
00032   tflite::StridedSliceParams params_copy = op_params;
00033 
00034   TFLITE_DCHECK_LE(unextended_input_shape.DimensionsCount(), 4);
00035   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
00036   const RuntimeShape input_shape =
00037       RuntimeShape::ExtendedShape(4, unextended_input_shape);
00038   const RuntimeShape output_shape =
00039       RuntimeShape::ExtendedShape(4, unextended_output_shape);
00040 
00041   // Reverse and pad to 4 dimensions because that is what the runtime code
00042   // requires (ie. all shapes must be 4D and are given backwards).
00043   strided_slice::StridedSlicePadIndices(&params_copy, 4);
00044 
00045   const int start_b = strided_slice::StartForAxis(params_copy, input_shape, 0);
00046   const int stop_b =
00047       strided_slice::StopForAxis(params_copy, input_shape, 0, start_b);
00048   const int start_h = strided_slice::StartForAxis(params_copy, input_shape, 1);
00049   const int stop_h =
00050       strided_slice::StopForAxis(params_copy, input_shape, 1, start_h);
00051   const int start_w = strided_slice::StartForAxis(params_copy, input_shape, 2);
00052   const int stop_w =
00053       strided_slice::StopForAxis(params_copy, input_shape, 2, start_w);
00054   const int start_d = strided_slice::StartForAxis(params_copy, input_shape, 3);
00055   const int stop_d =
00056       strided_slice::StopForAxis(params_copy, input_shape, 3, start_d);
00057 
00058   T* out_ptr = output_data;
00059   for (int in_b = start_b;
00060        !strided_slice::LoopCondition(in_b, stop_b, params_copy.strides[0]);
00061        in_b += params_copy.strides[0]) {
00062     for (int in_h = start_h;
00063          !strided_slice::LoopCondition(in_h, stop_h, params_copy.strides[1]);
00064          in_h += params_copy.strides[1]) {
00065       for (int in_w = start_w;
00066            !strided_slice::LoopCondition(in_w, stop_w, params_copy.strides[2]);
00067            in_w += params_copy.strides[2]) {
00068         for (int in_d = start_d; !strided_slice::LoopCondition(
00069                  in_d, stop_d, params_copy.strides[3]);
00070              in_d += params_copy.strides[3]) {
00071           *out_ptr++ = input_data[Offset(input_shape, in_b, in_h, in_w, in_d)];
00072         }
00073       }
00074     }
00075   }
00076 }
00077 }  // namespace reference_ops
00078 }  // namespace tflite
00079 
00080 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_STRIDED_SLICE_H_