| // Copyright 2023 The Chromium Authors |
| // Use of this source code is governed by a BSD-style license that can be |
| // found in the LICENSE file. |
| |
| import '//resources/cr_elements/cr_button/cr_button.js'; |
| import '//resources/cr_elements/cr_checkbox/cr_checkbox.js'; |
| import '//resources/cr_elements/cr_collapse/cr_collapse.js'; |
| import '//resources/cr_elements/cr_expand_button/cr_expand_button.js'; |
| import '//resources/cr_elements/cr_input/cr_input.js'; |
| import '//resources/cr_elements/cr_textarea/cr_textarea.js'; |
| import '/strings.m.js'; |
| |
| import type {CrInputElement} from '//resources/cr_elements/cr_input/cr_input.js'; |
| import {assert} from '//resources/js/assert.js'; |
| import {CrLitElement} from '//resources/lit/v3_0/lit.rollup.js'; |
| import type {PropertyValues} from '//resources/lit/v3_0/lit.rollup.js'; |
| import type {FilePath} from '//resources/mojo/mojo/public/mojom/base/file_path.mojom-webui.js'; |
| import {loadTimeData} from 'chrome://resources/js/load_time_data.js'; |
| |
| import {BrowserProxy} from './browser_proxy.js'; |
| import type {AudioData, Capabilities, InputPiece, ResponseChunk, ResponseSummary} from './on_device_model.mojom-webui.js'; |
| import {LoadModelResult, OnDeviceModelRemote, PerformanceClass, SessionRemote, StreamingResponderCallbackRouter, Token} from './on_device_model.mojom-webui.js'; |
| import {ModelPerformanceHint} from './on_device_model_service.mojom-webui.js'; |
| import {getCss} from './tools.css.js'; |
| import {getHtml} from './tools.html.js'; |
| |
| interface Response { |
| text: string; |
| response: string; |
| responseClass: string; |
| retracted: boolean; |
| error: boolean; |
| } |
| |
| interface OnDeviceInternalsToolsElement { |
| $: { |
| modelInput: CrInputElement, |
| temperatureInput: CrInputElement, |
| textInput: CrInputElement, |
| imageInput: HTMLInputElement, |
| audioInput: HTMLInputElement, |
| topKInput: CrInputElement, |
| performanceHintSelect: HTMLSelectElement, |
| }; |
| } |
| |
| function getPerformanceClassText(performanceClass: PerformanceClass): string { |
| switch (performanceClass) { |
| case PerformanceClass.kVeryLow: |
| return 'Very Low'; |
| case PerformanceClass.kLow: |
| return 'Low'; |
| case PerformanceClass.kMedium: |
| return 'Medium'; |
| case PerformanceClass.kHigh: |
| return 'High'; |
| case PerformanceClass.kVeryHigh: |
| return 'Very High'; |
| case PerformanceClass.kGpuBlocked: |
| return 'GPU blocked'; |
| case PerformanceClass.kFailedToLoadLibrary: |
| return 'Failed to load native library'; |
| default: |
| return 'Error'; |
| } |
| } |
| |
| function textToInputPieces(text: string): InputPiece[] { |
| const input: InputPiece[] = []; |
| for (const piece of text.split('\n')) { |
| if (piece === '$SYSTEM') { |
| input.push({token: Token.kSystem}); |
| } else if (piece === '$MODEL') { |
| input.push({token: Token.kModel}); |
| } else if (piece === '$USER') { |
| input.push({token: Token.kUser}); |
| } else if (piece === '$END') { |
| input.push({token: Token.kEnd}); |
| } else if ( |
| input.length === 0 || input[input.length - 1]!.text === undefined) { |
| input.push({text: piece}); |
| } else { |
| input[input.length - 1]!.text += '\n' + piece; |
| } |
| } |
| return input; |
| } |
| |
| function filePathToString(filePath: FilePath): string { |
| if (typeof filePath.path === 'string') { |
| return filePath.path; |
| } |
| |
| const decoder = new TextDecoder('utf-16'); |
| const buffer = new Uint16Array(filePath.path); |
| return decoder.decode(buffer); |
| } |
| |
| class OnDeviceInternalsToolsElement extends CrLitElement { |
| static get is() { |
| return 'on-device-internals-tools'; |
| } |
| |
| static override get styles() { |
| return getCss(); |
| } |
| |
| override render() { |
| return getHtml.bind(this)(); |
| } |
| |
| static override get properties() { |
| return { |
| modelPath_: {type: String}, |
| error_: {type: String}, |
| imageError_: {type: String}, |
| text_: {type: String}, |
| loadModelStart_: {type: Number}, |
| currentResponse_: {type: Object}, |
| responses_: {type: Array}, |
| model_: {type: Object}, |
| performanceClassText_: {type: String}, |
| usePlatformModel_: {type: Boolean}, |
| contextExpanded_: {type: Boolean}, |
| contextLength_: {type: Number}, |
| contextText_: {type: String}, |
| topK_: {type: Number}, |
| temperature_: {type: Number}, |
| imageFile_: {type: Object}, |
| audioFile_: {type: Object}, |
| audioError_: {type: String}, |
| performanceHint_: {type: String}, |
| loadedPerformanceHint_: {type: Number}, |
| }; |
| } |
| |
| private capabilities_: Capabilities = {imageInput: false, audioInput: false}; |
| protected accessor contextExpanded_: boolean = false; |
| protected accessor contextLength_: number = 0; |
| protected accessor contextText_: string = ''; |
| protected accessor currentResponse_: Response|null = null; |
| protected accessor error_: string = ''; |
| protected accessor imageError_: string = ''; |
| private loadModelDuration_: number = -1; |
| private accessor loadModelStart_: number = 0; |
| private accessor modelPath_: string = ''; |
| protected accessor model_: OnDeviceModelRemote|null = null; |
| protected accessor performanceClassText_: string = 'Loading...'; |
| protected showPlatformModelCheckbox_: boolean = |
| loadTimeData.getBoolean('useChromeOSModelService'); |
| protected accessor usePlatformModel_: boolean = false; |
| protected accessor responses_: Response[] = []; |
| protected accessor temperature_: number = 0; |
| protected accessor text_: string = ''; |
| protected accessor topK_: number = 1; |
| protected accessor imageFile_: File|null = null; |
| protected accessor audioFile_: File|null = null; |
| protected accessor audioError_: string = ''; |
| protected accessor performanceHint_: string = 'kHighestQuality'; |
| private accessor loadedPerformanceHint_: ModelPerformanceHint|null = null; |
| |
| private session_: SessionRemote|null = null; |
| private proxy_: BrowserProxy = BrowserProxy.getInstance(); |
| private responseRouter_: StreamingResponderCallbackRouter = |
| new StreamingResponderCallbackRouter(); |
| |
| override firstUpdated() { |
| this.getPerformanceClass_(); |
| this.$.temperatureInput.inputElement.step = '0.1'; |
| this.$.imageInput.addEventListener( |
| 'change', this.onImageChange_.bind(this)); |
| this.$.audioInput.addEventListener( |
| 'change', this.onAudioChange_.bind(this)); |
| } |
| |
| override updated(changedProperties: PropertyValues<this>) { |
| super.updated(changedProperties); |
| |
| const changedPrivateProperties = |
| changedProperties as Map<PropertyKey, unknown>; |
| |
| if (changedPrivateProperties.has('model_') || |
| changedPrivateProperties.has('error_')) { |
| this.onModelOrErrorChanged_(); |
| } |
| } |
| |
| private async getPerformanceClass_() { |
| this.performanceClassText_ = getPerformanceClassText( |
| (await this.proxy_.handler.getDeviceAndPerformanceInfo()) |
| .performanceInfo.performanceClass); |
| } |
| |
| private onModelOrErrorChanged_() { |
| if (this.model_ !== null) { |
| this.loadModelDuration_ = new Date().getTime() - this.loadModelStart_; |
| this.$.textInput.focus(); |
| } |
| this.loadModelStart_ = 0; |
| } |
| |
| protected onLoadClick_() { |
| const modelPathString = this.$.modelInput.value; |
| // <if expr="is_win"> |
| // Windows file paths are std::wstring, so use Array<Number>. |
| const processedPath = Array.from(modelPathString, (c) => c.charCodeAt(0)); |
| // </if> |
| // <if expr="not is_win"> |
| const processedPath = modelPathString; |
| // </if> |
| this.onModelSelected_({path: processedPath}); |
| } |
| |
| protected async onLoadDefaultClick_() { |
| const defaultModelPath = await this.proxy_.handler.getDefaultModelPath(); |
| if (defaultModelPath.modelPath === null) { |
| this.error_ = 'Unable to get default model path.'; |
| return; |
| } |
| this.onModelSelected_(defaultModelPath.modelPath); |
| } |
| |
| protected onAddImageClick_() { |
| this.$.imageInput.click(); |
| } |
| |
| protected onAddAudioClick_() { |
| this.$.audioInput.click(); |
| } |
| |
| protected onRemoteImageClick_() { |
| this.imageFile_ = null; |
| this.$.imageInput.value = ''; |
| } |
| |
| protected onRemoteAudioClick_() { |
| this.audioFile_ = null; |
| this.$.audioInput.value = ''; |
| } |
| |
| |
| protected onPerformanceHintChange_() { |
| this.performanceHint_ = this.$.performanceHintSelect.value; |
| } |
| |
| private onServiceCrashed_() { |
| if (this.currentResponse_) { |
| this.currentResponse_.error = true; |
| this.addResponse_(); |
| } |
| this.error_ = 'Service crashed, please reload the model.'; |
| this.model_ = null; |
| this.modelPath_ = ''; |
| this.loadModelStart_ = 0; |
| this.$.modelInput.focus(); |
| } |
| |
| private onImageChange_() { |
| this.imageError_ = ''; |
| if ((this.$.imageInput.files?.length ?? 0) > 0) { |
| this.imageFile_ = this.$.imageInput.files!.item(0) ?? null; |
| } else { |
| this.imageFile_ = null; |
| } |
| } |
| |
| private onAudioChange_() { |
| this.audioError_ = ''; |
| if ((this.$.audioInput.files?.length ?? 0) > 0) { |
| this.audioFile_ = this.$.audioInput.files!.item(0) ?? null; |
| } else { |
| this.audioFile_ = null; |
| } |
| } |
| |
| private async onModelSelected_(modelPath: FilePath) { |
| this.error_ = ''; |
| if (this.model_) { |
| this.model_.$.close(); |
| } |
| if (this.model_) { |
| this.model_.$.close(); |
| } |
| this.imageFile_ = null; |
| this.audioFile_ = null; |
| this.model_ = null; |
| this.capabilities_ = {imageInput: false, audioInput: false}; |
| this.loadModelStart_ = new Date().getTime(); |
| const performanceHint = ModelPerformanceHint[( |
| this.performanceHint_ as keyof typeof ModelPerformanceHint)]; |
| const newModel = new OnDeviceModelRemote(); |
| |
| let result: LoadModelResult; |
| let capabilities: Capabilities; |
| if (this.usePlatformModel_) { |
| const loadedData = await this.proxy_.handler.loadPlatformModel( |
| modelPath, newModel.$.bindNewPipeAndPassReceiver()); |
| result = loadedData.result; |
| capabilities = {imageInput: false, audioInput: false}; |
| } else { |
| const loadedData = await this.proxy_.handler.loadModel( |
| modelPath, performanceHint, |
| newModel.$.bindNewPipeAndPassReceiver()); |
| result = loadedData.result; |
| capabilities = loadedData.capabilities; |
| } |
| |
| if (result !== LoadModelResult.kSuccess) { |
| this.error_ = |
| 'Unable to load model. Specify a correct and absolute path.'; |
| } else { |
| this.model_ = newModel; |
| this.capabilities_ = capabilities; |
| this.model_.onConnectionError.addListener(() => { |
| this.onServiceCrashed_(); |
| }); |
| this.startNewSession_(); |
| this.modelPath_ = filePathToString(modelPath); |
| this.loadedPerformanceHint_ = performanceHint; |
| } |
| } |
| |
| protected onAddContextClick_() { |
| if (this.session_ === null) { |
| return; |
| } |
| this.session_.append( |
| { |
| maxTokens: 0, |
| input: {pieces: textToInputPieces(this.contextText_)}, |
| }, |
| null); |
| this.contextLength_ += this.contextText_.split(/(\s+)/).length; |
| this.contextText_ = ''; |
| } |
| |
| protected startNewSession_() { |
| if (this.model_ === null) { |
| return; |
| } |
| this.contextLength_ = 0; |
| this.session_ = new SessionRemote(); |
| this.model_.startSession(this.session_.$.bindNewPipeAndPassReceiver(), { |
| maxTokens: 0, |
| topK: this.topK_, |
| temperature: this.temperature_, |
| capabilities: { |
| imageInput: this.imagesEnabled_(), |
| audioInput: this.audioEnabled_(), |
| }, |
| }); |
| } |
| |
| protected onCancelClick_() { |
| this.responseRouter_.$.close(); |
| this.responseRouter_ = new StreamingResponderCallbackRouter(); |
| this.addResponse_(); |
| } |
| |
| protected onExecuteClick_() { |
| this.onExecute_(); |
| } |
| |
| private async addResponse_() { |
| assert(this.currentResponse_); |
| this.responses_.unshift(this.currentResponse_); |
| this.currentResponse_ = null; |
| this.requestUpdate(); |
| await this.updateComplete; |
| this.$.textInput.focus(); |
| } |
| |
| private async decodeBitmap_() { |
| const data = new Uint8Array(await this.imageFile_!.arrayBuffer()); |
| if (data.byteLength <= 0) { |
| return null; |
| } |
| const handle = Mojo.createSharedBuffer(data.byteLength).handle; |
| const buffer = new Uint8Array(handle.mapBuffer(0, data.byteLength).buffer); |
| buffer.set(data); |
| |
| // BigBuffer type wants all properties but Mojo expects only one of them. |
| const bigBuffer = { |
| sharedMemory: { |
| bufferHandle: handle, |
| size: data.byteLength, |
| }, |
| bytes: undefined, |
| invalidBuffer: undefined, |
| }; |
| delete bigBuffer.invalidBuffer; |
| delete bigBuffer.bytes; |
| const {bitmap} = await this.proxy_.handler.decodeBitmap(bigBuffer); |
| return bitmap; |
| } |
| private async decodeAudio_(): Promise<AudioData> { |
| const audioCtx = new AudioContext({sampleRate: 48000}); |
| const arrayBuffer = await this.audioFile_!.arrayBuffer(); |
| const buffer = await audioCtx.decodeAudioData(arrayBuffer); |
| if (buffer.numberOfChannels > 1) { |
| throw new Error('Multichannel audio is not supported'); |
| } |
| return { |
| sampleRate: buffer.sampleRate, |
| channelCount: buffer.numberOfChannels, |
| frameCount: buffer.length, |
| data: Array.from(buffer.getChannelData(0)), |
| }; |
| } |
| private async onExecute_() { |
| this.imageError_ = ''; |
| if (this.session_ === null) { |
| return; |
| } |
| if (!this.$.topKInput.validate()) { |
| return; |
| } |
| if (!this.$.temperatureInput.validate()) { |
| return; |
| } |
| const pieces = textToInputPieces(this.text_); |
| if (this.imageFile_ !== null) { |
| const bitmap = await this.decodeBitmap_(); |
| if (bitmap) { |
| pieces.unshift({bitmap}); |
| } else { |
| this.imageFile_ = null; |
| this.imageError_ = 'Image is invalid'; |
| return; |
| } |
| } |
| if (this.audioFile_ !== null) { |
| try { |
| const audio = await this.decodeAudio_(); |
| pieces.unshift({audio}); |
| } catch (error) { |
| this.audioFile_ = null; |
| this.audioError_ = `Audio is invalid: ${error}`; |
| return; |
| } |
| } |
| const clonedSession = new SessionRemote(); |
| this.session_.clone(clonedSession.$.bindNewPipeAndPassReceiver()); |
| clonedSession.append( |
| { |
| maxTokens: 0, |
| input: {pieces: pieces}, |
| }, |
| null); |
| clonedSession.generate( |
| { |
| maxOutputTokens: 0, |
| constraint: null, |
| }, |
| this.responseRouter_.$.bindNewPipeAndPassRemote()); |
| const onResponseId = |
| this.responseRouter_.onResponse.addListener((chunk: ResponseChunk) => { |
| assert(this.currentResponse_); |
| this.currentResponse_.response = |
| (this.currentResponse_?.response + chunk.text).trimStart(); |
| this.requestUpdate(); |
| }); |
| const onCompleteId = |
| this.responseRouter_.onComplete.addListener((_: ResponseSummary) => { |
| this.addResponse_(); |
| this.responseRouter_.removeListener(onResponseId); |
| this.responseRouter_.removeListener(onCompleteId); |
| }); |
| this.currentResponse_ = { |
| text: this.text_, |
| response: '', |
| responseClass: 'response', |
| retracted: false, |
| error: false, |
| }; |
| this.text_ = ''; |
| } |
| |
| protected canEnterInput_(): boolean { |
| return !this.currentResponse_ && this.model_ !== null; |
| } |
| |
| protected canExecute_(): boolean { |
| return this.canEnterInput_() && this.text_.length > 0; |
| } |
| |
| protected canUploadFile_(): boolean { |
| return this.canEnterInput_() && this.imageFile_ === null; |
| } |
| |
| protected isLoading_(): boolean { |
| return this.loadModelStart_ !== 0; |
| } |
| |
| protected imagesEnabled_(): boolean { |
| return this.capabilities_.imageInput; |
| } |
| |
| protected audioEnabled_(): boolean { |
| return this.capabilities_.audioInput; |
| } |
| |
| protected getModelText_(): string { |
| if (this.modelPath_.length === 0) { |
| return ''; |
| } |
| let text = 'Model loaded from ' + this.modelPath_ + ' in ' + |
| this.loadModelDuration_ + 'ms '; |
| if (this.imagesEnabled_()) { |
| text += '[images enabled]'; |
| } |
| if (this.audioEnabled_()) { |
| text += '[audio enabled]'; |
| } |
| if (this.loadedPerformanceHint_ === |
| ModelPerformanceHint.kFastestInference) { |
| text += '[fastest inference]'; |
| } |
| return text; |
| } |
| |
| protected onContextExpandedChanged_(e: CustomEvent<{value: boolean}>) { |
| this.contextExpanded_ = e.detail.value; |
| } |
| |
| protected onContextTextChanged_(e: CustomEvent<{value: string}>) { |
| this.contextText_ = e.detail.value; |
| } |
| |
| protected onTextChanged_(e: CustomEvent<{value: string}>) { |
| this.text_ = e.detail.value; |
| } |
| |
| protected onTopKChanged_(e: CustomEvent<{value: number}>) { |
| this.topK_ = e.detail.value; |
| } |
| |
| protected onTemperatureChanged_(e: CustomEvent<{value: number}>) { |
| this.temperature_ = e.detail.value; |
| } |
| |
| protected onUsePlatformModelChanged_(e: CustomEvent<{value: boolean}>) { |
| this.usePlatformModel_ = e.detail.value; |
| } |
| } |
| |
| export type ToolsElement = OnDeviceInternalsToolsElement; |
| |
| declare global { |
| interface HTMLElementTagNameMap { |
| 'on-device-internals-tools': OnDeviceInternalsToolsElement; |
| } |
| } |
| |
| customElements.define( |
| OnDeviceInternalsToolsElement.is, OnDeviceInternalsToolsElement); |