blob: c8418fee4f216a6c861c563e465e45c1dd8efe71 [file] [log] [blame] [edit]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/mojom/tensor.mojom.h"
#include <vector>
#ifndef ML_TENSOR_VIEW_H_
#define ML_TENSOR_VIEW_H_
namespace ml {
// Provides basic error checking and a common interface for mojom::TensorPtrs of
// any underlying data type.
//
// Basic usage of a TensorView is as follows.
// Non-const view:
// TensorPtr ptr;
// TensorView<double> view(ptr);
// view.Allocate(); // Creates a FloatList in ptr.
// view.GetValues().push_back(0.5); // Adds first elem of FloatList in ptr.
//
// Const view:
// const TensorPtr ptr = ...
// const TensorView<double> view(ptr);
// double v = view.GetValues()[0]; // Gets first elem of FloatList in ptr.
//
// Type-specific funtionality is implemented in template specializations below.
template <typename T>
class TensorView {
public:
// The given tensor must outlive this view.
explicit TensorView(
const chromeos::machine_learning::mojom::TensorPtr& tensor)
: tensor_(tensor) {}
TensorView(const TensorView&) = delete;
TensorView& operator=(const TensorView&) = delete;
// Return the shape array of the tensor.
std::vector<int64_t>& GetShape() { return tensor_->shape->value; }
const std::vector<int64_t>& GetShape() const { return tensor_->shape->value; }
// Return the value array of the tensor.
// Defined only in each specialization for T.
std::vector<T>& GetValues();
const std::vector<T>& GetValues() const {
return const_cast<TensorView<T>*>(this)->GetValues();
}
// Return true if the tensor contains values of the correct type. Should be
// specialized for each tensor data type T.
bool IsValidType() const { return false; }
// Return true if the tensor is in a valid format (i.e. valid dimensions and
// the right number of entries for its shape).
bool IsValidFormat() const {
const std::vector<int64_t>& dims = GetShape();
// Special case: no entries.
if (dims.empty())
return GetValues().empty();
// Otherwise, values size should be the product of all dimension lengths.
int64_t num_entries = 1;
for (const int64_t dim : dims) {
if (dim < 0)
return false;
num_entries *= dim;
}
return num_entries == GetValues().size();
}
// Allocate memory for the members of the tensor object (including values).
void Allocate() {
tensor_->shape = chromeos::machine_learning::mojom::Int64List::New();
// TODO(hidehiko): assigning std::vector<>() to `value` is unneeded
// on libmojo uprev. Remove them after the uprev.
tensor_->shape->value = std::vector<int64_t>();
tensor_->data = chromeos::machine_learning::mojom::ValueList::New();
AllocateValues();
}
private:
// Allocate memory for the value array of this tensor.
// Defined only in each specialization for T.
void AllocateValues();
const chromeos::machine_learning::mojom::TensorPtr& tensor_;
};
// Specializations for int tensors.
template <>
std::vector<int64_t>& TensorView<int64_t>::GetValues();
template <>
bool TensorView<int64_t>::IsValidType() const;
template <>
void TensorView<int64_t>::AllocateValues();
// Specializations for float tensors.
template <>
std::vector<double>& TensorView<double>::GetValues();
template <>
bool TensorView<double>::IsValidType() const;
template <>
void TensorView<double>::AllocateValues();
} // namespace ml
#endif // ML_TENSOR_VIEW_H_