// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "ml/mojom/tensor.mojom.h"

#include <vector>

#ifndef ML_TENSOR_VIEW_H_
#define ML_TENSOR_VIEW_H_

namespace ml {

// Provides basic error checking and a common interface for mojom::TensorPtrs of
// any underlying data type.
//
// Basic usage of a TensorView is as follows.
// Non-const view:
//   TensorPtr ptr;
//   TensorView<double> view(ptr);
//   view.Allocate();                  // Creates a FloatList in ptr.
//   view.GetValues().push_back(0.5);  // Adds first elem of FloatList in ptr.
//
// Const view:
//   const TensorPtr ptr = ...
//   const TensorView<double> view(ptr);
//   double v = view.GetValues()[0];   // Gets first elem of FloatList in ptr.
//
// Type-specific funtionality is implemented in template specializations below.
template <typename T>
class TensorView {
 public:
  // The given tensor must outlive this view.
  explicit TensorView(
      const chromeos::machine_learning::mojom::TensorPtr& tensor)
      : tensor_(tensor) {}

  // Return the shape array of the tensor.
  std::vector<int64_t>& GetShape() { return tensor_->shape->value; }

  const std::vector<int64_t>& GetShape() const {
    return tensor_->shape->value;
  }

  // Return the value array of the tensor.
  // Defined only in each specialization for T.
  std::vector<T>& GetValues();

  const std::vector<T>& GetValues() const {
    return const_cast<TensorView<T>*>(this)->GetValues();
  }

  // Return true if the tensor contains values of the correct type. Should be
  // specialized for each tensor data type T.
  bool IsValidType() const { return false; }

  // Return true if the tensor is in a valid format (i.e. valid dimensions and
  // the right number of entries for its shape).
  bool IsValidFormat() const {
    const std::vector<int64_t>& dims = GetShape();

    // Special case: no entries.
    if (dims.empty())
      return GetValues().empty();

    // Otherwise, values size should be the product of all dimension lengths.
    int64_t num_entries = 1;
    for (const int64_t dim : dims) {
      if (dim < 0)
        return false;

      num_entries *= dim;
    }
    return num_entries == GetValues().size();
  }

  // Allocate memory for the members of the tensor object (including values).
  void Allocate() {
    tensor_->shape = chromeos::machine_learning::mojom::Int64List::New();
    // TODO(hidehiko): assigning std::vector<>() to |value| is unneeded
    // on libmojo uprev. Remove them after the uprev.
    tensor_->shape->value = std::vector<int64_t>();
    tensor_->data = chromeos::machine_learning::mojom::ValueList::New();
    AllocateValues();
  }

 private:
  // Allocate memory for the value array of this tensor.
  // Defined only in each specialization for T.
  void AllocateValues();

  const chromeos::machine_learning::mojom::TensorPtr& tensor_;

  DISALLOW_COPY_AND_ASSIGN(TensorView);
};

// Specializations for int tensors.
template <>
std::vector<int64_t>& TensorView<int64_t>::GetValues();
template <>
bool TensorView<int64_t>::IsValidType() const;
template <>
void TensorView<int64_t>::AllocateValues();

// Specializations for float tensors.
template <>
std::vector<double>& TensorView<double>::GetValues();
template <>
bool TensorView<double>::IsValidType() const;
template <>
void TensorView<double>::AllocateValues();

}  // namespace ml

#endif  // ML_TENSOR_VIEW_H_
