Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers process_broadcast_shapes.h Source File

process_broadcast_shapes.h

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/types.h"
00019 
00020 namespace tflite {
00021 
00022 namespace reference_ops {
00023 
00024 // Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
00025 //
00026 // For example, if sequence of dimensions of one input is
00027 // ..., 1, 3, 1, 7, 9, 5,... and the other is ..., 2, 3, 1, 7, 1, 1, ...
00028 // we can consolidate these as
00029 // ..., 1, 3*7, 9*5, ... and 2, 3*7, 1.
00030 //
00031 // The category is updated in the less-frequent case of shapes that are
00032 // not suited to a fivefold-loop broadcast.
00033 //
00034 // Falls back to generic pattern when it does not know how to process properly.
00035 //
00036 // Returns true iff there is some sort of broadcast, which includes five-fold
00037 // patterns and falling back to generic broadcast.
00038 inline bool ProcessBroadcastShapes(const RuntimeShape& shape0,
00039                                    const RuntimeShape& shape1,
00040                                    tflite::ArithmeticParams* params) {
00041   const int dims_count =
00042       std::max(shape0.DimensionsCount(), shape1.DimensionsCount());
00043 
00044   params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
00045   RuntimeShape scalar_shape(dims_count, 1);
00046 
00047   auto extended_shape0 = RuntimeShape::ExtendedShape(dims_count, shape0);
00048   auto extended_shape1 = RuntimeShape::ExtendedShape(dims_count, shape1);
00049 
00050   // Check for "exact" match, implicitly accepting any scalar shapes.
00051   if (extended_shape0 == extended_shape1) {
00052     params->broadcast_category = BroadcastableOpCategory::kNonBroadcast;
00053     return false;
00054   }
00055 
00056   for (int i = dims_count - 1; i >= 0; --i) {
00057     if (extended_shape0.Dims(i) == extended_shape1.Dims(i)) {
00058       continue;
00059     } else if (extended_shape0.Dims(i) == 1) {
00060       params->broadcast_category =
00061           BroadcastableOpCategory::kFirstInputBroadcastsFast;
00062       break;
00063     } else if (extended_shape1.Dims(i) == 1) {
00064       params->broadcast_category =
00065           BroadcastableOpCategory::kSecondInputBroadcastsFast;
00066       break;
00067     } else {
00068       // This case is erroneous: there is a dimension that does not match and
00069       // is not a broadcast from one shape to the other.
00070       params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
00071       return true;
00072     }
00073   }
00074 
00075   if (params->broadcast_category !=
00076           BroadcastableOpCategory::kFirstInputBroadcastsFast &&
00077       params->broadcast_category !=
00078           BroadcastableOpCategory::kSecondInputBroadcastsFast) {
00079     return false;
00080   }
00081 
00082   // From this point it is assumed contractually that corresponding dimensions
00083   // in shape0 and shape1 are either (a) equal or (b) one or other equals 1.
00084   const bool swap_inputs = params->broadcast_category ==
00085                            BroadcastableOpCategory::kSecondInputBroadcastsFast;
00086   const RuntimeShape* shape_a =
00087       swap_inputs ? &extended_shape1 : &extended_shape0;
00088   const RuntimeShape* shape_b =
00089       swap_inputs ? &extended_shape0 : &extended_shape1;
00090 
00091   int i = dims_count - 1;
00092   params->broadcast_shape[0] = 1;
00093   params->broadcast_shape[1] = 1;
00094   params->broadcast_shape[2] = 1;
00095   params->broadcast_shape[3] = 1;
00096   params->broadcast_shape[4] = 1;
00097   // y_0 is greedy: include dims if both or neither equal 1: in other words,
00098   // test for equality rather than (shape_a->Dims(i) != 1).
00099   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
00100     params->broadcast_shape[4] *= shape_b->Dims(i);
00101     --i;
00102   }
00103   // Here either input_a or input_b has dim of 1 (if i >= 0).  If it is input_b
00104   // that has the unit dimension, the next two loops are not entered.
00105   while (i >= 0 && shape_a->Dims(i) == 1) {
00106     params->broadcast_shape[3] *= shape_b->Dims(i);
00107     --i;
00108   }
00109   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
00110     params->broadcast_shape[2] *= shape_a->Dims(i);
00111     --i;
00112   }
00113   // Here either input_a or input_b has dim of 1 (if i >= 0).
00114   while (i >= 0 && shape_b->Dims(i) == 1) {
00115     params->broadcast_shape[1] *= shape_a->Dims(i);
00116     --i;
00117   }
00118   while (i >= 0 && shape_a->Dims(i) == shape_b->Dims(i)) {
00119     params->broadcast_shape[0] *= shape_b->Dims(i);
00120     --i;
00121   }
00122 
00123   // Rarer case is when the broadcast dimensions cannot be handled by a fivefold
00124   // loop.
00125   if (i >= 0) {
00126     params->broadcast_category = BroadcastableOpCategory::kGenericBroadcast;
00127   }
00128   return true;
00129 }
00130 
00131 }  // namespace reference_ops
00132 }  // namespace tflite
00133 
00134 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PROCESS_BROADCAST_SHAPES_H_