Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers main_functions.cc Source File

main_functions.cc

00001 #include "./main_functions.h"
00002 
00003 #include "./constants.h"
00004 #include "./output_handler.h"
00005 #include "./mnist_model_data.h"
00006 #include "./../../kernels/all_ops_resolver.h"
00007 #include "./../../micro_error_reporter.h"
00008 #include "./../../micro_interpreter.h"
00009 #include "./../../../../schema/schema_generated.h"
00010 #include "./../../../../version.h"
00011 
00012 #include "./inputdata.h"
00013 
00014 // Globals, used for compatibility with Arduino-style sketches.
00015 namespace {
00016 tflite::ErrorReporter* error_reporter = nullptr;
00017 const tflite::Model* model = nullptr;
00018 tflite::MicroInterpreter* interpreter = nullptr;
00019 TfLiteTensor* input = nullptr;
00020 TfLiteTensor* output = nullptr;
00021 
00022 // Create an area of memory to use for input, output, and intermediate arrays.
00023 // Finding the minimum value for your model may require some trial and error.
00024 constexpr int kTensorArenaSize = 10 * 1024;
00025 uint8_t tensor_arena[kTensorArenaSize];
00026 }  // namespace
00027 
00028 // The name of this function is important for Arduino compatibility.
00029 void setup() {
00030   // Set up logging. Google style is to avoid globals or statics because of
00031   // lifetime uncertainty, but since this has a trivial destructor it's okay.
00032   // NOLINTNEXTLINE(runtime-global-variables)
00033   static tflite::MicroErrorReporter micro_error_reporter;
00034   error_reporter = &micro_error_reporter;
00035 
00036   // Map the model into a usable data structure. This doesn't involve any
00037   // copying or parsing, it's a very lightweight operation.
00038   model = tflite::GetModel(model_quantized_tflite);
00039   if (model->version() != TFLITE_SCHEMA_VERSION) {
00040     error_reporter->Report(
00041         "Model provided is schema version %d not equal "
00042         "to supported version %d.",
00043         model->version(), TFLITE_SCHEMA_VERSION);
00044     return;
00045   }
00046 
00047   // This pulls in all the operation implementations we need.
00048   // NOLINTNEXTLINE(runtime-global-variables)
00049   static tflite::ops::micro::AllOpsResolver resolver;
00050 
00051   // Build an interpreter to run the model with.
00052   static tflite::MicroInterpreter static_interpreter(
00053       model, resolver, tensor_arena, kTensorArenaSize, error_reporter);
00054   interpreter = &static_interpreter;
00055 
00056   // Allocate memory from the tensor_arena for the model's tensors.
00057   TfLiteStatus allocate_status = interpreter->AllocateTensors();
00058   if (allocate_status != kTfLiteOk) {
00059     error_reporter->Report("AllocateTensors() failed");
00060     return;
00061   }
00062 
00063   // Obtain pointers to the model's input and output tensors.
00064   input = interpreter->input(0);
00065   output = interpreter->output(0);
00066   
00067   
00068   printf("HALLO\n");
00069   
00070 }
00071 
00072 // The name of this function is important for Arduino compatibility.
00073 void loop() {
00074 
00075   // Place our calculated x value in the model's input tensor
00076   //input->data.f[0] = x_val;
00077   for (int i = 0; i < 784; ++i) {
00078     input->data.f[i] = input_data[i];
00079   }
00080 
00081   // Run inference, and report any error
00082   TfLiteStatus invoke_status = interpreter->Invoke();
00083   if (invoke_status != kTfLiteOk) {
00084     error_reporter->Report("Invoke failed on x_val: \n");
00085     return;
00086   }
00087   
00088   float val = 0;
00089   int digit = 0;
00090   for (int i = 0; i < 10; ++i) {
00091         float current = output->data.f[i];
00092         if(current > 0.7 && current > val) {
00093             val = current;
00094             digit = i;
00095         }
00096   }
00097   
00098   printf("NUMBER: %d\n", digit);
00099   printf("ACC: %f\n", val);
00100 }