blob: fecb9bf97a4fc6470baf24ada716d373d8cc1451 [file] [log] [blame]
// Copyright 2022 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include "ml/web_platform_model_impl.h"
#include <algorithm>
#include <utility>
#include <base/bind.h>
#include <base/time/time.h>
#include <base/callback_helpers.h>
#include <base/notreached.h>
#include <brillo/message_loops/message_loop.h>
#include "base/debug/leak_annotations.h"
#include "ml/machine_learning_service_impl.h"
#include "ml/mojom/big_buffer.mojom.h"
#include "ml/mojom/machine_learning_service.mojom.h"
#include "ml/process.h"
namespace ml {
namespace {
std::vector<unsigned int> ConvertTfLiteDimensions(
TfLiteIntArray* tflite_int_array) {
if (tflite_int_array == nullptr)
return {};
std::vector<unsigned int> ret(tflite_int_array->size);
for (int i = 0; i < tflite_int_array->size; i++) {
const auto v = tflite_int_array->data[i];
// TfLiteIntArray's data can be less than 0. But for dimensions, it must be
// >= 0.
if (v < 0)
return {};
ret[i] = static_cast<unsigned int>(v);
}
return ret;
}
// Notice that in the new version of TFLite, kUint16 will be supported.
model_loader::mojom::DataType ConvertTfLiteTypeToMojo(TfLiteType tflite_type) {
switch (tflite_type) {
case kTfLiteFloat32:
return model_loader::mojom::DataType::kFloat32;
case kTfLiteInt32:
return model_loader::mojom::DataType::kInt32;
case kTfLiteUInt8:
return model_loader::mojom::DataType::kUint8;
case kTfLiteInt64:
return model_loader::mojom::DataType::kInt64;
case kTfLiteBool:
return model_loader::mojom::DataType::kBool;
case kTfLiteInt16:
return model_loader::mojom::DataType::kInt16;
case kTfLiteInt8:
return model_loader::mojom::DataType::kInt8;
case kTfLiteFloat16:
return model_loader::mojom::DataType::kFloat16;
case kTfLiteFloat64:
return model_loader::mojom::DataType::kFloat64;
case kTfLiteUInt64:
return model_loader::mojom::DataType::kUint64;
case kTfLiteUInt32:
return model_loader::mojom::DataType::kUint32;
case kTfLiteNoType:
case kTfLiteString:
case kTfLiteComplex64:
case kTfLiteComplex128:
case kTfLiteResource:
case kTfLiteVariant:
return model_loader::mojom::DataType::kUnknown;
}
}
} // namespace
void WebPlatformModelImpl::Create(
mojo_base::mojom::BigBufferPtr model_content,
WebPlatformModelLoaderImpl::LoadCallback callback,
WebPlatformModelLoaderImpl* loader) {
mojo::PendingRemote<ml::model_loader::mojom::Model> remote;
auto model_loaded_impl =
new WebPlatformModelImpl(remote.InitWithNewPipeAndPassReceiver(), loader);
if (!model_loaded_impl->Load(std::move(model_content), callback)) {
// In this case, the `callback` has been called (including returning the
// error messages to the remote process) in `Load()` already.
delete model_loaded_impl;
} else {
loader->RegisterModel();
model_loaded_impl->receiver_.set_disconnect_handler(
base::BindOnce(&WebPlatformModelImpl::DefaultDisconnectHandler,
base::Unretained(model_loaded_impl)));
std::move(callback).Run(model_loader::mojom::LoadModelResult::kOk,
std::move(remote),
model_loaded_impl->GetModelInfo());
}
}
WebPlatformModelImpl::WebPlatformModelImpl(
mojo::PendingReceiver<model_loader::mojom::Model> receiver,
WebPlatformModelLoaderImpl* loader)
: loader_(loader), receiver_(this, std::move(receiver)) {}
void WebPlatformModelImpl::DefaultDisconnectHandler() {
const auto remaining_models = loader_->UnregisterModel();
if (remaining_models == 0 && !loader_->IsValid()) {
brillo::MessageLoop::current()->BreakLoop();
} else {
delete this;
}
}
void WebPlatformModelImpl::BuildModelFromBytes(
mojo_base::mojom::BigBufferPtr& model_content) {
const auto incoming_pointer =
reinterpret_cast<char*>(model_content->get_bytes().data());
// Checks alignment. TfLite requires the model buffer to be 32bit aligned.
if (reinterpret_cast<std::uintptr_t>(incoming_pointer) % 4 == 0) {
model_big_buffer_ptr_ = std::move(model_content);
model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
incoming_pointer, model_big_buffer_ptr_->get_bytes().size());
} else {
model_size_ = model_content->get_bytes().size();
// The buffer returned from `new` is always aligned.
model_content_.reset(new char[model_size_]);
memcpy(model_content_.get(), incoming_pointer, model_size_);
model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
model_content_.get(), model_size_);
}
}
bool WebPlatformModelImpl::BuildModelFromSharedBuffer(
mojo_base::mojom::BigBufferPtr& model_content,
WebPlatformModelLoaderImpl::LoadCallback& callback) {
// If it is shared memory, for security reason, we MUST make a copy.
model_size_ = model_content->get_shared_memory()->size;
auto shared_region = base::WritableSharedMemoryRegion::ConvertToReadOnly(
mojo::UnwrapWritableSharedMemoryRegion(
std::move(model_content->get_shared_memory()->buffer_handle)));
auto shared_mapping = shared_region.Map();
if (!shared_region.IsValid() || !shared_mapping.IsValid()) {
std::move(callback).Run(model_loader::mojom::LoadModelResult::kUnknownError,
mojo::NullRemote(), nullptr);
return false;
}
model_content_.reset(new char[model_size_]);
memcpy(model_content_.get(), shared_mapping.GetMemoryAs<char>(), model_size_);
model_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(
model_content_.get(), model_size_);
return true;
}
void WebPlatformModelImpl::CollectTensorInformation(
const std::vector<int>& tensor_indices_in_model,
base::flat_map<std::string, model_loader::mojom::TensorInfoPtr>&
io_tensor_info) {
for (auto tensor_idx : tensor_indices_in_model) {
std::string tensor_name(interpreter_->tensor(tensor_idx)->name);
TensorInfo tensor_info;
tensor_info.size = interpreter_->tensor(tensor_idx)->bytes;
tensor_info.data_type = interpreter_->tensor(tensor_idx)->type;
name_to_tensor_info_[tensor_name] = tensor_info;
auto mojo_tensor_info = model_loader::mojom::TensorInfo::New();
mojo_tensor_info->byte_size = tensor_info.size;
mojo_tensor_info->data_type =
ConvertTfLiteTypeToMojo(tensor_info.data_type);
mojo_tensor_info->dimensions =
ConvertTfLiteDimensions(interpreter_->tensor(tensor_idx)->dims);
io_tensor_info[tensor_name] = std::move(mojo_tensor_info);
}
}
bool WebPlatformModelImpl::Load(
mojo_base::mojom::BigBufferPtr model_content,
WebPlatformModelLoaderImpl::LoadCallback& callback) {
if (model_content->is_invalid_buffer()) {
std::move(callback).Run(model_loader::mojom::LoadModelResult::kUnknownError,
mojo::NullRemote(), nullptr);
return false;
} else if (model_content->is_bytes()) {
BuildModelFromBytes(model_content);
} else if (model_content->is_shared_memory()) {
if (!BuildModelFromSharedBuffer(model_content, callback))
// The `callback` has already called with appropriate error messages in
// `BuildModelFromSharedBuffer`.
return false;
} else {
LOG(FATAL) << "Unknown type of input BigBuffer. Please check if "
"mojom::BigBuffer has been extended.";
}
if (model_ == nullptr) {
std::move(callback).Run(model_loader::mojom::LoadModelResult::kInvalidModel,
mojo::NullRemote(), nullptr);
brillo::MessageLoop::current()->BreakLoop();
return false;
}
// Sets up the interpreter.
resolver_ = std::make_unique<tflite::ops::builtin::BuiltinOpResolver>();
interpreter_ = std::make_unique<tflite::Interpreter>();
const TfLiteStatus resolve_status =
tflite::InterpreterBuilder(*model_, *resolver_)(&interpreter_);
if (resolve_status != kTfLiteOk || !interpreter_) {
std::move(callback).Run(model_loader::mojom::LoadModelResult::kInvalidModel,
mojo::NullRemote(), nullptr);
return false;
}
// If you want to set up delegate (e.g. NNAPI), do it here.
// Allocates the tensors.
// Notice that maybe we can move this to the `compute()` function.
if (interpreter_->AllocateTensors() != kTfLiteOk) {
std::move(callback).Run(model_loader::mojom::LoadModelResult::kUnknownError,
mojo::NullRemote(), nullptr);
return false;
}
return true;
}
model_loader::mojom::ModelInfoPtr WebPlatformModelImpl::GetModelInfo() {
auto model_info = model_loader::mojom::ModelInfo::New();
CollectTensorInformation(interpreter_->inputs(),
model_info->input_tensor_info);
CollectTensorInformation(interpreter_->outputs(),
model_info->output_tensor_info);
return model_info;
}
void WebPlatformModelImpl::Compute(
const base::flat_map<std::string, std::vector<uint8_t>>& name_tensors,
ComputeCallback callback) {
// Sets up the input.
// Checks if the input and output matches.
if (interpreter_->inputs().size() != name_tensors.size()) {
std::move(callback).Run(
model_loader::mojom::ComputeResult::kIncorrectNumberOfInputs,
std::nullopt);
return;
}
// More self-consistency check on the input tensors.
for (auto tensor_idx : interpreter_->inputs()) {
std::string tensor_name(interpreter_->tensor(tensor_idx)->name);
auto iter = name_tensors.find(tensor_name);
if (iter == name_tensors.end()) {
std::move(callback).Run(model_loader::mojom::ComputeResult::kMissingInput,
std::nullopt);
return;
}
if (iter->second.size() != interpreter_->tensor(tensor_idx)->bytes) {
std::move(callback).Run(
model_loader::mojom::ComputeResult::kInvalidInputBufferSize,
std::nullopt);
return;
}
}
// Fills the buffer.
for (auto tensor_idx : interpreter_->inputs()) {
std::string tensor_name(interpreter_->tensor(tensor_idx)->name);
auto iter = name_tensors.find(tensor_name);
memcpy(interpreter_->tensor(tensor_idx)->data.raw, iter->second.data(),
iter->second.size());
}
// Does the computation.
if (interpreter_->Invoke() != kTfLiteOk) {
std::move(callback).Run(model_loader::mojom::ComputeResult::kUnknownError,
std::nullopt);
return;
}
// Fills the buffer with output.
base::flat_map<std::string, std::vector<uint8_t>> output_buffer_infos;
for (auto tensor_idx : interpreter_->outputs()) {
std::vector<uint8_t> tensor(
static_cast<size_t>(interpreter_->tensor(tensor_idx)->bytes));
memcpy(tensor.data(), interpreter_->tensor(tensor_idx)->data.raw,
interpreter_->tensor(tensor_idx)->bytes);
output_buffer_infos[interpreter_->tensor(tensor_idx)->name] =
std::move(tensor);
}
std::move(callback).Run(model_loader::mojom::ComputeResult::kOk,
std::move(output_buffer_infos));
}
} // namespace ml