blob: e8734e1647b75f36c757025083baff413dfe0c48 [file] [log] [blame] [edit]
// Copyright 2018 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#ifndef ML_TEST_UTILS_H_
#define ML_TEST_UTILS_H_
#include <string>
#include <vector>
#include "ml/mojom/tensor.mojom.h"
#include "ml/tensor_view.h"
namespace ml {
// Create a tensor with the given shape and values. Does no validity checking
// (by design, as we sometimes need to pass bad tensors to test error handling).
template <typename T>
chromeos::machine_learning::mojom::TensorPtr NewTensor(
const std::vector<int64_t>& shape, const std::vector<T>& values) {
auto tensor(chromeos::machine_learning::mojom::Tensor::New());
TensorView<T> tensor_view(tensor);
tensor_view.Allocate();
tensor_view.GetShape() = shape;
tensor_view.GetValues() = values;
return tensor;
}
// Return the model directory for tests (or die if it cannot be obtained).
std::string GetTestModelDir();
} // namespace ml
#endif // ML_TEST_UTILS_H_