Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
prelu.cc
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 00016 #include "tensorflow/lite/kernels/internal/reference/prelu.h" 00017 00018 #include "tensorflow/lite/c/c_api_internal.h" 00019 #include "tensorflow/lite/kernels/internal/quantization_util.h" 00020 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" 00021 #include "tensorflow/lite/kernels/kernel_util.h" 00022 00023 namespace tflite { 00024 namespace ops { 00025 namespace micro { 00026 namespace activations { 00027 00028 TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { 00029 return kTfLiteOk; 00030 } 00031 00032 inline void BroadcastPrelu4DSlowFloat( 00033 const RuntimeShape& unextended_input1_shape, const float* input1_data, 00034 const RuntimeShape& unextended_input2_shape, const float* input2_data, 00035 const RuntimeShape& unextended_output_shape, float* output_data) { 00036 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); 00037 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); 00038 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); 00039 const RuntimeShape output_shape = 00040 RuntimeShape::ExtendedShape(4, unextended_output_shape); 00041 00042 NdArrayDesc<4> desc1; 00043 NdArrayDesc<4> desc2; 00044 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, 00045 unextended_input2_shape, &desc1, &desc2); 00046 00047 for (int b = 0; b < output_shape.Dims(0); ++b) { 00048 for (int y = 0; y < output_shape.Dims(1); ++y) { 00049 for (int x = 0; x < output_shape.Dims(2); ++x) { 00050 for (int c = 0; c < output_shape.Dims(3); ++c) { 00051 auto out_idx = Offset(output_shape, b, y, x, c); 00052 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); 00053 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); 00054 auto in1_val = input1_data[in1_idx]; 00055 auto in2_val = input2_data[in2_idx]; 00056 output_data[out_idx] = in1_val >= 0.0 ? in1_val : in1_val * in2_val; 00057 } 00058 } 00059 } 00060 } 00061 } 00062 00063 TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { 00064 const TfLiteTensor* input = GetInput(context, node, 0); 00065 const TfLiteTensor* alpha = GetInput(context, node, 1); 00066 TfLiteTensor* output = GetOutput(context, node, 0); 00067 int32_t output_multiplier = 0; 00068 int output_shift = 0; 00069 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt16) { 00070 double real_multiplier = 00071 input->params.scale * alpha->params.scale / output->params.scale; 00072 QuantizeMultiplierSmallerThanOneExp(real_multiplier, &output_multiplier, 00073 &output_shift); 00074 } 00075 switch (input->type) { 00076 case kTfLiteFloat32: { 00077 BroadcastPrelu4DSlowFloat( 00078 GetTensorShape(input), GetTensorData<float>(input), 00079 GetTensorShape(alpha), GetTensorData<float>(alpha), 00080 GetTensorShape(output), GetTensorData<float>(output)); 00081 return kTfLiteOk; 00082 } break; 00083 case kTfLiteUInt8: { 00084 PreluParams op_params; 00085 op_params.input_offset = -input->params.zero_point; 00086 op_params.alpha_offset = -alpha->params.zero_point; 00087 op_params.output_offset = output->params.zero_point; 00088 op_params.output_multiplier = output_multiplier; 00089 op_params.output_shift = output_shift; 00090 reference_ops::BroadcastPrelu4DSlow( 00091 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), 00092 GetTensorShape(alpha), GetTensorData<uint8_t>(alpha), 00093 GetTensorShape(output), GetTensorData<uint8_t>(output)); 00094 return kTfLiteOk; 00095 } break; 00096 default: 00097 context->ReportError( 00098 context, "Only float32 and uint8 are supported currently, got %d.", 00099 TfLiteTypeGetName(input->type)); 00100 return kTfLiteError; 00101 } 00102 } 00103 00104 } // namespace activations 00105 00106 TfLiteRegistration* Register_PRELU() { 00107 static TfLiteRegistration r = {nullptr, nullptr, activations::PreluPrepare, 00108 activations::PreluEval}; 00109 return &r; 00110 } 00111 00112 } // namespace micro 00113 } // namespace ops 00114 } // namespace tflite
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2