Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
strided_slice_logic.h
00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ 00017 #define TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_ 00018 00019 #include <limits> 00020 #include <vector> 00021 #include "tensorflow/lite/kernels/internal/compatibility.h" 00022 #include "tensorflow/lite/kernels/internal/types.h" 00023 00024 namespace tflite { 00025 namespace strided_slice { 00026 00027 // Use until std::clamp() is available from C++17. 00028 inline int Clamp(const int v, const int lo, const int hi) { 00029 TFLITE_DCHECK(!(hi < lo)); 00030 if (hi < v) return hi; 00031 if (v < lo) return lo; 00032 return v; 00033 } 00034 00035 inline void StridedSlicePadIndices(tflite::StridedSliceParams* p, 00036 int dim_count) { 00037 // Add indices and mask bits to fully include extra dimensions 00038 TFLITE_CHECK_LE(dim_count, 4); 00039 TFLITE_CHECK_GE(dim_count, p->start_indices_count); 00040 TFLITE_CHECK_EQ(p->start_indices_count, p->stop_indices_count); 00041 TFLITE_CHECK_EQ(p->stop_indices_count, p->strides_count); 00042 00043 const int pad_count = dim_count - p->start_indices_count; 00044 00045 // Pad indices at start, so move arrays by pad_count. 00046 for (int i = p->start_indices_count - 1; i >= 0; --i) { 00047 p->strides[i + pad_count] = p->strides[i]; 00048 p->start_indices[i + pad_count] = p->start_indices[i]; 00049 p->stop_indices[i + pad_count] = p->stop_indices[i]; 00050 } 00051 for (int i = 0; i < pad_count; ++i) { 00052 p->start_indices[i] = 0; 00053 p->stop_indices[i] = 1; 00054 p->strides[i] = 1; 00055 } 00056 00057 // Pad masks with 0s or 1s as required. 00058 p->shrink_axis_mask <<= pad_count; 00059 p->ellipsis_mask <<= pad_count; 00060 p->new_axis_mask <<= pad_count; 00061 p->begin_mask <<= pad_count; 00062 p->end_mask <<= pad_count; 00063 p->begin_mask |= (1 << pad_count) - 1; 00064 p->end_mask |= (1 << pad_count) - 1; 00065 00066 p->start_indices_count = dim_count; 00067 p->stop_indices_count = dim_count; 00068 p->strides_count = dim_count; 00069 } 00070 00071 // Return the index for the first element along that axis. This index will be a 00072 // positive integer between [0, axis_size - 1] that can be used to index 00073 // directly into the data. 00074 inline int StartForAxis(const tflite::StridedSliceParams& params, 00075 const RuntimeShape& input_shape, int axis) { 00076 const auto begin_mask = params.begin_mask; 00077 const auto* start_indices = params.start_indices; 00078 const auto* strides = params.strides; 00079 // Begin with the specified index. 00080 int start = start_indices[axis]; 00081 00082 // begin_mask override 00083 if (begin_mask & 1 << axis) { 00084 if (strides[axis] > 0) { 00085 // Forward iteration - use the first element. These values will get 00086 // clamped below (Note: We could have set them to 0 and axis_size-1, but 00087 // use lowest() and max() to maintain symmetry with StopForAxis()) 00088 start = std::numeric_limits<int>::lowest(); 00089 } else { 00090 // Backward iteration - use the last element. 00091 start = std::numeric_limits<int>::max(); 00092 } 00093 } 00094 00095 // Handle negative indices 00096 int axis_size = input_shape.Dims(axis); 00097 if (start < 0) { 00098 start += axis_size; 00099 } 00100 00101 // Clamping 00102 start = Clamp(start, 0, axis_size - 1); 00103 00104 return start; 00105 } 00106 00107 // Return the "real" index for the end of iteration along that axis. This is an 00108 // "end" in the traditional C sense, in that it points to one past the last 00109 // element. ie. So if you were iterating through all elements of a 1D array of 00110 // size 4, this function would return 4 as the stop, because it is one past the 00111 // "real" indices of 0, 1, 2 & 3. 00112 inline int StopForAxis(const tflite::StridedSliceParams& params, 00113 const RuntimeShape& input_shape, int axis, 00114 int start_for_axis) { 00115 const auto end_mask = params.end_mask; 00116 const auto shrink_axis_mask = params.shrink_axis_mask; 00117 const auto* stop_indices = params.stop_indices; 00118 const auto* strides = params.strides; 00119 00120 // Begin with the specified index 00121 const bool shrink_axis = shrink_axis_mask & (1 << axis); 00122 int stop = stop_indices[axis]; 00123 00124 // When shrinking an axis, the end position does not matter (and can be 00125 // incorrect when negative indexing is used, see Issue #19260). Always use 00126 // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has 00127 // already been adjusted for negative indices. 00128 if (shrink_axis) { 00129 stop = start_for_axis + 1; 00130 } 00131 00132 // end_mask override 00133 if (end_mask & (1 << axis)) { 00134 if (strides[axis] > 0) { 00135 // Forward iteration - use the last element. These values will get 00136 // clamped below 00137 stop = std::numeric_limits<int>::max(); 00138 } else { 00139 // Backward iteration - use the first element. 00140 stop = std::numeric_limits<int>::lowest(); 00141 } 00142 } 00143 00144 // Handle negative indices 00145 const int axis_size = input_shape.Dims(axis); 00146 if (stop < 0) { 00147 stop += axis_size; 00148 } 00149 00150 // Clamping 00151 // Because the end index points one past the last element, we need slightly 00152 // different clamping ranges depending on the direction. 00153 if (strides[axis] > 0) { 00154 // Forward iteration 00155 stop = Clamp(stop, 0, axis_size); 00156 } else { 00157 // Backward iteration 00158 stop = Clamp(stop, -1, axis_size - 1); 00159 } 00160 00161 return stop; 00162 } 00163 00164 inline bool LoopCondition(int index, int stop, int stride) { 00165 // True when we have reached the end of an axis and should loop. 00166 return stride > 0 ? index >= stop : index <= stop; 00167 } 00168 00169 inline tflite::StridedSliceParams BuildStridedSliceParams( 00170 int begin_mask, int end_mask, int shrink_axis_mask, 00171 const std::vector<int>& start_indices, const std::vector<int>& stop_indices, 00172 const std::vector<int>& strides) { 00173 tflite::StridedSliceParams op_params; 00174 const int dims_count = start_indices.size(); 00175 00176 op_params.start_indices_count = dims_count; 00177 op_params.stop_indices_count = dims_count; 00178 op_params.strides_count = dims_count; 00179 for (int i = 0; i < dims_count; ++i) { 00180 op_params.start_indices[i] = start_indices[i]; 00181 op_params.stop_indices[i] = stop_indices[i]; 00182 op_params.strides[i] = strides[i]; 00183 } 00184 00185 op_params.begin_mask = begin_mask; 00186 op_params.ellipsis_mask = 0; 00187 op_params.end_mask = end_mask; 00188 op_params.new_axis_mask = 0; 00189 op_params.shrink_axis_mask = shrink_axis_mask; 00190 00191 return op_params; 00192 } 00193 00194 } // namespace strided_slice 00195 00196 } // namespace tflite 00197 00198 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_STRIDED_SLICE_LOGIC_H_
Generated on Wed Jul 13 2022 16:03:36 by
