Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
split.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/c/builtin_op_data.h" 00017 #include "tensorflow/lite/c/c_api_internal.h" 00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00019 #include "tensorflow/lite/kernels/kernel_util.h" 00020 00021 namespace tflite { 00022 namespace ops { 00023 namespace micro { 00024 namespace split { 00025 00026 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00027 return kTfLiteOk; 00028 } 00029 00030 template <typename T> 00031 TfLiteStatus SplitImpl(TfLiteContext* context, TfLiteNode* node, 00032 const TfLiteTensor* input, int axis_value) { 00033 const int output_count = NumOutputs(node); 00034 const TfLiteIntArray* input_dims = input->dims; 00035 const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]]; 00036 const TfLiteIntArray* output_dims = output0->dims; 00037 00038 const int split_dimensions = input_dims->size; 00039 int axis = axis_value < 0 ? axis_value + split_dimensions : axis_value; 00040 00041 TFLITE_DCHECK_LT(axis, split_dimensions); 00042 TFLITE_DCHECK_EQ(output_dims->size, split_dimensions); 00043 00044 int64_t split_size = output_dims->data[axis] * output_count; 00045 00046 TFLITE_DCHECK_EQ(split_size, input_dims->data[axis]); 00047 int64_t outer_size = 1; 00048 for (int i = 0; i < axis; ++i) { 00049 outer_size *= input_dims->data[i]; 00050 } 00051 00052 int64_t base_inner_size = 1; 00053 for (int i = axis + 1; i < split_dimensions; ++i) { 00054 base_inner_size *= input_dims->data[i]; 00055 } 00056 00057 const T* input_ptr = GetTensorData<T>(input); 00058 for (int k = 0; k < outer_size; ++k) { 00059 for (int i = 0; i < output_count; ++i) { 00060 TfLiteTensor* t = &context->tensors[node->outputs->data[i]]; 00061 T* output_data = GetTensorData<T>(t); 00062 const int copy_size = output_dims->data[axis] * base_inner_size; 00063 T* output_ptr = output_data + k * copy_size; 00064 for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j]; 00065 input_ptr += copy_size; 00066 } 00067 } 00068 00069 return kTfLiteOk; 00070 } 00071 00072 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00073 const TfLiteTensor* axis = GetInput(context, node, 0); 00074 const TfLiteTensor* input = GetInput(context, node, 1); 00075 00076 // Dynamic output tensors are needed if axis tensor is not constant. 00077 // But Micro doesn't support dynamic memeory allocation, so we only support 00078 // constant axis tensor for now. 00079 TF_LITE_ENSURE_MSG(context, IsConstantTensor(axis), 00080 "Non constant axis tensor not supported"); 00081 00082 int axis_value = GetTensorData<int32_t>(axis)[0]; 00083 if (axis_value < 0) { 00084 axis_value += NumDimensions(input); 00085 } 00086 00087 TF_LITE_ENSURE(context, axis_value >= 0); 00088 TF_LITE_ENSURE(context, axis_value < NumDimensions(input)); 00089 00090 switch (input->type) { 00091 case kTfLiteFloat32: { 00092 return SplitImpl<float>(context, node, input, axis_value); 00093 } 00094 case kTfLiteUInt8: { 00095 return SplitImpl<uint8_t>(context, node, input, axis_value); 00096 } 00097 case kTfLiteInt8: { 00098 return SplitImpl<int8_t>(context, node, input, axis_value); 00099 } 00100 case kTfLiteInt16: { 00101 return SplitImpl<int16_t>(context, node, input, axis_value); 00102 } 00103 case kTfLiteInt32: { 00104 return SplitImpl<int32_t>(context, node, input, axis_value); 00105 } 00106 default: 00107 context->ReportError(context, "Type %s currently not supported.", 00108 TfLiteTypeGetName(input->type)); 00109 return kTfLiteError; 00110 } 00111 #undef TF_LITE_SPLIT 00112 00113 return kTfLiteOk; 00114 } 00115 00116 } // namespace split 00117 00118 TfLiteRegistration* Register_SPLIT() { 00119 static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval}; 00120 return &r; 00121 } 00122 00123 } // namespace micro 00124 } // namespace ops 00125 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:36 by
