Daniel Konegen / MNIST_example

Dependencies:   mbed-os

Embed: (wiki syntax)

« Back to documentation index

Show/hide line numbers tensor.h Source File

tensor.h

00001 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
00002 
00003 Licensed under the Apache License, Version 2.0 (the "License");
00004 you may not use this file except in compliance with the License.
00005 You may obtain a copy of the License at
00006 
00007     http://www.apache.org/licenses/LICENSE-2.0
00008 
00009 Unless required by applicable law or agreed to in writing, software
00010 distributed under the License is distributed on an "AS IS" BASIS,
00011 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
00012 See the License for the specific language governing permissions and
00013 limitations under the License.
00014 ==============================================================================*/
00015 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
00016 #define TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_
00017 
00018 #include <complex>
00019 #include <vector>
00020 
00021 #include "tensorflow/lite/c/c_api_internal.h"
00022 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
00023 #include "tensorflow/lite/kernels/internal/types.h"
00024 #include "tensorflow/lite/string_util.h"
00025 
00026 namespace tflite {
00027 
00028 inline RuntimeShape GetTensorShape(std::vector<int32_t> data) {
00029   return RuntimeShape(data.size(), data.data());
00030 }
00031 
00032 // A list of tensors in a format that can be used by kernels like split and
00033 // concatenation.
00034 template <typename T>
00035 class VectorOfTensors {
00036  public:
00037   // Build with the tensors in 'tensor_list'.
00038   VectorOfTensors(const TfLiteContext& context,
00039                   const TfLiteIntArray& tensor_list) {
00040     int num_tensors = tensor_list.size;
00041 
00042     all_data_.reserve(num_tensors);
00043     all_shape_.reserve(num_tensors);
00044     all_shape_ptr_.reserve(num_tensors);
00045 
00046     for (int i = 0; i < num_tensors; ++i) {
00047       TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
00048       all_data_.push_back(GetTensorData<T>(t));
00049       all_shape_.push_back(GetTensorShape(t));
00050     }
00051 
00052     // Taking the pointer from inside a std::vector is only OK if the vector is
00053     // never modified, so we populate all_shape in the previous loop and then we
00054     // are free to grab iterators here.
00055     for (int i = 0; i < num_tensors; ++i) {
00056       all_shape_ptr_.push_back(&all_shape_[i]);
00057     }
00058   }
00059   // Return a pointer to the data pointers of all tensors in the list. For
00060   // example:
00061   //   float* const* f = v.data();
00062   //   f[0][1] is the second element of the first tensor.
00063   T* const* data() const { return all_data_.data(); }
00064 
00065   // Return a pointer the shape pointers of all tensors in the list. For
00066   // example:
00067   //   const RuntimeShape* const* d = v.dims();
00068   //   dims[1] are the dimensions of the second tensor in the list.
00069   const RuntimeShape* const* shapes() const { return all_shape_ptr_.data(); }
00070 
00071  private:
00072   std::vector<T*> all_data_;
00073   std::vector<RuntimeShape> all_shape_;
00074   std::vector<RuntimeShape*> all_shape_ptr_;
00075 };
00076 
00077 // A list of quantized tensors in a format that can be used by kernels like
00078 // split and concatenation.
00079 class VectorOfQuantizedTensors : public VectorOfTensors<uint8> {
00080  public:
00081   // Build with the tensors in 'tensor_list'.
00082   VectorOfQuantizedTensors(const TfLiteContext& context,
00083                            const TfLiteIntArray& tensor_list)
00084       : VectorOfTensors<uint8>(context, tensor_list) {
00085     for (int i = 0; i < tensor_list.size; ++i) {
00086       TfLiteTensor* t = &context.tensors[tensor_list.data[i]];
00087       zero_point_.push_back(t->params.zero_point);
00088       scale_.push_back(t->params.scale);
00089     }
00090   }
00091 
00092   const float* scale() const { return scale_.data(); }
00093   const int32* zero_point() const { return zero_point_.data(); }
00094 
00095  private:
00096   std::vector<int32> zero_point_;
00097   std::vector<float> scale_;
00098 };
00099 
00100 // Writes randomly accessed values from `input` sequentially into `output`.
00101 template <typename T>
00102 class SequentialTensorWriter {
00103  public:
00104   SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output) {
00105     input_data_ = GetTensorData<T>(input);
00106     output_ptr_ = GetTensorData<T>(output);
00107   }
00108   SequentialTensorWriter(const T* input_data, T* output_data)
00109       : input_data_(input_data), output_ptr_(output_data) {}
00110 
00111   void Write(int position) { *output_ptr_++ = input_data_[position]; }
00112   void WriteN(int position, int len) {
00113     memcpy(output_ptr_, &input_data_[position], sizeof(T) * len);
00114     output_ptr_ += len;
00115   }
00116 
00117  private:
00118   const T* input_data_;
00119   T* output_ptr_;
00120 };
00121 
00122 template <>
00123 class SequentialTensorWriter<string> {
00124  public:
00125   SequentialTensorWriter(const TfLiteTensor* input, TfLiteTensor* output)
00126       : input_(input), output_(output) {}
00127   ~SequentialTensorWriter() { buffer_.WriteToTensor(output_, nullptr); }
00128 
00129   void Write(int position) { this->WriteN(position, 1); }
00130   void WriteN(int position, int len) {
00131     for (int i = 0; i < len; i++) {
00132       buffer_.AddString(GetString(input_, position + i));
00133     }
00134   }
00135 
00136  private:
00137   const TfLiteTensor* input_;
00138   TfLiteTensor* output_;
00139   DynamicBuffer buffer_;
00140 };
00141 
00142 }  // namespace tflite
00143 
00144 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_TENSOR_H_