Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers micro_interpreter.cc Source File

micro_interpreter.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #include "tensorflow/lite/experimental/micro/micro_interpreter.h"
00016 
00017 #include "tensorflow/lite/c/c_api_internal.h"
00018 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
00019 #include "tensorflow/lite/experimental/micro/compatibility.h"
00020 
00021 namespace tflite {
00022 namespace {
00023 const int kStackDataAllocatorSize = 128;
00024 class StackDataAllocator : public BuiltinDataAllocator {
00025  public:
00026   void* Allocate(size_t size) override {
00027     if (size > kStackDataAllocatorSize) {
00028       return nullptr;
00029     } else {
00030       return data_;
00031     }
00032   }
00033   void Deallocate(void* data) override {
00034     // Do nothing.
00035   }
00036 
00037  private:
00038   uint8_t data_[kStackDataAllocatorSize];
00039 
00040   TF_LITE_REMOVE_VIRTUAL_DELETE
00041 };
00042 
00043 const char* OpNameFromRegistration(const TfLiteRegistration* registration) {
00044   if (registration->builtin_code == BuiltinOperator_CUSTOM) {
00045     return registration->custom_name;
00046   } else {
00047     return EnumNameBuiltinOperator(BuiltinOperator(registration->builtin_code));
00048   }
00049 }
00050 
00051 void ReportOpError(struct TfLiteContext* context, const char* format, ...) {
00052   MicroInterpreter* interpreter =
00053       static_cast<MicroInterpreter*>(context->impl_);
00054   va_list args;
00055   va_start(args, format);
00056   interpreter->error_reporter()->Report(format, args);
00057   va_end(args);
00058 }
00059 
00060 }  // namespace
00061 
00062 MicroInterpreter::MicroInterpreter(const Model* model,
00063                                    const OpResolver& op_resolver,
00064                                    uint8_t* tensor_arena,
00065                                    size_t tensor_arena_size,
00066                                    ErrorReporter* error_reporter)
00067     : model_(model),
00068       op_resolver_(op_resolver),
00069       error_reporter_(error_reporter),
00070       context_(),
00071       allocator_(&context_, model_, tensor_arena, tensor_arena_size,
00072                  error_reporter_),
00073       tensors_allocated_(false) {
00074   auto* subgraphs = model->subgraphs();
00075   if (subgraphs->size() != 1) {
00076     error_reporter->Report("Only 1 subgraph is currently supported.\n");
00077     initialization_status_ = kTfLiteError;
00078     return;
00079   }
00080   subgraph_ = (*subgraphs)[0];
00081   tensors_ = subgraph_->tensors();
00082   operators_ = subgraph_->operators();
00083 
00084   context_.impl_ = static_cast<void*>(this);
00085   context_.ReportError = ReportOpError;
00086   context_.recommended_num_threads = 1;
00087 
00088   // If the system is big endian then convert weights from the flatbuffer from
00089   // little to big endian on startup so that it does not need to be done during
00090   // inference.
00091   // NOTE: This requires that the flatbuffer is held in memory which can be
00092   // modified by this process.
00093   if (!FLATBUFFERS_LITTLEENDIAN) {
00094     for (int t = 0; t < tensors_size(); ++t) {
00095       TfLiteTensor* thisTensor = &context_.tensors[t];
00096       if (thisTensor->allocation_type == kTfLiteMmapRo)
00097         CorrectTensorEndianness(thisTensor);
00098     }
00099   }
00100 
00101   initialization_status_ = kTfLiteOk;
00102 }
00103 
00104 void MicroInterpreter::CorrectTensorEndianness(TfLiteTensor* tensorCorr) {
00105   int32_t tensorSize = 1;
00106   for (int d = 0; d < tensorCorr->dims->size; ++d)
00107     tensorSize *= reinterpret_cast<const int32_t*>(tensorCorr->dims->data)[d];
00108 
00109   switch (tensorCorr->type) {
00110     case TfLiteType::kTfLiteFloat32:
00111       CorrectTensorDataEndianness(tensorCorr->data.f, tensorSize);
00112       break;
00113     case TfLiteType::kTfLiteFloat16:
00114       CorrectTensorDataEndianness(tensorCorr->data.f16, tensorSize);
00115       break;
00116     case TfLiteType::kTfLiteInt64:
00117       CorrectTensorDataEndianness(tensorCorr->data.i64, tensorSize);
00118       break;
00119     case TfLiteType::kTfLiteInt32:
00120       CorrectTensorDataEndianness(tensorCorr->data.i32, tensorSize);
00121       break;
00122     case TfLiteType::kTfLiteInt16:
00123       CorrectTensorDataEndianness(tensorCorr->data.i16, tensorSize);
00124       break;
00125     case TfLiteType::kTfLiteComplex64:
00126       CorrectTensorDataEndianness(tensorCorr->data.c64, tensorSize);
00127       break;
00128     default:
00129       // Do nothing for other data types.
00130       break;
00131   }
00132 }
00133 
00134 template <class T>
00135 void MicroInterpreter::CorrectTensorDataEndianness(T* data, int32_t size) {
00136   for (int32_t i = 0; i < size; ++i) {
00137     data[i] = flatbuffers::EndianScalar(data[i]);
00138   }
00139 }
00140 
00141 TfLiteStatus MicroInterpreter::RegisterPreallocatedInput(uint8_t* buffer,
00142                                                          size_t input_index) {
00143   return allocator_.RegisterPreallocatedInput(buffer, input_index);
00144 }
00145 
00146 TfLiteStatus MicroInterpreter::AllocateTensors() {
00147   TF_LITE_ENSURE_OK(&context_, allocator_.AllocateNodeAndRegistrations(
00148                                    op_resolver_, &node_and_registrations_));
00149   TF_LITE_ENSURE_OK(&context_, allocator_.FinishTensorAllocation());
00150 
00151   tensors_allocated_ = true;
00152   return kTfLiteOk;
00153 }
00154 
00155 TfLiteStatus MicroInterpreter::Invoke() {
00156   if (initialization_status_ != kTfLiteOk) {
00157     error_reporter_->Report("Invoke() called after initialization failed\n");
00158     return kTfLiteError;
00159   }
00160 
00161   // Ensure tensors are allocated before the interpreter is invoked to avoid
00162   // difficult to debug segfaults.
00163   if (!tensors_allocated_) {
00164     AllocateTensors();
00165   }
00166 
00167   // Init method is not yet implemented.
00168   for (size_t i = 0; i < operators_->size(); ++i) {
00169     auto* node = &(node_and_registrations_[i].node);
00170     auto* registration = node_and_registrations_[i].registration;
00171     size_t init_data_size;
00172     const char* init_data;
00173     if (registration->builtin_code == BuiltinOperator_CUSTOM) {
00174       init_data = reinterpret_cast<const char*>(node->custom_initial_data);
00175       init_data_size = node->custom_initial_data_size;
00176     } else {
00177       init_data = reinterpret_cast<const char*>(node->builtin_data);
00178       init_data_size = 0;
00179     }
00180     if (registration->init) {
00181       node->user_data =
00182           registration->init(&context_, init_data, init_data_size);
00183     }
00184   }
00185 
00186   for (size_t i = 0; i < operators_->size(); ++i) {
00187     auto* node = &(node_and_registrations_[i].node);
00188     auto* registration = node_and_registrations_[i].registration;
00189     if (registration->prepare) {
00190       TfLiteStatus prepare_status = registration->prepare(&context_, node);
00191       if (prepare_status != kTfLiteOk) {
00192         error_reporter_->Report(
00193             "Node %s (number %d) failed to prepare with status %d",
00194             OpNameFromRegistration(registration), i, prepare_status);
00195         return kTfLiteError;
00196       }
00197     }
00198   }
00199 
00200   for (size_t i = 0; i < operators_->size(); ++i) {
00201     auto* node = &(node_and_registrations_[i].node);
00202     auto* registration = node_and_registrations_[i].registration;
00203 
00204     if (registration->invoke) {
00205       TfLiteStatus invoke_status = registration->invoke(&context_, node);
00206       if (invoke_status != kTfLiteOk) {
00207         error_reporter_->Report(
00208             "Node %s (number %d) failed to invoke with status %d",
00209             OpNameFromRegistration(registration), i, invoke_status);
00210         return kTfLiteError;
00211       }
00212     }
00213   }
00214 
00215   // This is actually a no-op.
00216   // TODO(wangtz): Consider removing this code to slightly reduce binary size.
00217   for (size_t i = 0; i < operators_->size(); ++i) {
00218     auto* node = &(node_and_registrations_[i].node);
00219     auto* registration = node_and_registrations_[i].registration;
00220     if (registration->free) {
00221       registration->free(&context_, node->user_data);
00222     }
00223   }
00224   return kTfLiteOk;
00225 }
00226 
00227 TfLiteTensor* MicroInterpreter::input(size_t index) {
00228   const flatbuffers::Vector<int32_t>* inputs = subgraph_->inputs();
00229   const size_t length = inputs->size();
00230   if ((index < 0) || (index >= length)) {
00231     error_reporter_->Report("Input index %d out of range (length is %d)", index,
00232                             length);
00233     return nullptr;
00234   }
00235   return &(context_.tensors[inputs->Get(index)]);
00236 }
00237 
00238 TfLiteTensor* MicroInterpreter::output(size_t index) {
00239   const flatbuffers::Vector<int32_t>* outputs = subgraph_->outputs();
00240   const size_t length = outputs->size();
00241   if ((index < 0) || (index >= outputs->size())) {
00242     error_reporter_->Report("Output index %d out of range (length is %d)",
00243                             index, length);
00244     return nullptr;
00245   }
00246   return &(context_.tensors[outputs->Get(index)]);
00247 }
00248 
00249 TfLiteTensor* MicroInterpreter::tensor(size_t index) {
00250   const size_t length = tensors_size();
00251   if ((index < 0) || (index >= tensors_size())) {
00252     error_reporter_->Report("Tensor index %d out of range (length is %d)",
00253                             index, length);
00254     return nullptr;
00255   }
00256   return &context_.tensors[index];
00257 }
00258 
00259 struct pairTfLiteNodeAndRegistration MicroInterpreter::node_and_registration(
00260     int node_index) {
00261   TfLiteStatus status = kTfLiteOk;
00262   struct pairTfLiteNodeAndRegistration tfNodeRegiPair;
00263   auto opcodes = model_->operator_codes();
00264   {
00265     const auto* op = operators_->Get(node_index);
00266     size_t index = op->opcode_index();
00267     if (index < 0 || index >= opcodes->size()) {
00268       error_reporter_->Report("Missing registration for opcode_index %d\n",
00269                               index);
00270     }
00271     auto opcode = (*opcodes)[index];
00272     const TfLiteRegistration* registration = nullptr;
00273     status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
00274                                        &registration);
00275     if (status != kTfLiteOk) {
00276       error_reporter_->Report("Missing registration for opcode_index %d\n",
00277                               index);
00278     }
00279     if (registration == nullptr) {
00280       error_reporter_->Report("Skipping op for opcode_index %d\n", index);
00281     }
00282 
00283     // Disregard const qualifier to workaround with existing API.
00284     TfLiteIntArray* inputs_array = const_cast<TfLiteIntArray*>(
00285         reinterpret_cast<const TfLiteIntArray*>(op->inputs()));
00286     TfLiteIntArray* outputs_array = const_cast<TfLiteIntArray*>(
00287         reinterpret_cast<const TfLiteIntArray*>(op->outputs()));
00288 
00289     TfLiteNode node;
00290     node.inputs = inputs_array;
00291     node.outputs = outputs_array;
00292     tfNodeRegiPair.node = node;
00293     tfNodeRegiPair.registration = registration;
00294   }
00295   return tfNodeRegiPair;
00296 }
00297 
00298 }  // namespace tflite