Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
binary_function.h
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_ 00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_ 00017 00018 #include "tensorflow/lite/kernels/internal/common.h" 00019 #include "tensorflow/lite/kernels/internal/compatibility.h" 00020 #include "tensorflow/lite/kernels/internal/types.h" 00021 00022 namespace tflite { 00023 00024 namespace reference_ops { 00025 00026 // TODO(ycling): Refactoring. Remove BroadcastLogical and use the more 00027 // generalized and efficient BroadcastBinaryFunction. 00028 // 00029 // Also appears to duplicte MinimumMaximum. 00030 // 00031 // R: Result type. T1: Input 1 type. T2: Input 2 type. 00032 template <typename R, typename T1, typename T2> 00033 inline void BroadcastBinaryFunction4DSlow( 00034 const RuntimeShape& unextended_input1_shape, const T1* input1_data, 00035 const RuntimeShape& unextended_input2_shape, const T2* input2_data, 00036 const RuntimeShape& unextended_output_shape, R* output_data, 00037 R (*func)(T1, T2)) { 00038 TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4); 00039 TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4); 00040 TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4); 00041 const RuntimeShape output_shape = 00042 RuntimeShape::ExtendedShape(4, unextended_output_shape); 00043 00044 NdArrayDesc<4> desc1; 00045 NdArrayDesc<4> desc2; 00046 NdArrayDescsForElementwiseBroadcast(unextended_input1_shape, 00047 unextended_input2_shape, &desc1, &desc2); 00048 00049 for (int b = 0; b < output_shape.Dims(0); ++b) { 00050 for (int y = 0; y < output_shape.Dims(1); ++y) { 00051 for (int x = 0; x < output_shape.Dims(2); ++x) { 00052 for (int c = 0; c < output_shape.Dims(3); ++c) { 00053 auto out_idx = Offset(output_shape, b, y, x, c); 00054 auto in1_idx = SubscriptToIndex(desc1, b, y, x, c); 00055 auto in2_idx = SubscriptToIndex(desc2, b, y, x, c); 00056 auto in1_val = input1_data[in1_idx]; 00057 auto in2_val = input2_data[in2_idx]; 00058 output_data[out_idx] = func(in1_val, in2_val); 00059 } 00060 } 00061 } 00062 } 00063 } 00064 00065 // R: Result type. T1: Input 1 type. T2: Input 2 type. 00066 // TODO(renjieliu): Refactor other binary functions to use this one. 00067 template <typename R, typename T1, typename T2> 00068 inline void BinaryFunction(const RuntimeShape& input1_shape, 00069 const T1* input1_data, 00070 const RuntimeShape& input2_shape, 00071 const T2* input2_data, 00072 const RuntimeShape& output_shape, R* output_data, 00073 R (*func)(T1, T2)) { 00074 const int flat_size = 00075 MatchingFlatSize(input1_shape, input2_shape, output_shape); 00076 for (int i = 0; i < flat_size; ++i) { 00077 output_data[i] = func(input1_data[i], input2_data[i]); 00078 } 00079 } 00080 00081 } // namespace reference_ops 00082 } // namespace tflite 00083 00084 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2