Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
micro_interpreter.cc
00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #include "tensorflow/lite/experimental/micro/micro_interpreter.h" 00016 00017 #include "tensorflow/lite/c/c_api_internal.h" 00018 #include "tensorflow/lite/core/api/flatbuffer_conversions.h" 00019 #include "tensorflow/lite/experimental/micro/compatibility.h" 00020 00021 namespace tflite { 00022 namespace { 00023 const int kStackDataAllocatorSize = 128; 00024 class StackDataAllocator : public BuiltinDataAllocator { 00025 public: 00026 void* Allocate(size_t size) override { 00027 if (size > kStackDataAllocatorSize) { 00028 return nullptr; 00029 } else { 00030 return data_; 00031 } 00032 } 00033 void Deallocate(void* data) override { 00034 // Do nothing. 00035 } 00036 00037 private: 00038 uint8_t data_[kStackDataAllocatorSize]; 00039 00040 TF_LITE_REMOVE_VIRTUAL_DELETE 00041 }; 00042 00043 const char* OpNameFromRegistration(const TfLiteRegistration* registration) { 00044 if (registration->builtin_code == BuiltinOperator_CUSTOM) { 00045 return registration->custom_name; 00046 } else { 00047 return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code)); 00048 } 00049 } 00050 00051 void ReportOpError(struct TfLiteContext* context, const char* format, ...) { 00052 MicroInterpreter* interpreter = 00053 static_cast<MicroInterpreter*>(context->impl_); 00054 va_list args; 00055 va_start(args, format); 00056 interpreter->error_reporter()->Report(format, args); 00057 va_end(args); 00058 } 00059 00060 } // namespace 00061 00062 MicroInterpreter::MicroInterpreter(const Model* model, 00063 const OpResolver& op_resolver, 00064 uint8_t* tensor_arena, 00065 size_t tensor_arena_size, 00066 ErrorReporter* error_reporter) 00067 : model_(model), 00068 op_resolver_(op_resolver), 00069 error_reporter_(error_reporter), 00070 context_(), 00071 allocator_(&context_, model_, tensor_arena, tensor_arena_size, 00072 error_reporter_), 00073 tensors_allocated_(false) { 00074 auto* subgraphs = model->subgraphs(); 00075 if (subgraphs->size() != 1) { 00076 error_reporter->Report("Only 1 subgraph is currently supported.\n"); 00077 initialization_status_ = kTfLiteError; 00078 return; 00079 } 00080 subgraph_ = (*subgraphs)[0]; 00081 tensors_ = subgraph_->tensors(); 00082 operators_ = subgraph_->operators(); 00083 00084 context_.impl_ = static_cast<void*>(this); 00085 context_.ReportError = ReportOpError; 00086 context_.recommended_num_threads = 1; 00087 00088 // If the system is big endian then convert weights from the flatbuffer from 00089 // little to big endian on startup so that it does not need to be done during 00090 // inference. 00091 // NOTE: This requires that the flatbuffer is held in memory which can be 00092 // modified by this process. 00093 if (!FLATBUFFERS_LITTLEENDIAN) { 00094 for (int t = 0; t < tensors_size(); ++t) { 00095 TfLiteTensor* thisTensor = &context_.tensors[t]; 00096 if (thisTensor->allocation_type == kTfLiteMmapRo) 00097 CorrectTensorEndianness(thisTensor); 00098 } 00099 } 00100 00101 initialization_status_ = kTfLiteOk; 00102 } 00103 00104 void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) { 00105 int32_t tensorSize = 1; 00106 for (int d = 0; d < tensorCorr->dims->size; ++d) 00107 tensorSize *= reinterpret_cast<const int32_t*>(tensorCorr->dims->data)[d]; 00108 00109 switch (tensorCorr->type) { 00110 case TfLiteType::kTfLiteFloat32: 00111 CorrectTensorDataEndianness(tensorCorr->data.f, tensorSize); 00112 break; 00113 case TfLiteType::kTfLiteFloat16: 00114 CorrectTensorDataEndianness(tensorCorr->data.f16, tensorSize); 00115 break; 00116 case TfLiteType::kTfLiteInt64: 00117 CorrectTensorDataEndianness(tensorCorr->data.i64, tensorSize); 00118 break; 00119 case TfLiteType::kTfLiteInt32: 00120 CorrectTensorDataEndianness(tensorCorr->data.i32, tensorSize); 00121 break; 00122 case TfLiteType::kTfLiteInt16: 00123 CorrectTensorDataEndianness(tensorCorr->data.i16, tensorSize); 00124 break; 00125 case TfLiteType::kTfLiteComplex64: 00126 CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize); 00127 break; 00128 default: 00129 // Do nothing for other data types. 00130 break; 00131 } 00132 } 00133 00134 template <class T> 00135 void MicroInterpreter::CorrectTensorDataEndianness(T* data, int32_t size) { 00136 for (int32_t i = 0; i < size; ++i) { 00137 data[i] = flatbuffers::EndianScalar(data[i]); 00138 } 00139 } 00140 00141 TfLiteStatus MicroInterpreter::RegisterPreallocatedInput(uint8_t* buffer, 00142 size_t input_index) { 00143 return allocator_.RegisterPreallocatedInput(buffer, input_index); 00144 } 00145 00146 TfLiteStatus MicroInterpreter::AllocateTensors() { 00147 TF_LITE_ENSURE_OK(&context_, allocator_.AllocateNodeAndRegistrations( 00148 op_resolver_, &node_and_registrations_)); 00149 TF_LITE_ENSURE_OK(&context_, allocator_.FinishTensorAllocation()); 00150 00151 tensors_allocated_ = true; 00152 return kTfLiteOk; 00153 } 00154 00155 TfLiteStatus MicroInterpreter::Invoke() { 00156 if (initialization_status_ != kTfLiteOk) { 00157 error_reporter_->Report("Invoke() called after initialization failed\n"); 00158 return kTfLiteError; 00159 } 00160 00161 // Ensure tensors are allocated before the interpreter is invoked to avoid 00162 // difficult to debug segfaults. 00163 if (!tensors_allocated_) { 00164 AllocateTensors(); 00165 } 00166 00167 // Init method is not yet implemented. 00168 for (size_t i = 0; i < operators_->size(); ++i) { 00169 auto* node = &(node_and_registrations_[i].node); 00170 auto* registration = node_and_registrations_[i].registration; 00171 size_t init_data_size; 00172 const char* init_data; 00173 if (registration->builtin_code == BuiltinOperator_CUSTOM) { 00174 init_data = reinterpret_cast<const char*>(node->custom_initial_data); 00175 init_data_size = node->custom_initial_data_size; 00176 } else { 00177 init_data = reinterpret_cast<const char*>(node->builtin_data); 00178 init_data_size = 0; 00179 } 00180 if (registration->init) { 00181 node->user_data = 00182 registration->init(&context_, init_data, init_data_size); 00183 } 00184 } 00185 00186 for (size_t i = 0; i < operators_->size(); ++i) { 00187 auto* node = &(node_and_registrations_[i].node); 00188 auto* registration = node_and_registrations_[i].registration; 00189 if (registration->prepare) { 00190 TfLiteStatus prepare_status = registration->prepare(&context_, node); 00191 if (prepare_status != kTfLiteOk) { 00192 error_reporter_->Report( 00193 "Node %s (number %d) failed to prepare with status %d", 00194 OpNameFromRegistration(registration), i, prepare_status); 00195 return kTfLiteError; 00196 } 00197 } 00198 } 00199 00200 for (size_t i = 0; i < operators_->size(); ++i) { 00201 auto* node = &(node_and_registrations_[i].node); 00202 auto* registration = node_and_registrations_[i].registration; 00203 00204 if (registration->invoke) { 00205 TfLiteStatus invoke_status = registration->invoke(&context_, node); 00206 if (invoke_status != kTfLiteOk) { 00207 error_reporter_->Report( 00208 "Node %s (number %d) failed to invoke with status %d", 00209 OpNameFromRegistration(registration), i, invoke_status); 00210 return kTfLiteError; 00211 } 00212 } 00213 } 00214 00215 // This is actually a no-op. 00216 // TODO(wangtz): Consider removing this code to slightly reduce binary size. 00217 for (size_t i = 0; i < operators_->size(); ++i) { 00218 auto* node = &(node_and_registrations_[i].node); 00219 auto* registration = node_and_registrations_[i].registration; 00220 if (registration->free) { 00221 registration->free(&context_, node->user_data); 00222 } 00223 } 00224 return kTfLiteOk; 00225 } 00226 00227 TfLiteTensor* MicroInterpreter::input(size_t index) { 00228 const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs(); 00229 const size_t length = inputs->size(); 00230 if ((index < 0) || (index >= length)) { 00231 error_reporter_->Report("Input index %d out of range (length is %d)", index, 00232 length); 00233 return nullptr; 00234 } 00235 return &(context_.tensors[inputs->Get(index)]); 00236 } 00237 00238 TfLiteTensor* MicroInterpreter::output(size_t index) { 00239 const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs(); 00240 const size_t length = outputs->size(); 00241 if ((index < 0) || (index >= outputs->size())) { 00242 error_reporter_->Report("Output index %d out of range (length is %d)", 00243 index, length); 00244 return nullptr; 00245 } 00246 return &(context_.tensors[outputs->Get(index)]); 00247 } 00248 00249 TfLiteTensor* MicroInterpreter::tensor(size_t index) { 00250 const size_t length = tensors_size(); 00251 if ((index < 0) || (index >= tensors_size())) { 00252 error_reporter_->Report("Tensor index %d out of range (length is %d)", 00253 index, length); 00254 return nullptr; 00255 } 00256 return &context_.tensors[index]; 00257 } 00258 00259 struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration( 00260 int node_index) { 00261 TfLiteStatus status = kTfLiteOk; 00262 struct pairTfLiteNodeAndRegistration tfNodeRegiPair; 00263 auto opcodes = model_->operator_codes(); 00264 { 00265 const auto* op = operators_->Get(node_index); 00266 size_t index = op->opcode_index(); 00267 if (index < 0 || index >= opcodes->size()) { 00268 error_reporter_->Report("Missing registration for opcode_index %d\n", 00269 index); 00270 } 00271 auto opcode = (*opcodes)[index]; 00272 const TfLiteRegistration* registration = nullptr; 00273 status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_, 00274 ®istration); 00275 if (status != kTfLiteOk) { 00276 error_reporter_->Report("Missing registration for opcode_index %d\n", 00277 index); 00278 } 00279 if (registration == nullptr) { 00280 error_reporter_->Report("Skipping op for opcode_index %d\n", index); 00281 } 00282 00283 // Disregard const qualifier to workaround with existing API. 00284 TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>( 00285 reinterpret_cast<const TfLiteIntArray*>(op->inputs())); 00286 TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>( 00287 reinterpret_cast<const TfLiteIntArray*>(op->outputs())); 00288 00289 TfLiteNode node; 00290 node.inputs = inputs_array; 00291 node.outputs = outputs_array; 00292 tfNodeRegiPair.node = node; 00293 tfNodeRegiPair.registration = registration; 00294 } 00295 return tfNodeRegiPair; 00296 } 00297 00298 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2