odml: Sync with Chrome ToT (up to Add SetPriority() to on-device Session)

This CL includes:
* Blink side changes for guidance/structured outputs
  crrev.com/c/6350280
* Bind paths to model controllers
  crrev.com/c/6419383
* Add GetProbabilitiesBlocking() API to odml mojo
  crrev.com/c/6396181
* Add GetProbabilities() API to chrome_ml_api
  crrev.com/c/6362334
* Add a few more metrics for on-device model usage
  crrev.com/c/6441693
* Add SetPriority() to on-device Session interface
  crrev.com/c/6442263

BUG=b:353900545
TEST=unit tests

Change-Id: I8eed88ecf2c86eeb90bba8d7381fc1160db8aeca

Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform2/+/6498469
Reviewed-by: Howard Yang <hcyang@google.com>
Tested-by: John L Chen <zuan@chromium.org>
diff --git a/odml/mojom/on_device_model.mojom b/odml/mojom/on_device_model.mojom
index 9e15050..9232360 100644
--- a/odml/mojom/on_device_model.mojom
+++ b/odml/mojom/on_device_model.mojom
@@ -214,6 +214,22 @@
   uint32? top_k;
   // Deprecated: use SessionParams.temperature instead.
   float? temperature;
+
+  // A JSON schema defining structured output requirements for the response.
+  // Passed as an opaque JSON blob to llguidance.
+  [MinVersion=1]
+  string? response_json_schema;
+};
+
+// Priorities which determine how requests to a session are scheduled.
+[Stable, Extensible]
+enum Priority {
+  // Requests should be treated as high priority, the user may be actively
+  // waiting for a response.
+  [Default] kForeground,
+  // Requests are not urgent and may be queued. Background requests will only
+  // run if there are no active foreground requests.
+  kBackground,
 };
 
 // A session for a model that allows adding context and then executing an input
@@ -243,6 +259,19 @@
   // Clones the current session. The cloned session will have the same context
   // as the current session.
   Clone@4(pending_receiver<Session> session);
+
+  // Gets the probability for a series of tokens on top of the current
+  // context. Capabilities.probabilities_output must be specified to use this.
+  // Note that this is implemented as a blocking method on the service side
+  // and should only be used in debugging/testing.
+  [MinVersion=3]
+  GetProbabilitiesBlocking@8(string text) => (array<float> probabilities);
+
+  // Sets the priority for currently queued requests to this session and future
+  // requests. Any clones made of this session will inherit the current
+  // priority. Priority for new sessions defaults to kForeground.
+  [MinVersion=4]
+  SetPriority@9(Priority priority);
 };
 
 // A loaded model which can be queried. This interface must be controlled by the
diff --git a/odml/on_device_model/ml/chrome_ml.cc b/odml/on_device_model/ml/chrome_ml.cc
index 60723f9..ce3f88d 100644
--- a/odml/on_device_model/ml/chrome_ml.cc
+++ b/odml/on_device_model/ml/chrome_ml.cc
@@ -102,6 +102,24 @@
   }
 }
 
+void RecordMediumTimesHistogram(const char* name, int64_t milliseconds) {
+  base::AutoLock lock(*g_metrics_lock);
+  if (g_metrics) {
+    // This is originally Uma_HistogramMediumTimes() in Chromium but ChromiumOS
+    // doesn't have Uma_HistogramMediumTimes() so an equivalent is implemented
+    // here.
+    // Note: The actual function name of Uma_HistogramMediumTimes() is without
+    // the underscore, it is mentioned with the underscore here because
+    // check-libchrome.py will detect that function name in comment and
+    // incorrectly assume we used that method and fail, thus it is mentioned
+    // with an extra underscore.
+    g_metrics->SendToUMA(name,
+                         base::Milliseconds(milliseconds).InMilliseconds(),
+                         base::Milliseconds(1).InMilliseconds(),
+                         base::Minutes(3).InMilliseconds(), 50);
+  }
+}
+
 }  // namespace
 
 ChromeML::ChromeML(raw_ref<MetricsLibraryInterface> metrics,
@@ -174,6 +192,7 @@
     const ChromeMLMetricsFns metrics_fns{
         .RecordExactLinearHistogram = &RecordExactLinearHistogram,
         .RecordCustomCountsHistogram = &RecordCustomCountsHistogram,
+        .RecordMediumTimesHistogram = &RecordMediumTimesHistogram,
     };
     api->SetMetricsFns(&metrics_fns);
   }
diff --git a/odml/on_device_model/ml/chrome_ml_api.h b/odml/on_device_model/ml/chrome_ml_api.h
index 51a1c60..85fd1c8 100644
--- a/odml/on_device_model/ml/chrome_ml_api.h
+++ b/odml/on_device_model/ml/chrome_ml_api.h
@@ -9,6 +9,7 @@
 #include <functional>
 #include <optional>
 #include <string>
