Important changes to repositories hosted on mbed.com
Mbed hosted mercurial repositories are deprecated and are due to be permanently deleted in July 2026.
To keep a copy of this software download the repository Zip archive or clone locally using Mercurial.
It is also possible to export all your personal repositories from the account settings page.
process_broadcast_shapes.h
00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 00002 00003 Licensed under the Apache License, Version 2.0 (the "License"); 00004 you may not use this file except in compliance with the License. 00005 You may obtain a copy of the License at 00006 00007 http://www.apache.org/licenses/LICENSE-2.0 00008 00009 Unless required by applicable law or agreed to in writing, software 00010 distributed under the License is distributed on an "AS IS" BASIS, 00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 00012 See the License for the specific language governing permissions and 00013 limitations under the License. 00014 ==============================================================================*/ 00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_ 00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_ 00017 00018 #include "tensorflow/lite/kernels/internal/types.h" 00019 00020 namespace tflite { 00021 00022 namespace reference_ops { 00023 00024 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern. 00025 // 00026 // For example, if sequence of dimensions of one input is 00027 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ... 00028 // we can consolidate these as 00029 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1. 00030 // 00031 // The category is updated in the less-frequent case of shapes that are 00032 // not suited to a fivefold-loop broadcast. 00033 // 00034 // Falls back to generic pattern when it does not know how to process properly. 00035 // 00036 // Returns true iff there is some sort of broadcast, which includes five-fold 00037 // patterns and falling back to generic broadcast. 00038 inline bool ProcessBroadcastShapes(const RuntimeShape& shape0, 00039 const RuntimeShape& shape1, 00040 tflite::ArithmeticParams* params) { 00041 const int dims_count = 00042 std::max(shape0.DimensionsCount(), shape1.DimensionsCount()); 00043 00044 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; 00045 RuntimeShape scalar_shape(dims_count, 1); 00046 00047 auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0); 00048 auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1); 00049 00050 // Check for "exact" match, implicitly accepting any scalar shapes. 00051 if (extended_shape0 == extended_shape1) { 00052 params->broadcast_category = BroadcastableOpCategory::kNonBroadcast; 00053 return false; 00054 } 00055 00056 for (int i = dims_count - 1; i >= 0; --i) { 00057 if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) { 00058 continue; 00059 } else if (extended_shape0.Dims(i) == 1) { 00060 params->broadcast_category = 00061 BroadcastableOpCategory::kFirstInputBroadcastsFast; 00062 break; 00063 } else if (extended_shape1.Dims(i) == 1) { 00064 params->broadcast_category = 00065 BroadcastableOpCategory::kSecondInputBroadcastsFast; 00066 break; 00067 } else { 00068 // This case is erroneous: there is a dimension that does not match and 00069 // is not a broadcast from one shape to the other. 00070 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; 00071 return true; 00072 } 00073 } 00074 00075 if (params->broadcast_category != 00076 BroadcastableOpCategory::kFirstInputBroadcastsFast && 00077 params->broadcast_category != 00078 BroadcastableOpCategory::kSecondInputBroadcastsFast) { 00079 return false; 00080 } 00081 00082 // From this point it is assumed contractually that corresponding dimensions 00083 // in shape0 and shape1 are either (a) equal or (b) one or other equals 1. 00084 const bool swap_inputs = params->broadcast_category == 00085 BroadcastableOpCategory::kSecondInputBroadcastsFast; 00086 const RuntimeShape* shape_a = 00087 swap_inputs ? &extended_shape1 : &extended_shape0; 00088 const RuntimeShape* shape_b = 00089 swap_inputs ? &extended_shape0 : &extended_shape1; 00090 00091 int i = dims_count - 1; 00092 params->broadcast_shape[0] = 1; 00093 params->broadcast_shape[1] = 1; 00094 params->broadcast_shape[2] = 1; 00095 params->broadcast_shape[3] = 1; 00096 params->broadcast_shape[4] = 1; 00097 // y_0 is greedy: include dims if both or neither equal 1: in other words, 00098 // test for equality rather than (shape_a->Dims(i) != 1). 00099 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { 00100 params->broadcast_shape[4] *= shape_b->Dims(i); 00101 --i; 00102 } 00103 // Here either input_a or input_b has dim of 1 (if i >= 0). If it is input_b 00104 // that has the unit dimension, the next two loops are not entered. 00105 while (i >= 0 && shape_a->Dims(i) == 1) { 00106 params->broadcast_shape[3] *= shape_b->Dims(i); 00107 --i; 00108 } 00109 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { 00110 params->broadcast_shape[2] *= shape_a->Dims(i); 00111 --i; 00112 } 00113 // Here either input_a or input_b has dim of 1 (if i >= 0). 00114 while (i >= 0 && shape_b->Dims(i) == 1) { 00115 params->broadcast_shape[1] *= shape_a->Dims(i); 00116 --i; 00117 } 00118 while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) { 00119 params->broadcast_shape[0] *= shape_b->Dims(i); 00120 --i; 00121 } 00122 00123 // Rarer case is when the broadcast dimensions cannot be handled by a fivefold 00124 // loop. 00125 if (i >= 0) { 00126 params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast; 00127 } 00128 return true; 00129 } 00130 00131 } // namespace reference_ops 00132 } // namespace tflite 00133 00134 #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
Generated on Wed Jul 13 2022 16:03:35 by
1.7.2