Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers prelu.cc Source File

prelu.cc

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 
00016 #include "tensorflow/lite/kernels/internal/reference/prelu.h"
00017 
00018 #include "tensorflow/lite/c/c_api_internal.h"
00019 #include "tensorflow/lite/kernels/internal/quantization_util.h"
00020 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00021 #include "tensorflow/lite/kernels/kernel_util.h"
00022 
00023 namespace tflite {
00024 namespace ops {
00025 namespace micro {
00026 namespace activations {
00027 
00028 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) {
00029   return kTfLiteOk;
00030 }
00031 
00032 inline void BroadcastPrelu4DSlowFloat(
00033     const RuntimeShape& unextended_input1_shape, const float* input1_data,
00034     const RuntimeShape& unextended_input2_shape, const float* input2_data,
00035     const RuntimeShape& unextended_output_shape, float* output_data) {
00036   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
00037   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
00038   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
00039   const RuntimeShape output_shape =
00040       RuntimeShape::ExtendedShape(4, unextended_output_shape);
00041 
00042   NdArrayDesc<4> desc1;
00043   NdArrayDesc<4> desc2;
00044   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
00045                                       unextended_input2_shape, &desc1, &desc2);
00046 
00047   for (int b = 0; b < output_shape.Dims(0); ++b) {
00048     for (int y = 0; y < output_shape.Dims(1); ++y) {
00049       for (int x = 0; x < output_shape.Dims(2); ++x) {
00050         for (int c = 0; c < output_shape.Dims(3); ++c) {
00051           auto out_idx = Offset(output_shape, b, y, x, c);
00052           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
00053           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
00054           auto in1_val = input1_data[in1_idx];
00055           auto in2_val = input2_data[in2_idx];
00056           output_data[out_idx] = in1_val >= 0.0 ? in1_val : in1_val * in2_val;
00057         }
00058       }
00059     }
00060   }
00061 }
00062 
00063 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) {
00064   const TfLiteTensor* input = GetInput(context, node, 0);
00065   const TfLiteTensor* alpha = GetInput(context, node, 1);
00066   TfLiteTensor* output = GetOutput(context, node, 0);
00067   int32_t output_multiplier = 0;
00068   int output_shift = 0;
00069   if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) {
00070     double real_multiplier =
00071         input->params.scale * alpha->params.scale / output->params.scale;
00072     QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier,
00073                                         &output_shift);
00074   }
00075   switch (input->type) {
00076     case kTfLiteFloat32: {
00077       BroadcastPrelu4DSlowFloat(
00078           GetTensorShape(input), GetTensorData<float>(input),
00079           GetTensorShape(alpha), GetTensorData<float>(alpha),
00080           GetTensorShape(output), GetTensorData<float>(output));
00081       return kTfLiteOk;
00082     } break;
00083     case kTfLiteUInt8: {
00084       PreluParams op_params;
00085       op_params.input_offset = -input->params.zero_point;
00086       op_params.alpha_offset = -alpha->params.zero_point;
00087       op_params.output_offset = output->params.zero_point;
00088       op_params.output_multiplier = output_multiplier;
00089       op_params.output_shift = output_shift;
00090       reference_ops::BroadcastPrelu4DSlow(
00091           op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
00092           GetTensorShape(alpha), GetTensorData<uint8_t>(alpha),
00093           GetTensorShape(output), GetTensorData<uint8_t>(output));
00094       return kTfLiteOk;
00095     } break;
00096     default:
00097       context->ReportError(
00098           context, "Only float32 and uint8 are supported currently, got %d.",
00099           TfLiteTypeGetName(input->type));
00100       return kTfLiteError;
00101   }
00102 }
00103 
00104 }  // namespace activations
00105 
00106 TfLiteRegistration* Register_PRELU() {
00107   static TfLiteRegistration r = {nullptr, nullptr, activations::PreluPrepare,
00108                                  activations::PreluEval};
00109   return &r;
00110 }
00111 
00112 }  // namespace micro
00113 }  // namespace ops
00114 }  // namespace tflite