Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers strided_slice_logic.h Source File

strided_slice_logic.h

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
00017 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
00018 
00019 #include <limits>
00020 #include <vector>
00021 #include "tensorflow/lite/kernels/internal/compatibility.h"
00022 #include "tensorflow/lite/kernels/internal/types.h"
00023 
00024 namespace tflite {
00025 namespace strided_slice {
00026 
00027 // Use until std::clamp() is available from C++17.
00028 inline int Clamp(const int v, const int lo, const int hi) {
00029   TFLITE_DCHECK(!(hi < lo));
00030   if (hi < v) return hi;
00031   if (v < lo) return lo;
00032   return v;
00033 }
00034 
00035 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p,
00036                                    int dim_count) {
00037   // Add indices and mask bits to fully include extra dimensions
00038   TFLITE_CHECK_LE(dim_count, 4);
00039   TFLITE_CHECK_GE(dim_count, p->start_indices_count);
00040   TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count);
00041   TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count);
00042 
00043   const int pad_count = dim_count - p->start_indices_count;
00044 
00045   // Pad indices at start, so move arrays by pad_count.
00046   for (int i = p->start_indices_count - 1; i >= 0; --i) {
00047     p->strides[i + pad_count] = p->strides[i];
00048     p->start_indices[i + pad_count] = p->start_indices[i];
00049     p->stop_indices[i + pad_count] = p->stop_indices[i];
00050   }
00051   for (int i = 0; i < pad_count; ++i) {
00052     p->start_indices[i] = 0;
00053     p->stop_indices[i] = 1;
00054     p->strides[i] = 1;
00055   }
00056 
00057   // Pad masks with 0s or 1s as required.
00058   p->shrink_axis_mask <<= pad_count;
00059   p->ellipsis_mask <<= pad_count;
00060   p->new_axis_mask <<= pad_count;
00061   p->begin_mask <<= pad_count;
00062   p->end_mask <<= pad_count;
00063   p->begin_mask |= (1 << pad_count) - 1;
00064   p->end_mask |= (1 << pad_count) - 1;
00065 
00066   p->start_indices_count = dim_count;
00067   p->stop_indices_count = dim_count;
00068   p->strides_count = dim_count;
00069 }
00070 
00071 // Return the index for the first element along that axis. This index will be a
00072 // positive integer between [0, axis_size - 1] that can be used to index
00073 // directly into the data.
00074 inline int StartForAxis(const tflite::StridedSliceParams& params,
00075                         const RuntimeShape& input_shape, int axis) {
00076   const auto begin_mask = params.begin_mask;
00077   const auto* start_indices = params.start_indices;
00078   const auto* strides = params.strides;
00079   // Begin with the specified index.
00080   int start = start_indices[axis];
00081 
00082   // begin_mask override
00083   if (begin_mask & 1 << axis) {
00084     if (strides[axis] > 0) {
00085       // Forward iteration - use the first element. These values will get
00086       // clamped below (Note: We could have set them to 0 and axis_size-1, but
00087       // use lowest() and max() to maintain symmetry with StopForAxis())
00088       start = std::numeric_limits<int>::lowest();
00089     } else {
00090       // Backward iteration - use the last element.
00091       start = std::numeric_limits<int>::max();
00092     }
00093   }
00094 
00095   // Handle negative indices
00096   int axis_size = input_shape.Dims(axis);
00097   if (start < 0) {
00098     start += axis_size;
00099   }
00100 
00101   // Clamping
00102   start = Clamp(start, 0, axis_size - 1);
00103 
00104   return start;
00105 }
00106 
00107 // Return the "real" index for the end of iteration along that axis. This is an
00108 // "end" in the traditional C sense, in that it points to one past the last
00109 // element. ie. So if you were iterating through all elements of a 1D array of
00110 // size 4, this function would return 4 as the stop, because it is one past the
00111 // "real" indices of 0, 1, 2 & 3.
00112 inline int StopForAxis(const tflite::StridedSliceParams& params,
00113                        const RuntimeShape& input_shape, int axis,
00114                        int start_for_axis) {
00115   const auto end_mask = params.end_mask;
00116   const auto shrink_axis_mask = params.shrink_axis_mask;
00117   const auto* stop_indices = params.stop_indices;
00118   const auto* strides = params.strides;
00119 
00120   // Begin with the specified index
00121   const bool shrink_axis = shrink_axis_mask & (1 << axis);
00122   int stop = stop_indices[axis];
00123 
00124   // When shrinking an axis, the end position does not matter (and can be
00125   // incorrect when negative indexing is used, see Issue #19260). Always use
00126   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
00127   // already been adjusted for negative indices.
00128   if (shrink_axis) {
00129     stop = start_for_axis + 1;
00130   }
00131 
00132   // end_mask override
00133   if (end_mask & (1 << axis)) {
00134     if (strides[axis] > 0) {
00135       // Forward iteration - use the last element. These values will get
00136       // clamped below
00137       stop = std::numeric_limits<int>::max();
00138     } else {
00139       // Backward iteration - use the first element.
00140       stop = std::numeric_limits<int>::lowest();
00141     }
00142   }
00143 
00144   // Handle negative indices
00145   const int axis_size = input_shape.Dims(axis);
00146   if (stop < 0) {
00147     stop += axis_size;
00148   }
00149 
00150   // Clamping
00151   // Because the end index points one past the last element, we need slightly
00152   // different clamping ranges depending on the direction.
00153   if (strides[axis] > 0) {
00154     // Forward iteration
00155     stop = Clamp(stop, 0, axis_size);
00156   } else {
00157     // Backward iteration
00158     stop = Clamp(stop, -1, axis_size - 1);
00159   }
00160 
00161   return stop;
00162 }
00163 
00164 inline bool LoopCondition(int index, int stop, int stride) {
00165   // True when we have reached the end of an axis and should loop.
00166   return stride > 0 ? index >= stop : index <= stop;
00167 }
00168 
00169 inline tflite::StridedSliceParams BuildStridedSliceParams(
00170     int begin_mask, int end_mask, int shrink_axis_mask,
00171     const std::vector<int>& start_indices, const std::vector<int>& stop_indices,
00172     const std::vector<int>& strides) {
00173   tflite::StridedSliceParams op_params;
00174   const int dims_count = start_indices.size();
00175 
00176   op_params.start_indices_count = dims_count;
00177   op_params.stop_indices_count = dims_count;
00178   op_params.strides_count = dims_count;
00179   for (int i = 0; i < dims_count; ++i) {
00180     op_params.start_indices[i] = start_indices[i];
00181     op_params.stop_indices[i] = stop_indices[i];
00182     op_params.strides[i] = strides[i];
00183   }
00184 
00185   op_params.begin_mask = begin_mask;
00186   op_params.ellipsis_mask = 0;
00187   op_params.end_mask = end_mask;
00188   op_params.new_axis_mask = 0;
00189   op_params.shrink_axis_mask = shrink_axis_mask;
00190 
00191   return op_params;
00192 }
00193 
00194 }  // namespace strided_slice
00195 
00196 }  // namespace tflite
00197 
00198 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_