blob: 377582f5e45d0abda5bd3acd6faaada94a1bf759 [file] [log] [blame]
// Copyright 2019 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/machine_learning_service_impl.h"
#include <memory>
#include <string>
#include <vector>
#include <base/at_exit.h>
#include <base/bind.h>
#include <base/macros.h>
#include <base/run_loop.h>
#include <base/message_loop/message_loop.h>
#include <brillo/message_loops/base_message_loop.h>
#include <fuzzer/FuzzedDataProvider.h>
#include <mojo/public/cpp/bindings/binding.h>
#include <mojo/public/cpp/bindings/interface_request.h>
#include "ml/mojom/graph_executor.mojom.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/mojom/model.mojom.h"
#include "ml/tensor_view.h"
#include "mojo/core/embedder/embedder.h"
#include "mojo/core/embedder/scoped_ipc_support.h"
namespace ml {
namespace {
using ::chromeos::machine_learning::mojom::FlatBufferModelSpec;
using ::chromeos::machine_learning::mojom::FlatBufferModelSpecPtr;
using ::chromeos::machine_learning::mojom::LoadModelResult;
using ::chromeos::machine_learning::mojom::MachineLearningServicePtr;
using ::chromeos::machine_learning::mojom::ModelPtr;
class Environment {
public:
Environment() {
logging::SetMinLogLevel(logging::LOG_FATAL); // <- DISABLE LOGGING.
mojo::core::Init();
}
};
} // namespace
class MLServiceFuzzer {
public:
MLServiceFuzzer() = default;
~MLServiceFuzzer() = default;
void SetUp() {
ipc_support_ = std::make_unique<mojo::core::ScopedIPCSupport>(
base::ThreadTaskRunnerHandle::Get(),
mojo::core::ScopedIPCSupport::ShutdownPolicy::FAST);
ml_service_impl_ = std::make_unique<MachineLearningServiceImpl>(
mojo::MakeRequest(&ml_service_).PassMessagePipe(),
base::Closure());
}
void PerformInference(const uint8_t* data, size_t size) {
FlatBufferModelSpecPtr spec = FlatBufferModelSpec::New();
spec->model_string = std::string(reinterpret_cast<const char*>(data), size);
spec->inputs["input"] = 3;
spec->outputs["output"] = 4;
spec->metrics_model_name = "TestModel";
// Load model.
bool load_model_done = false;
ml_service_->LoadFlatBufferModel(
std::move(spec), mojo::MakeRequest(&model_),
base::Bind(
[](bool* load_model_done, const LoadModelResult result) {
*load_model_done = true;
},
&load_model_done));
base::RunLoop().RunUntilIdle();
CHECK(load_model_done);
}
private:
std::unique_ptr<mojo::core::ScopedIPCSupport> ipc_support_;
MachineLearningServicePtr ml_service_;
std::unique_ptr<MachineLearningServiceImpl> ml_service_impl_;
ModelPtr model_;
DISALLOW_COPY_AND_ASSIGN(MLServiceFuzzer);
};
} // namespace ml
extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) {
static ml::Environment env;
base::AtExitManager at_exit_manager;
// Mock main task runner
base::MessageLoopForIO message_loop;
brillo::BaseMessageLoop brillo_loop(&message_loop);
brillo_loop.SetAsCurrent();
ml::MLServiceFuzzer fuzzer;
fuzzer.SetUp();
fuzzer.PerformInference(data, size);
return 0;
}