Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
unpack.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/c/builtin_op_data.h" 00017 #include "tensorflow/lite/c/c_api_internal.h" 00018 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00019 #include "tensorflow/lite/kernels/kernel_util.h" 00020 00021 namespace tflite { 00022 namespace ops { 00023 namespace micro { 00024 namespace unpack { 00025 namespace { 00026 00027 constexpr int kInputTensor = 0; 00028 00029 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { 00030 return kTfLiteOk; 00031 } 00032 00033 template <typename T> 00034 TfLiteStatus UnpackImpl(TfLiteContext* context, TfLiteNode* node, 00035 const TfLiteTensor* input, int output_count, int axis) { 00036 const TfLiteTensor* output0 = &context->tensors[node->outputs->data[0]]; 00037 const TfLiteIntArray* input_dims = input->dims; 00038 const TfLiteIntArray* output_dims = output0->dims; 00039 const int dimensions = input_dims->size; 00040 00041 if (axis < 0) { 00042 axis += NumDimensions(input); 00043 } 00044 00045 TFLITE_DCHECK_LT(axis, dimensions); 00046 00047 int outer_size = 1; 00048 for (int i = 0; i < axis; ++i) { 00049 outer_size *= input_dims->data[i]; 00050 } 00051 int copy_size = 1; 00052 for (int i = axis + 1; i < dimensions; ++i) { 00053 copy_size *= input_dims->data[i]; 00054 } 00055 int output_size = 1; 00056 for (int i = 0; i < output_dims->size; ++i) { 00057 output_size *= output_dims->data[i]; 00058 } 00059 TFLITE_DCHECK_EQ(output_size, copy_size * outer_size); 00060 00061 const T* input_data = GetTensorData<T>(input); 00062 00063 for (int i = 0; i < output_count; ++i) { 00064 TfLiteTensor* t = &context->tensors[node->outputs->data[i]]; 00065 T* output_data = GetTensorData<T>(t); 00066 for (int k = 0; k < outer_size; ++k) { 00067 T* output_ptr = output_data + copy_size * k; 00068 int loc = k * output_count * copy_size + i * copy_size; 00069 const T* input_ptr = input_data + loc; 00070 for (int j = 0; j < copy_size; ++j) output_ptr[j] = input_ptr[j]; 00071 } 00072 } 00073 00074 return kTfLiteOk; 00075 } 00076 00077 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { 00078 TfLiteUnpackParams* data = 00079 reinterpret_cast<TfLiteUnpackParams*>(node->builtin_data); 00080 00081 const TfLiteTensor* input = GetInput(context, node, kInputTensor); 00082 00083 switch (input->type) { 00084 case kTfLiteFloat32: { 00085 return UnpackImpl<float>(context, node, input, data->num, data->axis); 00086 } 00087 case kTfLiteInt32: { 00088 return UnpackImpl<int32_t>(context, node, input, data->num, data->axis); 00089 } 00090 case kTfLiteUInt8: { 00091 return UnpackImpl<uint8_t>(context, node, input, data->num, data->axis); 00092 } 00093 case kTfLiteInt8: { 00094 return UnpackImpl<int8_t>(context, node, input, data->num, data->axis); 00095 } 00096 default: { 00097 context->ReportError(context, "Type '%s' is not supported by unpack.", 00098 TfLiteTypeGetName(input->type)); 00099 return kTfLiteError; 00100 } 00101 } 00102 00103 return kTfLiteOk; 00104 } 00105 } // namespace 00106 } // namespace unpack 00107 00108 TfLiteRegistration* Register_UNPACK() { 00109 static TfLiteRegistration r = {nullptr, nullptr, unpack::Prepare, 00110 unpack::Eval}; 00111 return &r; 00112 } 00113 00114 } // namespace micro 00115 } // namespace ops 00116 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:36 by
