Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers binary_function.h Source File

binary_function.h

00001 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_
00017 
00018 #include "tensorflow/lite/kernels/internal/common.h"
00019 #include "tensorflow/lite/kernels/internal/compatibility.h"
00020 #include "tensorflow/lite/kernels/internal/types.h"
00021 
00022 namespace tflite {
00023 
00024 namespace reference_ops {
00025 
00026 // TODO(ycling): Refactoring. Remove BroadcastLogical and use the more
00027 // generalized and efficient BroadcastBinaryFunction.
00028 //
00029 // Also appears to duplicte MinimumMaximum.
00030 //
00031 // R: Result type. T1: Input 1 type. T2: Input 2 type.
00032 template <typename R, typename T1, typename T2>
00033 inline void BroadcastBinaryFunction4DSlow(
00034     const RuntimeShape& unextended_input1_shape, const T1* input1_data,
00035     const RuntimeShape& unextended_input2_shape, const T2* input2_data,
00036     const RuntimeShape& unextended_output_shape, R* output_data,
00037     R (*func)(T1, T2)) {
00038   TFLITE_DCHECK_LE(unextended_input1_shape.DimensionsCount(), 4);
00039   TFLITE_DCHECK_LE(unextended_input2_shape.DimensionsCount(), 4);
00040   TFLITE_DCHECK_LE(unextended_output_shape.DimensionsCount(), 4);
00041   const RuntimeShape output_shape =
00042       RuntimeShape::ExtendedShape(4, unextended_output_shape);
00043 
00044   NdArrayDesc<4> desc1;
00045   NdArrayDesc<4> desc2;
00046   NdArrayDescsForElementwiseBroadcast(unextended_input1_shape,
00047                                       unextended_input2_shape, &desc1, &desc2);
00048 
00049   for (int b = 0; b < output_shape.Dims(0); ++b) {
00050     for (int y = 0; y < output_shape.Dims(1); ++y) {
00051       for (int x = 0; x < output_shape.Dims(2); ++x) {
00052         for (int c = 0; c < output_shape.Dims(3); ++c) {
00053           auto out_idx = Offset(output_shape, b, y, x, c);
00054           auto in1_idx = SubscriptToIndex(desc1, b, y, x, c);
00055           auto in2_idx = SubscriptToIndex(desc2, b, y, x, c);
00056           auto in1_val = input1_data[in1_idx];
00057           auto in2_val = input2_data[in2_idx];
00058           output_data[out_idx] = func(in1_val, in2_val);
00059         }
00060       }
00061     }
00062   }
00063 }
00064 
00065 // R: Result type. T1: Input 1 type. T2: Input 2 type.
00066 // TODO(renjieliu): Refactor other binary functions to use this one.
00067 template <typename R, typename T1, typename T2>
00068 inline void BinaryFunction(const RuntimeShape& input1_shape,
00069                            const T1* input1_data,
00070                            const RuntimeShape& input2_shape,
00071                            const T2* input2_data,
00072                            const RuntimeShape& output_shape, R* output_data,
00073                            R (*func)(T1, T2)) {
00074   const int flat_size =
00075       MatchingFlatSize(input1_shape, input2_shape, output_shape);
00076   for (int i = 0; i < flat_size; ++i) {
00077     output_data[i] = func(input1_data[i], input2_data[i]);
00078   }
00079 }
00080 
00081 }  // namespace reference_ops
00082 }  // namespace tflite
00083 
00084 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_BINARY_FUNCTION_H_