Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers micro_mutable_op_resolver.cc Source File

micro_mutable_op_resolver.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/experimental/micro/micro_mutable_op_resolver.h"
00017 
00018 namespace tflite {
00019 
00020 const TfLiteRegistration* MicroMutableOpResolver::FindOp(
00021     tflite::BuiltinOperator op, int version) const {
00022   for (int i = 0; i < registrations_len_; ++i) {
00023     const TfLiteRegistration& registration = registrations_[i];
00024     if ((registration.builtin_code == op) &&
00025         (registration.version == version)) {
00026       return &registration;
00027     }
00028   }
00029   return nullptr;
00030 }
00031 
00032 const TfLiteRegistration* MicroMutableOpResolver::FindOp(const char* op,
00033                                                          int version) const {
00034   for (int i = 0; i < registrations_len_; ++i) {
00035     const TfLiteRegistration& registration = registrations_[i];
00036     if ((registration.builtin_code == BuiltinOperator_CUSTOM) &&
00037         (strcmp(registration.custom_name, op) == 0) &&
00038         (registration.version == version)) {
00039       return &registration;
00040     }
00041   }
00042   return nullptr;
00043 }
00044 
00045 void MicroMutableOpResolver::AddBuiltin(tflite::BuiltinOperator op,
00046                                         TfLiteRegistration* registration,
00047                                         int min_version, int max_version) {
00048   for (int version = min_version; version <= max_version; ++version) {
00049     if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
00050       // TODO(petewarden) - Add error reporting hooks so we can report this!
00051       return;
00052     }
00053     TfLiteRegistration* new_registration = &registrations_[registrations_len_];
00054     registrations_len_ += 1;
00055 
00056     *new_registration = *registration;
00057     new_registration->builtin_code = op;
00058     new_registration->version = version;
00059   }
00060 }
00061 
00062 void MicroMutableOpResolver::AddCustom(const char* name,
00063                                        TfLiteRegistration* registration,
00064                                        int min_version, int max_version) {
00065   for (int version = min_version; version <= max_version; ++version) {
00066     if (registrations_len_ >= TFLITE_REGISTRATIONS_MAX) {
00067       // TODO(petewarden) - Add error reporting hooks so we can report this!
00068       return;
00069     }
00070     TfLiteRegistration* new_registration = &registrations_[registrations_len_];
00071     registrations_len_ += 1;
00072 
00073     *new_registration = *registration;
00074     new_registration->builtin_code = BuiltinOperator_CUSTOM;
00075     new_registration->custom_name = name;
00076     new_registration->version = version;
00077   }
00078 }
00079 
00080 }  // namespace tflite