Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers prelu.h Source File

prelu.h

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/common.h"
00019 #include "tensorflow/lite/kernels/internal/compatibility.h"
00020 #include "tensorflow/lite/kernels/internal/types.h"
00021 
00022 namespace tflite {
00023 
00024 namespace reference_ops {
00025 
00026 // Broadcast prelu to output_shape for quantized uint8 data.
00027 inline void BroadcastPrelu4DSlow(const PreluParams& params,
00028                                  const RuntimeShape& input_shape,
00029                                  const uint8* input_data,
00030                                  const RuntimeShape& alpha_shape,
00031                                  const uint8* alpha_data,
00032                                  const RuntimeShape& output_shape,
00033                                  uint8* output_data) {
00034   TFLITE_DCHECK_LE(input_shape.DimensionsCount(), 4);
00035   TFLITE_DCHECK_LE(alpha_shape.DimensionsCount(), 4);
00036   TFLITE_DCHECK_LE(output_shape.DimensionsCount(), 4);
00037   const RuntimeShape extended_output_shape =
00038       RuntimeShape::ExtendedShape(4, output_shape);
00039   NdArrayDesc<4> desc1;
00040   NdArrayDesc<4> desc2;
00041   NdArrayDescsForElementwiseBroadcast(input_shape, alpha_shape, &desc1, &desc2);
00042 
00043   for (int b = 0; b < extended_output_shape.Dims(0); ++b) {
00044     for (int y = 0; y < extended_output_shape.Dims(1); ++y) {
00045       for (int x = 0; x < extended_output_shape.Dims(2); ++x) {
00046         for (int c = 0; c < extended_output_shape.Dims(3); ++c) {
00047           int output_index = Offset(extended_output_shape, b, y, x, c);
00048           int input_index = SubscriptToIndex(desc1, b, y, x, c);
00049           const int32 input_value =
00050               params.input_offset + input_data[input_index];
00051           if (input_value >= 0) {
00052             output_data[output_index] = input_data[input_index];
00053           } else {
00054             auto alpha_index = SubscriptToIndex(desc2, b, y, x, c);
00055             const int32 alpha_value =
00056                 params.alpha_offset + alpha_data[alpha_index];
00057             const int32 unclamped_output =
00058                 params.output_offset +
00059                 MultiplyByQuantizedMultiplierSmallerThanOneExp(
00060                     input_value * alpha_value, params.output_multiplier,
00061                     params.output_shift);
00062             const int32 quantized_min = std::numeric_limits<uint8_t>::min();
00063             const int32 quantized_max = std::numeric_limits<uint8_t>::max();
00064             const int32 clamped_output = std::min(
00065                 quantized_max, std::max(quantized_min, unclamped_output));
00066             output_data[output_index] = static_cast<uint8>(clamped_output);
00067           }
00068         }
00069       }
00070     }
00071   }
00072 }
00073 
00074 }  // namespace reference_ops
00075 }  // namespace tflite
00076 
00077 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PRELU_H_