fune/toolkit/components/ml/actors/MLEngineParent.sys.mjs

403 lines
12 KiB
JavaScript

/* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
/**
* @typedef {object} Lazy
* @property {typeof console} console
* @property {typeof import("../content/EngineProcess.sys.mjs").EngineProcess} EngineProcess
* @property {typeof import("../../../../services/settings/remote-settings.sys.mjs").RemoteSettings} RemoteSettings
* @property {typeof import("../../translations/actors/TranslationsParent.sys.mjs").TranslationsParent} TranslationsParent
*/
/** @type {Lazy} */
const lazy = {};
ChromeUtils.defineLazyGetter(lazy, "console", () => {
return console.createInstance({
maxLogLevelPref: "browser.ml.logLevel",
prefix: "ML",
});
});
ChromeUtils.defineESModuleGetters(lazy, {
EngineProcess: "chrome://global/content/ml/EngineProcess.sys.mjs",
RemoteSettings: "resource://services-settings/remote-settings.sys.mjs",
TranslationsParent: "resource://gre/actors/TranslationsParent.sys.mjs",
});
/**
* @typedef {import("../../translations/translations").WasmRecord} WasmRecord
*/
const DEFAULT_CACHE_TIMEOUT_MS = 15_000;
/**
* The ML engine is in its own content process. This actor handles the
* marshalling of the data such as the engine payload.
*/
export class MLEngineParent extends JSWindowActorParent {
/**
* The RemoteSettingsClient that downloads the wasm binaries.
*
* @type {RemoteSettingsClient | null}
*/
static #remoteClient = null;
/** @type {Promise<WasmRecord> | null} */
static #wasmRecord = null;
/**
* The following constant controls the major version for wasm downloaded from
* Remote Settings. When a breaking change is introduced, Nightly will have these
* numbers incremented by one, but Beta and Release will still be on the previous
* version. Remote Settings will ship both versions of the records, and the latest
* asset released in that version will be used. For instance, with a major version
* of "1", assets can be downloaded for "1.0", "1.2", "1.3beta", but assets marked
* as "2.0", "2.1", etc will not be downloaded.
*/
static WASM_MAJOR_VERSION = 1;
/**
* Remote settings isn't available in tests, so provide mocked responses.
*
* @param {RemoteSettingsClient} remoteClient
*/
static mockRemoteSettings(remoteClient) {
lazy.console.log("Mocking remote settings in MLEngineParent.");
MLEngineParent.#remoteClient = remoteClient;
MLEngineParent.#wasmRecord = null;
}
/**
* Remove anything that could have been mocked.
*/
static removeMocks() {
lazy.console.log("Removing mocked remote client in MLEngineParent.");
MLEngineParent.#remoteClient = null;
MLEngineParent.#wasmRecord = null;
}
/**
* @param {string} engineName
* @param {() => Promise<ArrayBuffer>} getModel
* @param {number} cacheTimeoutMS - How long the engine cache remains alive between
* uses, in milliseconds. In automation the engine is manually created and destroyed
* to avoid timing issues.
* @returns {MLEngine}
*/
getEngine(engineName, getModel, cacheTimeoutMS = DEFAULT_CACHE_TIMEOUT_MS) {
return new MLEngine(this, engineName, getModel, cacheTimeoutMS);
}
// eslint-disable-next-line consistent-return
async receiveMessage({ name }) {
switch (name) {
case "MLEngine:Ready":
if (lazy.EngineProcess.resolveMLEngineParent) {
lazy.EngineProcess.resolveMLEngineParent(this);
} else {
lazy.console.error(
"Expected #resolveMLEngineParent to exist when then ML Engine is ready."
);
}
break;
case "MLEngine:GetWasmArrayBuffer":
return MLEngineParent.getWasmArrayBuffer();
case "MLEngine:DestroyEngineProcess":
lazy.EngineProcess.destroyMLEngine().catch(error =>
console.error(error)
);
break;
}
}
/**
* @param {RemoteSettingsClient} client
*/
static async #getWasmArrayRecord(client) {
// Load the wasm binary from remote settings, if it hasn't been already.
lazy.console.log(`Getting remote wasm records.`);
/** @type {WasmRecord[]} */
const wasmRecords = await lazy.TranslationsParent.getMaxVersionRecords(
client,
{
// TODO - This record needs to be created with the engine wasm payload.
filters: { name: "inference-engine" },
majorVersion: MLEngineParent.WASM_MAJOR_VERSION,
}
);
if (wasmRecords.length === 0) {
// The remote settings client provides an empty list of records when there is
// an error.
throw new Error("Unable to get the ML engine from Remote Settings.");
}
if (wasmRecords.length > 1) {
MLEngineParent.reportError(
new Error("Expected the ml engine to only have 1 record."),
wasmRecords
);
}
const [record] = wasmRecords;
lazy.console.log(
`Using ${record.name}@${record.release} release version ${record.version} first released on Fx${record.fx_release}`,
record
);
return record;
}
/**
* Download the wasm for the ML inference engine.
*
* @returns {Promise<ArrayBuffer>}
*/
static async getWasmArrayBuffer() {
const client = MLEngineParent.#getRemoteClient();
if (!MLEngineParent.#wasmRecord) {
// Place the records into a promise to prevent any races.
MLEngineParent.#wasmRecord = MLEngineParent.#getWasmArrayRecord(client);
}
let wasmRecord;
try {
wasmRecord = await MLEngineParent.#wasmRecord;
if (!wasmRecord) {
return Promise.reject(
"Error: Unable to get the ML engine from Remote Settings."
);
}
} catch (error) {
MLEngineParent.#wasmRecord = null;
throw error;
}
/** @type {{buffer: ArrayBuffer}} */
const { buffer } = await client.attachments.download(wasmRecord);
return buffer;
}
/**
* Lazily initializes the RemoteSettingsClient for the downloaded wasm binary data.
*
* @returns {RemoteSettingsClient}
*/
static #getRemoteClient() {
if (MLEngineParent.#remoteClient) {
return MLEngineParent.#remoteClient;
}
/** @type {RemoteSettingsClient} */
const client = lazy.RemoteSettings("ml-wasm");
MLEngineParent.#remoteClient = client;
client.on("sync", async ({ data: { created, updated, deleted } }) => {
lazy.console.log(`"sync" event for ml-wasm`, {
created,
updated,
deleted,
});
// Remove all the deleted records.
for (const record of deleted) {
await client.attachments.deleteDownloaded(record);
}
// Remove any updated records, and download the new ones.
for (const { old: oldRecord } of updated) {
await client.attachments.deleteDownloaded(oldRecord);
}
// Do nothing for the created records.
});
return client;
}
/**
* Send a message to gracefully shutdown all of the ML engines in the engine process.
* This mostly exists for testing the shutdown paths of the code.
*/
forceShutdown() {
return this.sendQuery("MLEngine:ForceShutdown");
}
}
/**
* This contains all of the information needed to perform a translation request.
*
* @typedef {object} TranslationRequest
* @property {Node} node
* @property {string} sourceText
* @property {boolean} isHTML
* @property {Function} resolve
* @property {Function} reject
*/
/**
* The interface to communicate to an MLEngine in the parent process. The engine manages
* its own lifetime, and is kept alive with a timeout. A reference to this engine can
* be retained, but once idle, the engine will be destroyed. If a new request to run
* is sent, the engine will be recreated on demand. This balances the cost of retaining
* potentially large amounts of memory to run models, with the speed and ease of running
* the engine.
*
* @template Request
* @template Response
*/
class MLEngine {
/**
* @type {MessagePort | null}
*/
#port = null;
#nextRequestId = 0;
/**
* Tie together a message id to a resolved response.
*
* @type {Map<number, PromiseWithResolvers<Request>>}
*/
#requests = new Map();
/**
* @type {"uninitialized" | "ready" | "error" | "closed"}
*/
engineStatus = "uninitialized";
/**
* @param {MLEngineParent} mlEngineParent
* @param {string} engineName
* @param {() => Promise<ArrayBuffer>} getModel
* @param {number} timeoutMS
*/
constructor(mlEngineParent, engineName, getModel, timeoutMS) {
/** @type {MLEngineParent} */
this.mlEngineParent = mlEngineParent;
/** @type {string} */
this.engineName = engineName;
/** @type {() => Promise<ArrayBuffer>} */
this.getModel = getModel;
/** @type {number} */
this.timeoutMS = timeoutMS;
this.#setupPortCommunication();
}
/**
* Create a MessageChannel to communicate with the engine directly.
*/
#setupPortCommunication() {
const { port1: childPort, port2: parentPort } = new MessageChannel();
const transferables = [childPort];
this.#port = parentPort;
this.#port.onmessage = this.handlePortMessage;
this.mlEngineParent.sendAsyncMessage(
"MLEngine:NewPort",
{
port: childPort,
engineName: this.engineName,
timeoutMS: this.timeoutMS,
},
transferables
);
}
handlePortMessage = ({ data }) => {
switch (data.type) {
case "EnginePort:ModelRequest": {
if (this.#port) {
this.getModel().then(
model => {
this.#port.postMessage({
type: "EnginePort:ModelResponse",
model,
error: null,
});
},
error => {
this.#port.postMessage({
type: "EnginePort:ModelResponse",
model: null,
error,
});
if (
// Ignore intentional errors in tests.
!error?.message.startsWith("Intentionally")
) {
lazy.console.error("Failed to get the model", error);
}
}
);
} else {
lazy.console.error(
"Expected a port to exist during the EnginePort:GetModel event"
);
}
break;
}
case "EnginePort:RunResponse": {
const { response, error, requestId } = data;
const request = this.#requests.get(requestId);
if (request) {
if (response) {
request.resolve(response);
} else {
request.reject(error);
}
} else {
lazy.console.error(
"Could not resolve response in the MLEngineParent",
data
);
}
this.#requests.delete(requestId);
break;
}
case "EnginePort:EngineTerminated": {
// The engine was terminated, and if a new run is needed a new port
// will need to be requested.
this.engineStatus = "closed";
this.discardPort();
break;
}
default:
lazy.console.error("Unknown port message from engine", data);
break;
}
};
discardPort() {
if (this.#port) {
this.#port.postMessage({ type: "EnginePort:Discard" });
this.#port.close();
this.#port = null;
}
}
terminate() {
this.#port.postMessage({ type: "EnginePort:Terminate" });
}
/**
* @param {Request} request
* @returns {Promise<Response>}
*/
run(request) {
const resolvers = Promise.withResolvers();
const requestId = this.#nextRequestId++;
this.#requests.set(requestId, resolvers);
this.#port.postMessage({
type: "EnginePort:Run",
requestId,
request,
});
return resolvers.promise;
}
}