blob: f0c4b53039563b6e767056f95cb2f1ecd1ab6231 [file] [log] [blame]
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <memory>
#include <string>
#include <vector>
#include <base/at_exit.h>
#include <base/check.h>
#include <base/check_op.h>
#include <base/bind.h>
#include <base/run_loop.h>
#include <base/task/single_thread_task_executor.h>
#include <brillo/message_loops/base_message_loop.h>
#include <fuzzer/FuzzedDataProvider.h>
#include <libprotobuf-mutator/src/libfuzzer/libfuzzer_macro.h>
#include <mojo/public/cpp/bindings/remote.h>
#include "chrome/knowledge/handwriting/handwriting_interface.pb.h"
#include "chrome/knowledge/handwriting/web_platform_handwriting_fuzz_container.pb.h"
#include "ml/fuzzers/handwriting_fake.h"
#include "ml/handwriting.h"
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/mojom/web_platform_handwriting.mojom.h"
#include "ml/process.h"
#include "mojo/core/embedder/embedder.h"
#include "mojo/core/embedder/scoped_ipc_support.h"
namespace ml {
namespace {
using ::chromeos::machine_learning::mojom::LoadHandwritingModelResult;
using ::chromeos::machine_learning::mojom::MachineLearningService;
using ::chromeos::machine_learning::web_platform::mojom::HandwritingHints;
using ::chromeos::machine_learning::web_platform::mojom::HandwritingHintsPtr;
using ::chromeos::machine_learning::web_platform::mojom::HandwritingPoint;
using ::chromeos::machine_learning::web_platform::mojom::
HandwritingPredictionPtr;
using ::chromeos::machine_learning::web_platform::mojom::HandwritingStroke;
using ::chromeos::machine_learning::web_platform::mojom::HandwritingStrokePtr;
} // namespace
class WebPlatformHandwritingFuzzer {
public:
WebPlatformHandwritingFuzzer() = default;
WebPlatformHandwritingFuzzer(const WebPlatformHandwritingFuzzer&) = delete;
WebPlatformHandwritingFuzzer& operator=(const WebPlatformHandwritingFuzzer&) =
delete;
~WebPlatformHandwritingFuzzer() = default;
void SetUp(const std::string& language) {
ipc_support_ = std::make_unique<mojo::core::ScopedIPCSupport>(
base::ThreadTaskRunnerHandle::Get(),
mojo::core::ScopedIPCSupport::ShutdownPolicy::FAST);
Process::GetInstance()->SetTypeForTesting(
Process::Type::kSingleProcessForTest);
ml_service_impl_ = std::make_unique<MachineLearningServiceImpl>(
ml_service_.BindNewPipeAndPassReceiver(), base::Closure());
bool model_callback_done = false;
auto constraint = chromeos::machine_learning::web_platform::mojom::
HandwritingModelConstraint::New();
constraint->languages.push_back(language);
ml_service_->LoadWebPlatformHandwritingModel(
std::move(constraint), recognizer_.BindNewPipeAndPassReceiver(),
base::Bind(
[](bool* model_callback_done,
const LoadHandwritingModelResult result) {
CHECK_EQ(result, LoadHandwritingModelResult::OK);
*model_callback_done = true;
},
&model_callback_done));
base::RunLoop().RunUntilIdle();
CHECK(model_callback_done);
CHECK(recognizer_.is_bound());
}
// Populates inputs with the fuzzed proto and calls GetPrediction, this fuzzes
// the input and output conversion code.
void PerformPrediction(
const chrome_knowledge::WebPlatformHandwritingFuzzContainer&
random_container_proto) {
// Populates inputs with the random_container_proto.
std::vector<HandwritingStrokePtr> strokes;
HandwritingHintsPtr hints = HandwritingHints::New();
hints->alternatives = random_container_proto.hint_alternatives();
hints->text_context = random_container_proto.hint_text_context();
for (const chrome_knowledge::InkStroke& ink_stroke :
random_container_proto.strokes()) {
auto stroke = HandwritingStroke::New();
for (const chrome_knowledge::InkPoint& ink_point : ink_stroke.points()) {
auto point = HandwritingPoint::New();
auto location = gfx::mojom::PointF::New();
location->x = ink_point.x();
location->y = ink_point.y();
point->location = std::move(location);
stroke->points.push_back(std::move(point));
}
strokes.push_back(std::move(stroke));
}
bool infer_callback_done = false;
recognizer_->GetPrediction(
std::move(strokes), std::move(hints),
base::Bind([](bool* infer_callback_done,
base::Optional<std::vector<HandwritingPredictionPtr>>
predictions) { *infer_callback_done = true; },
&infer_callback_done));
base::RunLoop().RunUntilIdle();
CHECK(infer_callback_done);
}
private:
std::unique_ptr<mojo::core::ScopedIPCSupport> ipc_support_;
std::unique_ptr<MachineLearningServiceImpl> ml_service_impl_;
mojo::Remote<MachineLearningService> ml_service_;
mojo::Remote<
chromeos::machine_learning::web_platform::mojom::HandwritingRecognizer>
recognizer_;
};
} // namespace ml
class Environment {
public:
Environment() {
logging::SetMinLogLevel(logging::LOGGING_FATAL); // <- DISABLE LOGGING.
mojo::core::Init();
}
};
DEFINE_PROTO_FUZZER(const chrome_knowledge::WebPlatformHandwritingFuzzContainer&
random_container_proto) {
static Environment environment;
base::AtExitManager at_exit_manager;
// Mock main task runner
base::SingleThreadTaskExecutor task_executor(base::MessagePumpType::IO);
brillo::BaseMessageLoop brillo_loop(task_executor.task_runner());
brillo_loop.SetAsCurrent();
ml::HandwritingLibraryFake handwriting_library_fake;
handwriting_library_fake.SetOutput(
random_container_proto.recognizer_result());
ml::HandwritingLibrary::UseFakeHandwritingLibraryForTesting(
&handwriting_library_fake);
ml::WebPlatformHandwritingFuzzer fuzzer;
fuzzer.SetUp(random_container_proto.constraint_language());
fuzzer.PerformPrediction(random_container_proto);
}