Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers round.cc Source File

round.cc

00001 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/round.h"
00017 
00018 #include "tensorflow/lite/c/c_api_internal.h"
00019 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00020 #include "tensorflow/lite/kernels/kernel_util.h"
00021 
00022 namespace tflite {
00023 namespace ops {
00024 namespace micro {
00025 namespace round {
00026 
00027 constexpr int kInputTensor = 0;
00028 constexpr int kOutputTensor = 0;
00029 
00030 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00031   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00032   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00033   TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
00034   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
00035   TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
00036   TF_LITE_ENSURE_EQ(context, output->type, input->type);
00037   TF_LITE_ENSURE_EQ(context, output->bytes, input->bytes);
00038   TF_LITE_ENSURE_EQ(context, output->dims->size, input->dims->size);
00039   for (int i = 0; i < output->dims->size; ++i) {
00040     TF_LITE_ENSURE_EQ(context, output->dims->data[i], input->dims->data[i]);
00041   }
00042   return kTfLiteOk;
00043 }
00044 
00045 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00046   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00047   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00048 
00049   reference_ops::Round(GetTensorShape(input), GetTensorData<float>(input),
00050                        GetTensorShape(output), GetTensorData<float>(output));
00051 
00052   return kTfLiteOk;
00053 }
00054 }  // namespace round
00055 
00056 TfLiteRegistration* Register_ROUND() {
00057   static TfLiteRegistration r = {/*init=*/nullptr,
00058                                  /*free=*/nullptr, round::Prepare, round::Eval};
00059   return &r;
00060 }
00061 
00062 }  // namespace micro
00063 }  // namespace ops
00064 }  // namespace tflite