Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers logistic.cc Source File

logistic.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/logistic.h"
00017 
00018 #include "tensorflow/lite/c/builtin_op_data.h"
00019 #include "tensorflow/lite/c/c_api_internal.h"
00020 #include "tensorflow/lite/kernels/internal/common.h"
00021 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00022 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00023 #include "tensorflow/lite/kernels/kernel_util.h"
00024 #include "tensorflow/lite/kernels/op_macros.h"
00025 
00026 namespace tflite {
00027 namespace ops {
00028 namespace micro {
00029 namespace activations {
00030 
00031 constexpr int kInputTensor = 0;
00032 constexpr int kOutputTensor = 0;
00033 
00034 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
00035   return kTfLiteOk;
00036 }
00037 
00038 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
00039   const TfLiteTensor* input = GetInput(context, node, kInputTensor);
00040   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
00041 
00042   switch (input->type) {
00043     case kTfLiteFloat32: {
00044       reference_ops::Logistic(
00045           GetTensorShape(input), GetTensorData<float>(input),
00046           GetTensorShape(output), GetTensorData<float>(output));
00047       return kTfLiteOk;
00048     }
00049     default: {
00050       // TODO(b/141211002): Also support other data types once we have supported
00051       // temporary tensors in TFLM.
00052       context->ReportError(context,
00053                            "Only float32 is supported currently, got %s",
00054                            TfLiteTypeGetName(input->type));
00055       return kTfLiteError;
00056     }
00057   }
00058 }
00059 
00060 }  // namespace activations
00061 
00062 TfLiteRegistration* Register_LOGISTIC() {
00063   static TfLiteRegistration r = {/*init=*/nullptr,
00064                                  /*free=*/nullptr, activations::Prepare,
00065                                  activations::Eval};
00066   return &r;
00067 }
00068 }  // namespace micro
00069 }  // namespace ops
00070 }  // namespace tflite