+#include <vector>
 
 #include "odml/on_device_model/ml/chrome_ml_types.h"
 #include "odml/on_device_model/ml/forward_declare.h"
@@ -195,6 +196,11 @@
 // This will be called on the internal thread executing the model.
 using ChromeMLScoreFn = std::function<void(float)>;
 
+// Called with a vector of probability scores after a call to
+// GetProbabilitiesBlocking().
+using ChromeMLGetProbabilitiesBlockingFn =
+    std::function<void(const std::vector<float>&)>;
+
 struct ChromeMLExecuteOptions {
   int context_mode;
   uint32_t max_tokens;
@@ -234,6 +240,9 @@
   // spanning the specified range.
   void (*RecordCustomCountsHistogram)(
       const char* name, int sample, int min, int exclusive_max, size_t buckets);
+
+  // Logs a sample for timings up to 3 minutes.
+  void (*RecordMediumTimesHistogram)(const char* name, int64_t milliseconds);
 };
 
 // Precision used by the gpu delegate during inference.
@@ -360,6 +369,13 @@
                        const std::string& text,
                        const ChromeMLScoreFn& fn);
 
+  // Get the probabilities of a batch of tokens.
+  // Note that this is a blocking call, and mainly used for testing purpose.
+  void (*SessionGetProbabilitiesBlocking)(
+      ChromeMLSession session,
+      const std::string& input,
+      const ChromeMLGetProbabilitiesBlockingFn& fn);
+
   // Create a new session in the model, optionally loading adaptation data.
   ChromeMLSession (*CreateSession)(
       ChromeMLModel model, const ChromeMLAdaptationDescriptor* descriptor);
diff --git a/odml/on_device_model/ml/on_device_model_executor.cc b/odml/on_device_model/ml/on_device_model_executor.cc
index a926429..cdb6780 100644
--- a/odml/on_device_model/ml/on_device_model_executor.cc
+++ b/odml/on_device_model/ml/on_device_model_executor.cc
@@ -331,6 +331,14 @@
   session_->Score(text, ConvertCallbackToFn(std::move(callback)));
 }
 
+DISABLE_CFI_DLSYM
+void SessionImpl::GetProbabilitiesBlocking(
+    const std::string& input,
+    base::OnceCallback<void(const std::vector<float>&)> callback) {
+  session_->GetProbabilitiesBlocking(input,
+                                     ConvertCallbackToFn(std::move(callback)));
+}
+
 std::unique_ptr<SessionImpl> SessionImpl::Clone() {
   return std::make_unique<SessionImpl>(metrics_, chrome_ml_.get(), model_,
                                        session_->Clone(), max_tokens_,
diff --git a/odml/on_device_model/ml/on_device_model_executor.h b/odml/on_device_model/ml/on_device_model_executor.h
index 57d69d2..a601775 100644
--- a/odml/on_device_model/ml/on_device_model_executor.h
+++ b/odml/on_device_model/ml/on_device_model_executor.h
@@ -10,6 +10,7 @@
 #include <memory>
 #include <set>
 #include <string>
+#include <vector>
 
 #include <absl/container/flat_hash_map.h>
 #include <base/files/file_path.h>
@@ -57,6 +58,9 @@
   void SizeInTokens(on_device_model::mojom::InputPtr input,
                     base::OnceCallback<void(uint32_t)> callback);
   void Score(const std::string& text, base::OnceCallback<void(float)> callback);
+  void GetProbabilitiesBlocking(
+      const std::string& input,
+      base::OnceCallback<void(const std::vector<float>&)> callback);
   std::unique_ptr<SessionImpl> Clone();
 
  private:
diff --git a/odml/on_device_model/ml/session_accessor.cc b/odml/on_device_model/ml/session_accessor.cc
index 577fff5..74e66db 100644
--- a/odml/on_device_model/ml/session_accessor.cc
+++ b/odml/on_device_model/ml/session_accessor.cc
@@ -125,6 +125,14 @@
                      text, std::move(score_fn)));
 }
 
