blob: 513d797821515c976d87072d793efee4ca109227 [file] [log] [blame]
// Copyright 2022 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/web_platform_model_loader_impl.h"
#include <algorithm>
#include <utility>
#include <base/bind.h>
#include <base/time/time.h>
#include <base/callback_helpers.h>
#include <base/notreached.h>
#include <brillo/message_loops/message_loop.h>
#include "base/debug/leak_annotations.h"
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/big_buffer.mojom.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/process.h"
#include "ml/web_platform_model_impl.h"
namespace ml {
namespace {
using model_loader::mojom::CreateModelLoaderOptionsPtr;
using model_loader::mojom::ModelLoader;
} // namespace
WebPlatformModelLoaderImpl* WebPlatformModelLoaderImpl::Create(
mojo::PendingReceiver<ModelLoader> receiver,
CreateModelLoaderOptionsPtr options) {
// TODO(honglinyu): in the first version, `options.threads` is unused. We need
// to make it meaningful.
auto model_loaded_impl =
new WebPlatformModelLoaderImpl(std::move(receiver), std::move(options));
// In production, `model_loaded_impl` is intentionally leaked, because this
// model runs in its own process and the model's memory is freed when the
// process exits. However, when being tested with ASAN, this memory leak
// causes an error. Therefore, we annotate it as an intentional leak.
ANNOTATE_LEAKING_OBJECT_PTR(model_loaded_impl);
// Sets the default disconnection handler which,
// - resets the `receiver_`.
// - quits the message loop if no model is connected.
// Note that it should NOT delete itself.
model_loaded_impl->receiver_.set_disconnect_handler(
base::BindOnce(&WebPlatformModelLoaderImpl::DefaultDisconnectHandler,
base::Unretained(model_loaded_impl)));
return model_loaded_impl;
}
WebPlatformModelLoaderImpl::WebPlatformModelLoaderImpl(
mojo::PendingReceiver<ModelLoader> receiver,
CreateModelLoaderOptionsPtr options)
: num_of_connected_models_(0), receiver_(this, std::move(receiver)) {
// TODO(honglinyu): makes use of `options->num_threads`.
DETACH_FROM_SEQUENCE(sequence_checker_);
}
bool WebPlatformModelLoaderImpl::IsValid() const {
return receiver_.is_bound();
}
void WebPlatformModelLoaderImpl::RegisterModel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
++num_of_connected_models_;
}
int WebPlatformModelLoaderImpl::UnregisterModel() {
DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
return --num_of_connected_models_;
}
void WebPlatformModelLoaderImpl::Load(
mojo_base::mojom::BigBufferPtr model_content, LoadCallback callback) {
WebPlatformModelImpl::Create(std::move(model_content), std::move(callback),
this);
}
// Note that the disconnect handler should NOT delete itself because the
// `WebPlatformModelImpl` objects have reference pointers to it.
void WebPlatformModelLoaderImpl::DefaultDisconnectHandler() {
receiver_.reset();
if (num_of_connected_models_ == 0) {
brillo::MessageLoop::current()->BreakLoop();
}
}
} // namespace ml