odml: Sync with Chrome ToT (up to Add SetPriority() to on-device Session)
This CL includes:
* Blink side changes for guidance/structured outputs
crrev.com/c/6350280
* Bind paths to model controllers
crrev.com/c/6419383
* Add GetProbabilitiesBlocking() API to odml mojo
crrev.com/c/6396181
* Add GetProbabilities() API to chrome_ml_api
crrev.com/c/6362334
* Add a few more metrics for on-device model usage
crrev.com/c/6441693
* Add SetPriority() to on-device Session interface
crrev.com/c/6442263
BUG=b:353900545
TEST=unit tests
Change-Id: I8eed88ecf2c86eeb90bba8d7381fc1160db8aeca
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/6498469
Reviewed-by: Howard Yang <hcyang@google.com>
Tested-by: John L Chen <zuan@chromium.org>
diff --git a/odml/mojom/on_device_model.mojom b/odml/mojom/on_device_model.mojom
index 9e15050..9232360 100644
--- a/odml/mojom/on_device_model.mojom
+++ b/odml/mojom/on_device_model.mojom
@@ -214,6 +214,22 @@
uint32? top_k;
// Deprecated: use SessionParams.temperature instead.
float? temperature;
+
+ // A JSON schema defining structured output requirements for the response.
+ // Passed as an opaque JSON blob to llguidance.
+ [MinVersion=1]
+ string? response_json_schema;
+};
+
+// Priorities which determine how requests to a session are scheduled.
+[Stable, Extensible]
+enum Priority {
+ // Requests should be treated as high priority, the user may be actively
+ // waiting for a response.
+ [Default] kForeground,
+ // Requests are not urgent and may be queued. Background requests will only
+ // run if there are no active foreground requests.
+ kBackground,
};
// A session for a model that allows adding context and then executing an input
@@ -243,6 +259,19 @@
// Clones the current session. The cloned session will have the same context
// as the current session.
Clone@4(pending_receiver<Session> session);
+
+ // Gets the probability for a series of tokens on top of the current
+ // context. Capabilities.probabilities_output must be specified to use this.
+ // Note that this is implemented as a blocking method on the service side
+ // and should only be used in debugging/testing.
+ [MinVersion=3]
+ GetProbabilitiesBlocking@8(string text) => (array<float> probabilities);
+
+ // Sets the priority for currently queued requests to this session and future
+ // requests. Any clones made of this session will inherit the current
+ // priority. Priority for new sessions defaults to kForeground.
+ [MinVersion=4]
+ SetPriority@9(Priority priority);
};
// A loaded model which can be queried. This interface must be controlled by the
diff --git a/odml/on_device_model/ml/chrome_ml.cc b/odml/on_device_model/ml/chrome_ml.cc
index 60723f9..ce3f88d 100644
--- a/odml/on_device_model/ml/chrome_ml.cc
+++ b/odml/on_device_model/ml/chrome_ml.cc
@@ -102,6 +102,24 @@
}
}
+void RecordMediumTimesHistogram(const char* name, int64_t milliseconds) {
+ base::AutoLock lock(*g_metrics_lock);
+ if (g_metrics) {
+ // This is originally Uma_HistogramMediumTimes() in Chromium but ChromiumOS
+ // doesn't have Uma_HistogramMediumTimes() so an equivalent is implemented
+ // here.
+ // Note: The actual function name of Uma_HistogramMediumTimes() is without
+ // the underscore, it is mentioned with the underscore here because
+ // check-libchrome.py will detect that function name in comment and
+ // incorrectly assume we used that method and fail, thus it is mentioned
+ // with an extra underscore.
+ g_metrics->SendToUMA(name,
+ base::Milliseconds(milliseconds).InMilliseconds(),
+ base::Milliseconds(1).InMilliseconds(),
+ base::Minutes(3).InMilliseconds(), 50);
+ }
+}
+
} // namespace
ChromeML::ChromeML(raw_ref<MetricsLibraryInterface> metrics,
@@ -174,6 +192,7 @@
const ChromeMLMetricsFns metrics_fns{
.RecordExactLinearHistogram = &RecordExactLinearHistogram,
.RecordCustomCountsHistogram = &RecordCustomCountsHistogram,
+ .RecordMediumTimesHistogram = &RecordMediumTimesHistogram,
};
api->SetMetricsFns(&metrics_fns);
}
diff --git a/odml/on_device_model/ml/chrome_ml_api.h b/odml/on_device_model/ml/chrome_ml_api.h
index 51a1c60..85fd1c8 100644
--- a/odml/on_device_model/ml/chrome_ml_api.h
+++ b/odml/on_device_model/ml/chrome_ml_api.h
@@ -9,6 +9,7 @@
#include <functional>
#include <optional>
#include <string>
+#include <vector>
#include "odml/on_device_model/ml/chrome_ml_types.h"
#include "odml/on_device_model/ml/forward_declare.h"
@@ -195,6 +196,11 @@
// This will be called on the internal thread executing the model.
using ChromeMLScoreFn = std::function<void(float)>;
+// Called with a vector of probability scores after a call to
+// GetProbabilitiesBlocking().
+using ChromeMLGetProbabilitiesBlockingFn =
+ std::function<void(const std::vector<float>&)>;
+
struct ChromeMLExecuteOptions {
int context_mode;
uint32_t max_tokens;
@@ -234,6 +240,9 @@
// spanning the specified range.
void (*RecordCustomCountsHistogram)(
const char* name, int sample, int min, int exclusive_max, size_t buckets);
+
+ // Logs a sample for timings up to 3 minutes.
+ void (*RecordMediumTimesHistogram)(const char* name, int64_t milliseconds);
};
// Precision used by the gpu delegate during inference.
@@ -360,6 +369,13 @@
const std::string& text,
const ChromeMLScoreFn& fn);
+ // Get the probabilities of a batch of tokens.
+ // Note that this is a blocking call, and mainly used for testing purpose.
+ void (*SessionGetProbabilitiesBlocking)(
+ ChromeMLSession session,
+ const std::string& input,
+ const ChromeMLGetProbabilitiesBlockingFn& fn);
+
// Create a new session in the model, optionally loading adaptation data.
ChromeMLSession (*CreateSession)(
ChromeMLModel model, const ChromeMLAdaptationDescriptor* descriptor);
diff --git a/odml/on_device_model/ml/on_device_model_executor.cc b/odml/on_device_model/ml/on_device_model_executor.cc
index a926429..cdb6780 100644
--- a/odml/on_device_model/ml/on_device_model_executor.cc
+++ b/odml/on_device_model/ml/on_device_model_executor.cc
@@ -331,6 +331,14 @@
session_->Score(text, ConvertCallbackToFn(std::move(callback)));
}
+DISABLE_CFI_DLSYM
+void SessionImpl::GetProbabilitiesBlocking(
+ const std::string& input,
+ base::OnceCallback<void(const std::vector<float>&)> callback) {
+ session_->GetProbabilitiesBlocking(input,
+ ConvertCallbackToFn(std::move(callback)));
+}
+
std::unique_ptr<SessionImpl> SessionImpl::Clone() {
return std::make_unique<SessionImpl>(metrics_, chrome_ml_.get(), model_,
session_->Clone(), max_tokens_,
diff --git a/odml/on_device_model/ml/on_device_model_executor.h b/odml/on_device_model/ml/on_device_model_executor.h
index 57d69d2..a601775 100644
--- a/odml/on_device_model/ml/on_device_model_executor.h
+++ b/odml/on_device_model/ml/on_device_model_executor.h
@@ -10,6 +10,7 @@
#include <memory>
#include <set>
#include <string>
+#include <vector>
#include <absl/container/flat_hash_map.h>
#include <base/files/file_path.h>
@@ -57,6 +58,9 @@
void SizeInTokens(on_device_model::mojom::InputPtr input,
base::OnceCallback<void(uint32_t)> callback);
void Score(const std::string& text, base::OnceCallback<void(float)> callback);
+ void GetProbabilitiesBlocking(
+ const std::string& input,
+ base::OnceCallback<void(const std::vector<float>&)> callback);
std::unique_ptr<SessionImpl> Clone();
private:
diff --git a/odml/on_device_model/ml/session_accessor.cc b/odml/on_device_model/ml/session_accessor.cc
index 577fff5..74e66db 100644
--- a/odml/on_device_model/ml/session_accessor.cc
+++ b/odml/on_device_model/ml/session_accessor.cc
@@ -125,6 +125,14 @@
text, std::move(score_fn)));
}
+void SessionAccessor::GetProbabilitiesBlocking(
+ const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn) {
+ task_runner_->PostTask(
+ FROM_HERE,
+ base::BindOnce(&SessionAccessor::GetProbabilitiesBlockingInternal,
+ base::Unretained(this), input, std::move(get_prob_fn)));
+}
+
void SessionAccessor::SizeInTokens(on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn) {
task_runner_->PostTask(
@@ -239,6 +247,14 @@
}
DISABLE_CFI_DLSYM
+void SessionAccessor::GetProbabilitiesBlockingInternal(
+ const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn) {
+ DCHECK(task_runner_->RunsTasksInCurrentSequence());
+ chrome_ml_->api().SessionGetProbabilitiesBlocking(session_, input,
+ get_prob_fn);
+}
+
+DISABLE_CFI_DLSYM
void SessionAccessor::SizeInTokensInternal(
on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn) {
diff --git a/odml/on_device_model/ml/session_accessor.h b/odml/on_device_model/ml/session_accessor.h
index cd506f6..86f2661 100644
--- a/odml/on_device_model/ml/session_accessor.h
+++ b/odml/on_device_model/ml/session_accessor.h
@@ -42,6 +42,8 @@
ChromeMLCancelFn Generate(on_device_model::mojom::GenerateOptionsPtr options,
ChromeMLExecutionOutputFn output_fn);
void Score(const std::string& text, ChromeMLScoreFn score_fn);
+ void GetProbabilitiesBlocking(const std::string& input,
+ ChromeMLGetProbabilitiesBlockingFn get_prob_fn);
void SizeInTokens(on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn);
@@ -65,6 +67,8 @@
ChromeMLExecutionOutputFn output_fn,
scoped_refptr<Canceler> canceler);
void ScoreInternal(const std::string& text, ChromeMLScoreFn score_fn);
+ void GetProbabilitiesBlockingInternal(
+ const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn);
void SizeInTokensInternal(on_device_model::mojom::InputPtr input,
ChromeMLSizeInTokensFn size_in_tokens_fn);
diff --git a/odml/on_device_model/on_device_model_service.cc b/odml/on_device_model/on_device_model_service.cc
index 2a2bc92..bb53916 100644
--- a/odml/on_device_model/on_device_model_service.cc
+++ b/odml/on_device_model/on_device_model_service.cc
@@ -66,6 +66,10 @@
GetSizeInTokensCallback callback) override;
void Score(const std::string& text, ScoreCallback callback) override;
void Clone(mojo::PendingReceiver<mojom::Session> session) override;
+ void GetProbabilitiesBlocking(
+ const std::string& input,
+ GetProbabilitiesBlockingCallback callback) override;
+ void SetPriority(mojom::Priority priority) override;
mojo::Receiver<mojom::Session>& receiver() { return receiver_; }
@@ -97,6 +101,14 @@
session_->Score(text, std::move(callback).Then(std::move(on_complete)));
}
+ void GetProbabilitiesBlockingInternal(
+ const std::string& input,
+ GetProbabilitiesBlockingCallback callback,
+ base::OnceClosure on_complete) {
+ session_->GetProbabilitiesBlocking(
+ input, std::move(callback).Then(std::move(on_complete)));
+ }
+
void CloneInternal(mojo::PendingReceiver<mojom::Session> session);
base::WeakPtr<ModelWrapper> model_;
@@ -322,6 +334,23 @@
weak_ptr_factory_.GetWeakPtr());
}
+void SessionWrapper::GetProbabilitiesBlocking(
+ const std::string& input, GetProbabilitiesBlockingCallback callback) {
+ if (!model_) {
+ return;
+ }
+
+ model_->AddAndRunPendingTask(
+ base::BindOnce(&SessionWrapper::GetProbabilitiesBlockingInternal,
+ weak_ptr_factory_.GetWeakPtr(), input,
+ std::move(callback)),
+ weak_ptr_factory_.GetWeakPtr());
+}
+
+void SessionWrapper::SetPriority(mojom::Priority priority) {
+ LOG(INFO) << "on_device_model priority is not supported on ChromeOS";
+}
+
void SessionWrapper::Clone(mojo::PendingReceiver<mojom::Session> session) {
if (!model_) {
return;
diff --git a/odml/on_device_model/public/cpp/model_assets.h b/odml/on_device_model/public/cpp/model_assets.h
index d6c8b3e..beb1bd6 100644
--- a/odml/on_device_model/public/cpp/model_assets.h
+++ b/odml/on_device_model/public/cpp/model_assets.h
@@ -44,6 +44,10 @@
AdaptationAssetPaths(const AdaptationAssetPaths&);
~AdaptationAssetPaths();
+ bool operator==(const AdaptationAssetPaths& other) const {
+ return weights == other.weights;
+ }
+
base::FilePath weights;
};