+void SessionAccessor::GetProbabilitiesBlocking(
+    const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn) {
+  task_runner_->PostTask(
+      FROM_HERE,
+      base::BindOnce(&SessionAccessor::GetProbabilitiesBlockingInternal,
+                     base::Unretained(this), input, std::move(get_prob_fn)));
+}
+
 void SessionAccessor::SizeInTokens(on_device_model::mojom::InputPtr input,
                                    ChromeMLSizeInTokensFn size_in_tokens_fn) {
   task_runner_->PostTask(
@@ -239,6 +247,14 @@
 }
 
 DISABLE_CFI_DLSYM
+void SessionAccessor::GetProbabilitiesBlockingInternal(
+    const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn) {
+  DCHECK(task_runner_->RunsTasksInCurrentSequence());
+  chrome_ml_->api().SessionGetProbabilitiesBlocking(session_, input,
+                                                    get_prob_fn);
+}
+
+DISABLE_CFI_DLSYM
 void SessionAccessor::SizeInTokensInternal(
     on_device_model::mojom::InputPtr input,
     ChromeMLSizeInTokensFn size_in_tokens_fn) {
diff --git a/odml/on_device_model/ml/session_accessor.h b/odml/on_device_model/ml/session_accessor.h
index cd506f6..86f2661 100644
--- a/odml/on_device_model/ml/session_accessor.h
+++ b/odml/on_device_model/ml/session_accessor.h
@@ -42,6 +42,8 @@
   ChromeMLCancelFn Generate(on_device_model::mojom::GenerateOptionsPtr options,
                             ChromeMLExecutionOutputFn output_fn);
   void Score(const std::string& text, ChromeMLScoreFn score_fn);
+  void GetProbabilitiesBlocking(const std::string& input,
+                                ChromeMLGetProbabilitiesBlockingFn get_prob_fn);
   void SizeInTokens(on_device_model::mojom::InputPtr input,
                     ChromeMLSizeInTokensFn size_in_tokens_fn);
 
@@ -65,6 +67,8 @@
       ChromeMLExecutionOutputFn output_fn,
       scoped_refptr<Canceler> canceler);
   void ScoreInternal(const std::string& text, ChromeMLScoreFn score_fn);
+  void GetProbabilitiesBlockingInternal(
+      const std::string& input, ChromeMLGetProbabilitiesBlockingFn get_prob_fn);
   void SizeInTokensInternal(on_device_model::mojom::InputPtr input,
                             ChromeMLSizeInTokensFn size_in_tokens_fn);
 
diff --git a/odml/on_device_model/on_device_model_service.cc b/odml/on_device_model/on_device_model_service.cc
index 2a2bc92..bb53916 100644
--- a/odml/on_device_model/on_device_model_service.cc
+++ b/odml/on_device_model/on_device_model_service.cc
@@ -66,6 +66,10 @@
                        GetSizeInTokensCallback callback) override;
   void Score(const std::string& text, ScoreCallback callback) override;
   void Clone(mojo::PendingReceiver<mojom::Session> session) override;
+  void GetProbabilitiesBlocking(
+      const std::string& input,
+      GetProbabilitiesBlockingCallback callback) override;
+  void SetPriority(mojom::Priority priority) override;
 
   mojo::Receiver<mojom::Session>& receiver() { return receiver_; }
 
@@ -97,6 +101,14 @@
     session_->Score(text, std::move(callback).Then(std::move(on_complete)));
   }
 
+  void GetProbabilitiesBlockingInternal(
+      const std::string& input,
+      GetProbabilitiesBlockingCallback callback,
+      base::OnceClosure on_complete) {
+    session_->GetProbabilitiesBlocking(
+        input, std::move(callback).Then(std::move(on_complete)));
+  }
+
   void CloneInternal(mojo::PendingReceiver<mojom::Session> session);
 
   base::WeakPtr<ModelWrapper> model_;
@@ -322,6 +334,23 @@
       weak_ptr_factory_.GetWeakPtr());
 }
 
+void SessionWrapper::GetProbabilitiesBlocking(
+    const std::string& input, GetProbabilitiesBlockingCallback callback) {
+  if (!model_) {
+    return;
+  }
+
+  model_->AddAndRunPendingTask(
+      base::BindOnce(&SessionWrapper::GetProbabilitiesBlockingInternal,
+                     weak_ptr_factory_.GetWeakPtr(), input,
+                     std::move(callback)),
+      weak_ptr_factory_.GetWeakPtr());
+}
+
+void SessionWrapper::SetPriority(mojom::Priority priority) {
+  LOG(INFO) << "on_device_model priority is not supported on ChromeOS";
+}
+
 void SessionWrapper::Clone(mojo::PendingReceiver<mojom::Session> session) {
   if (!model_) {
     return;
diff --git a/odml/on_device_model/public/cpp/model_assets.h b/odml/on_device_model/public/cpp/model_assets.h
index d6c8b3e..beb1bd6 100644
--- a/odml/on_device_model/public/cpp/model_assets.h
+++ b/odml/on_device_model/public/cpp/model_assets.h
@@ -44,6 +44,10 @@
   AdaptationAssetPaths(const AdaptationAssetPaths&);
   ~AdaptationAssetPaths();
 
+  bool operator==(const AdaptationAssetPaths& other) const {
+    return weights == other.weights;
+  }
+
   base::FilePath weights;
 };