Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
strided_slice.cc
00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #include "tensorflow/lite/kernels/internal/reference/strided_slice.h" 00016 00017 #include <cmath> 00018 00019 #include "tensorflow/lite/c/builtin_op_data.h" 00020 #include "tensorflow/lite/c/c_api_internal.h" 00021 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00022 #include "tensorflow/lite/kernels/kernel_util.h" 00023 #include "tensorflow/lite/kernels/op_macros.h" 00024 00025 namespace tflite { 00026 namespace ops { 00027 namespace micro { 00028 namespace strided_slice { 00029 00030 enum KernelType { 00031 kReference, 00032 // TODO(soroosh): add kGenericOptimized 00033 }; 00034 00035 constexpr int kInputTensor = 0; 00036 constexpr int kBeginTensor = 1; 00037 constexpr int kEndTensor = 2; 00038 constexpr int kStridesTensor = 3; 00039 constexpr int kOutputTensor = 0; 00040 00041 struct StridedSliceContext { 00042 StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { 00043 params = reinterpret_cast<TfLiteStridedSliceParams*>(node->builtin_data); 00044 input = GetInput(context, node, kInputTensor); 00045 begin = GetInput(context, node, kBeginTensor); 00046 end = GetInput(context, node, kEndTensor); 00047 strides = GetInput(context, node, kStridesTensor); 00048 output = GetOutput(context, node, kOutputTensor); 00049 dims = NumDimensions(input); 00050 } 00051 const TfLiteStridedSliceParams* params; 00052 const TfLiteTensor* input; 00053 const TfLiteTensor* begin; 00054 const TfLiteTensor* end; 00055 const TfLiteTensor* strides; 00056 TfLiteTensor* output; 00057 int dims; 00058 }; 00059 00060 // This Op only supports 1-4D cases and since we use the reference 4D 00061 // implementation, the 1-3D tensors are mapped to 4D. 00062 const int kMaxDim = 4; 00063 00064 tflite::StridedSliceParams BuildStridedSliceParams( 00065 StridedSliceContext* op_context) { 00066 tflite::StridedSliceParams op_params; 00067 op_params.start_indices_count = op_context->dims; 00068 op_params.stop_indices_count = op_context->dims; 00069 op_params.strides_count = op_context->dims; 00070 00071 for (int i = 0; i < op_context->dims; ++i) { 00072 op_params.start_indices[i] = GetTensorData<int32_t>(op_context->begin)[i]; 00073 op_params.stop_indices[i] = GetTensorData<int32_t>(op_context->end)[i]; 00074 op_params.strides[i] = GetTensorData<int32_t>(op_context->strides)[i]; 00075 } 00076 00077 op_params.begin_mask = op_context->params->begin_mask; 00078 op_params.ellipsis_mask = 0; 00079 op_params.end_mask = op_context->params->end_mask; 00080 op_params.new_axis_mask = 0; 00081 op_params.shrink_axis_mask = op_context->params->shrink_axis_mask; 00082 return op_params; 00083 } 00084 00085 // Processes the indexing tensors (begin, end and strides) to resize the 00086 // output tensor. This function is callable from both Prepare() and Eval() as 00087 // long as the caller ensures the indexing tensors are present. 00088 TfLiteStatus CheckOutputSize(TfLiteContext* context, 00089 StridedSliceContext* op_context) { 00090 using ::tflite::strided_slice::StartForAxis; 00091 using ::tflite::strided_slice::StopForAxis; 00092 TfLiteIntArray* output_shape = op_context->output->dims; 00093 int shape_size = 0; 00094 auto op_params = BuildStridedSliceParams(op_context); 00095 auto input_shape = GetTensorShape(op_context->input); 00096 for (int idx = 0; idx < op_context->dims; ++idx) { 00097 int32_t stride = GetTensorData<int32_t>(op_context->strides)[idx]; 00098 TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); 00099 int32_t begin = StartForAxis(op_params, input_shape, idx); 00100 int32_t end = StopForAxis(op_params, input_shape, idx, begin); 00101 00102 // When shrinking an axis, the end position does not matter (and can be 00103 // incorrect when negative indexing is used, see Issue #19260). Always use 00104 // begin + 1 to generate a length 1 slice, since begin has 00105 // already been adjusted for negative indices by StartForAxis. 00106 const bool shrink_axis = op_context->params->shrink_axis_mask & (1 << idx); 00107 if (shrink_axis) { 00108 end = begin + 1; 00109 } 00110 00111 // This is valid for both positive and negative strides 00112 int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride)); 00113 dim_shape = dim_shape < 0 ? 0 : dim_shape; 00114 if (!shrink_axis) { 00115 TF_LITE_ENSURE_EQ(context, output_shape->data[shape_size], dim_shape); 00116 shape_size++; 00117 } 00118 } 00119 TF_LITE_ENSURE_EQ(context, output_shape->size, shape_size); 00120 return kTfLiteOk; 00121 } 00122 00123 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00124 TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); 00125 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 00126 StridedSliceContext op_context(context, node); 00127 TF_LITE_ENSURE_MSG(context, op_context.dims <= kMaxDim, 00128 "input dim should not exceed 4"); 00129 return CheckOutputSize(context, &op_context); 00130 } 00131 00132 template <KernelType kernel_type> 00133 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00134 StridedSliceContext op_context(context, node); 00135 auto op_params = BuildStridedSliceParams(&op_context); 00136 00137 #define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ 00138 kernel_type::StridedSlice(op_params, GetTensorShape(op_context.input), \ 00139 GetTensorData<data_type>(op_context.input), \ 00140 GetTensorShape(op_context.output), \ 00141 GetTensorData<data_type>(op_context.output)) 00142 00143 switch (op_context.input->type) { 00144 case kTfLiteFloat32: 00145 if (kernel_type == kReference) { 00146 TF_LITE_STRIDED_SLICE(reference_ops, float); 00147 } 00148 break; 00149 case kTfLiteUInt8: 00150 if (kernel_type == kReference) { 00151 TF_LITE_STRIDED_SLICE(reference_ops, uint8_t); 00152 } 00153 break; 00154 case kTfLiteInt8: 00155 if (kernel_type == kReference) { 00156 TF_LITE_STRIDED_SLICE(reference_ops, int8_t); 00157 } 00158 break; 00159 default: 00160 context->ReportError(context, 00161 "Type %d is currently not supported " 00162 "by StridedSlice.", 00163 op_context.input->type); 00164 return kTfLiteError; 00165 } 00166 #undef TF_LITE_STRIDED_SLICE 00167 return kTfLiteOk; 00168 } 00169 } // namespace strided_slice 00170 00171 TfLiteRegistration* Register_STRIDED_SLICE() { 00172 static TfLiteRegistration r = { 00173 nullptr, nullptr, strided_slice::Prepare, 00174 strided_slice::Eval<strided_slice::kReference>}; 00175 return &r; 00176 } 00177 00178 } // namespace micro 00179 } // namespace ops 00180 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:36 by
1.7.2