From 9f5282c8f5f51177ff279eb978d7f20a7f7f7c40 Mon Sep 17 00:00:00 2001 From: Henry Heino <46334387+personalizedrefrigerator@users.noreply.github.com> Date: Sat, 26 Oct 2024 13:00:56 -0700 Subject: [PATCH] Android: Allow switching the voice typing library to Whisper (#11158) Co-authored-by: Laurent Cozic --- .eslintignore | 4 + .gitignore | 4 + packages/app-mobile/android/app/build.gradle | 4 + .../android/app/src/main/AndroidManifest.xml | 1 + .../java/net/cozic/joplin/MainApplication.kt | 2 + .../net/cozic/joplin/audio/AudioRecorder.kt | 101 +++++++++++++ .../joplin/audio/InvalidSessionIdException.kt | 5 + .../joplin/audio/SpeechToTextConverter.kt | 136 +++++++++++++++++ .../cozic/joplin/audio/SpeechToTextPackage.kt | 86 +++++++++++ .../audio/SpeechToTextSessionManager.kt | 111 ++++++++++++++ .../app-mobile/components/screens/Note.tsx | 5 +- .../voiceTyping/VoiceTypingDialog.tsx | 110 +++++++++----- packages/app-mobile/jest.setup.js | 4 + .../services/voiceTyping/VoiceTyping.ts | 139 ++++++++++++++++++ .../utils/splitWhisperText.test.ts | 61 ++++++++ .../voiceTyping/utils/splitWhisperText.ts | 65 ++++++++ .../services/voiceTyping/vosk.android.ts | 115 ++++----------- .../app-mobile/services/voiceTyping/vosk.ts | 45 ++---- .../services/voiceTyping/whisper.ts | 119 +++++++++++++++ .../lib/models/settings/builtInMetadata.ts | 19 +++ packages/tools/cspell/dictionary4.txt | 2 + 21 files changed, 980 insertions(+), 158 deletions(-) create mode 100644 packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/AudioRecorder.kt create mode 100644 packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/InvalidSessionIdException.kt create mode 100644 packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextConverter.kt create mode 100644 packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextPackage.kt create mode 100644 packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextSessionManager.kt create mode 100644 packages/app-mobile/services/voiceTyping/VoiceTyping.ts create mode 100644 packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.ts create mode 100644 packages/app-mobile/services/voiceTyping/utils/splitWhisperText.ts create mode 100644 packages/app-mobile/services/voiceTyping/whisper.ts diff --git a/.eslintignore b/.eslintignore index 3bb40cf20..9ede21387 100644 --- a/.eslintignore +++ b/.eslintignore @@ -737,8 +737,12 @@ packages/app-mobile/services/BackButtonService.js packages/app-mobile/services/e2ee/RSA.react-native.js packages/app-mobile/services/plugins/PlatformImplementation.js packages/app-mobile/services/profiles/index.js +packages/app-mobile/services/voiceTyping/VoiceTyping.js +packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js +packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js packages/app-mobile/services/voiceTyping/vosk.android.js packages/app-mobile/services/voiceTyping/vosk.js +packages/app-mobile/services/voiceTyping/whisper.js packages/app-mobile/setupQuickActions.js packages/app-mobile/tools/buildInjectedJs/BundledFile.js packages/app-mobile/tools/buildInjectedJs/constants.js diff --git a/.gitignore b/.gitignore index f1c19367a..f5b7a8a6f 100644 --- a/.gitignore +++ b/.gitignore @@ -714,8 +714,12 @@ packages/app-mobile/services/BackButtonService.js packages/app-mobile/services/e2ee/RSA.react-native.js packages/app-mobile/services/plugins/PlatformImplementation.js packages/app-mobile/services/profiles/index.js +packages/app-mobile/services/voiceTyping/VoiceTyping.js +packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js +packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js packages/app-mobile/services/voiceTyping/vosk.android.js packages/app-mobile/services/voiceTyping/vosk.js +packages/app-mobile/services/voiceTyping/whisper.js packages/app-mobile/setupQuickActions.js packages/app-mobile/tools/buildInjectedJs/BundledFile.js packages/app-mobile/tools/buildInjectedJs/constants.js diff --git a/packages/app-mobile/android/app/build.gradle b/packages/app-mobile/android/app/build.gradle index f736977da..111ff095e 100644 --- a/packages/app-mobile/android/app/build.gradle +++ b/packages/app-mobile/android/app/build.gradle @@ -136,6 +136,10 @@ dependencies { } else { implementation jscFlavor } + + // Needed for Whisper speech-to-text + implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release' + implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release' } apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project) diff --git a/packages/app-mobile/android/app/src/main/AndroidManifest.xml b/packages/app-mobile/android/app/src/main/AndroidManifest.xml index 60022c031..3cb36fdad 100644 --- a/packages/app-mobile/android/app/src/main/AndroidManifest.xml +++ b/packages/app-mobile/android/app/src/main/AndroidManifest.xml @@ -10,6 +10,7 @@ + diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/MainApplication.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/MainApplication.kt index 76075257e..3c4a4bd1f 100644 --- a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/MainApplication.kt +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/MainApplication.kt @@ -11,6 +11,7 @@ import com.facebook.react.defaults.DefaultNewArchitectureEntryPoint.load import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost import com.facebook.react.defaults.DefaultReactNativeHost import com.facebook.soloader.SoLoader +import net.cozic.joplin.audio.SpeechToTextPackage import net.cozic.joplin.versioninfo.SystemVersionInformationPackage import net.cozic.joplin.share.SharePackage import net.cozic.joplin.ssl.SslPackage @@ -25,6 +26,7 @@ class MainApplication : Application(), ReactApplication { add(SslPackage()) add(TextInputPackage()) add(SystemVersionInformationPackage()) + add(SpeechToTextPackage()) } override fun getJSMainModuleName(): String = "index" diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/AudioRecorder.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/AudioRecorder.kt new file mode 100644 index 000000000..6376ea9d6 --- /dev/null +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/AudioRecorder.kt @@ -0,0 +1,101 @@ +package net.cozic.joplin.audio + +import android.Manifest +import android.annotation.SuppressLint +import android.content.Context +import android.content.pm.PackageManager +import android.media.AudioFormat +import android.media.AudioRecord +import android.media.MediaRecorder.AudioSource +import java.io.Closeable +import kotlin.math.max +import kotlin.math.min + +typealias AudioRecorderFactory = (context: Context)->AudioRecorder; + +class AudioRecorder(context: Context) : Closeable { + private val sampleRate = 16_000 + private val maxLengthSeconds = 30 // Whisper supports a maximum of 30s + private val maxBufferSize = sampleRate * maxLengthSeconds + private val buffer = FloatArray(maxBufferSize) + private var bufferWriteOffset = 0 + + // Accessor must not modify result + val bufferedData: FloatArray get() = buffer.sliceArray(0 until bufferWriteOffset) + val bufferLengthSeconds: Double get() = bufferWriteOffset.toDouble() / sampleRate + + init { + val permissionResult = context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) + if (permissionResult == PackageManager.PERMISSION_DENIED) { + throw SecurityException("Missing RECORD_AUDIO permission!") + } + } + + // Permissions check is included above + @SuppressLint("MissingPermission") + private val recorder = AudioRecord.Builder() + .setAudioSource(AudioSource.MIC) + .setAudioFormat( + AudioFormat.Builder() + // PCM: A WAV format + .setEncoding(AudioFormat.ENCODING_PCM_FLOAT) + .setSampleRate(sampleRate) + .setChannelMask(AudioFormat.CHANNEL_IN_MONO) + .build() + ) + .setBufferSizeInBytes(maxBufferSize * Float.SIZE_BYTES) + .build() + + // Discards the first [samples] samples from the start of the buffer. Conceptually, this + // advances the buffer's start point. + private fun advanceStartBySamples(samples: Int) { + val samplesClamped = min(samples, maxBufferSize) + val remainingBuffer = buffer.sliceArray(samplesClamped until maxBufferSize) + + buffer.fill(0f, samplesClamped, maxBufferSize) + remainingBuffer.copyInto(buffer, 0) + bufferWriteOffset = max(bufferWriteOffset - samplesClamped, 0) + } + + fun dropFirstSeconds(seconds: Double) { + advanceStartBySamples((seconds * sampleRate).toInt()) + } + + fun start() { + recorder.startRecording() + } + + private fun read(requestedSize: Int, mode: Int) { + val size = min(requestedSize, maxBufferSize - bufferWriteOffset) + val sizeRead = recorder.read(buffer, bufferWriteOffset, size, mode) + if (sizeRead > 0) { + bufferWriteOffset += sizeRead + } + } + + // Pulls all available data from the audio recorder's buffer + fun pullAvailable() { + return read(maxBufferSize, AudioRecord.READ_NON_BLOCKING) + } + + fun pullNextSeconds(seconds: Double) { + val remainingSize = maxBufferSize - bufferWriteOffset + val requestedSize = (seconds * sampleRate).toInt() + + // If low on size, make more room. + if (remainingSize < maxBufferSize / 3) { + advanceStartBySamples(maxBufferSize / 3) + } + + return read(requestedSize, AudioRecord.READ_BLOCKING) + } + + override fun close() { + recorder.stop() + recorder.release() + } + + companion object { + val factory: AudioRecorderFactory = { context -> AudioRecorder(context) } + } +} \ No newline at end of file diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/InvalidSessionIdException.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/InvalidSessionIdException.kt new file mode 100644 index 000000000..38615fe42 --- /dev/null +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/InvalidSessionIdException.kt @@ -0,0 +1,5 @@ +package net.cozic.joplin.audio + + +class InvalidSessionIdException(id: Int) : IllegalArgumentException("Invalid session ID $id") { +} \ No newline at end of file diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextConverter.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextConverter.kt new file mode 100644 index 000000000..6db9faa76 --- /dev/null +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextConverter.kt @@ -0,0 +1,136 @@ +package net.cozic.joplin.audio + +import ai.onnxruntime.OnnxTensor +import ai.onnxruntime.OrtEnvironment +import ai.onnxruntime.OrtSession +import ai.onnxruntime.extensions.OrtxPackage +import android.annotation.SuppressLint +import android.content.Context +import android.util.Log +import java.io.Closeable +import java.nio.FloatBuffer +import java.nio.IntBuffer +import kotlin.time.DurationUnit +import kotlin.time.measureTimedValue + +class SpeechToTextConverter( + modelPath: String, + locale: String, + recorderFactory: AudioRecorderFactory, + private val environment: OrtEnvironment, + context: Context, +) : Closeable { + private val recorder = recorderFactory(context) + private val session: OrtSession = environment.createSession( + modelPath, + OrtSession.SessionOptions().apply { + // Needed for audio decoding + registerCustomOpLibrary(OrtxPackage.getLibraryPath()) + }, + ) + private val languageCode = Regex("_.*").replace(locale, "") + private val decoderInputIds = when (languageCode) { + // Add 50363 to the end to omit timestamps + "en" -> intArrayOf(50258, 50259, 50359) + "fr" -> intArrayOf(50258, 50265, 50359) + "es" -> intArrayOf(50258, 50262, 50359) + "de" -> intArrayOf(50258, 50261, 50359) + "it" -> intArrayOf(50258, 50274, 50359) + "nl" -> intArrayOf(50258, 50271, 50359) + "ko" -> intArrayOf(50258, 50264, 50359) + "th" -> intArrayOf(50258, 50289, 50359) + "ru" -> intArrayOf(50258, 50263, 50359) + "pt" -> intArrayOf(50258, 50267, 50359) + "pl" -> intArrayOf(50258, 50269, 50359) + "id" -> intArrayOf(50258, 50275, 50359) + "hi" -> intArrayOf(50258, 50276, 50359) + // Let Whisper guess the language + else -> intArrayOf(50258) + } + + fun start() { + recorder.start() + } + + private fun getInputs(data: FloatArray): MutableMap { + fun intTensor(value: Int) = OnnxTensor.createTensor( + environment, + IntBuffer.wrap(intArrayOf(value)), + longArrayOf(1), + ) + fun floatTensor(value: Float) = OnnxTensor.createTensor( + environment, + FloatBuffer.wrap(floatArrayOf(value)), + longArrayOf(1), + ) + val audioPcmTensor = OnnxTensor.createTensor( + environment, + FloatBuffer.wrap(data), + longArrayOf(1, data.size.toLong()), + ) + val decoderInputIdsTensor = OnnxTensor.createTensor( + environment, + IntBuffer.wrap(decoderInputIds), + longArrayOf(1, decoderInputIds.size.toLong()) + ) + + return mutableMapOf( + "audio_pcm" to audioPcmTensor, + "max_length" to intTensor(412), + "min_length" to intTensor(0), + "num_return_sequences" to intTensor(1), + "num_beams" to intTensor(1), + "length_penalty" to floatTensor(1.1f), + "repetition_penalty" to floatTensor(3f), + "decoder_input_ids" to decoderInputIdsTensor, + + // Required for timestamps + "logits_processor" to intTensor(1) + ) + } + + // TODO .get() fails on older Android versions + @SuppressLint("NewApi") + private fun convert(data: FloatArray): String { + val (inputs, convertInputsTime) = measureTimedValue { + getInputs(data) + } + val (outputs, getOutputsTime) = measureTimedValue { + session.run(inputs, setOf("str")) + } + val mainOutput = outputs.get("str").get().value as Array> + outputs.close() + + Log.i("Whisper", "Converted ${data.size / 16000}s of data in ${ + getOutputsTime.toString(DurationUnit.SECONDS, 2) + } converted inputs in ${convertInputsTime.inWholeMilliseconds}ms") + return mainOutput[0][0] + } + + fun dropFirstSeconds(seconds: Double) { + Log.i("Whisper", "Drop first seconds $seconds") + recorder.dropFirstSeconds(seconds) + } + + val bufferLengthSeconds: Double get() = recorder.bufferLengthSeconds + + fun expandBufferAndConvert(seconds: Double): String { + recorder.pullNextSeconds(seconds) + // Also pull any extra available data, in case the speech-to-text converter + // is lagging behind the audio recorder. + recorder.pullAvailable() + + return convert(recorder.bufferedData) + } + + // Converts as many seconds of buffered data as possible, without waiting + fun expandBufferAndConvert(): String { + recorder.pullAvailable() + return convert(recorder.bufferedData) + } + + override fun close() { + recorder.close() + session.close() + } +} \ No newline at end of file diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextPackage.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextPackage.kt new file mode 100644 index 000000000..8845d6ebd --- /dev/null +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextPackage.kt @@ -0,0 +1,86 @@ +package net.cozic.joplin.audio + +import ai.onnxruntime.OrtEnvironment +import com.facebook.react.ReactPackage +import com.facebook.react.bridge.LifecycleEventListener +import com.facebook.react.bridge.NativeModule +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReactContextBaseJavaModule +import com.facebook.react.bridge.ReactMethod +import com.facebook.react.uimanager.ViewManager +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors + +class SpeechToTextPackage : ReactPackage { + override fun createNativeModules(reactContext: ReactApplicationContext): List { + return listOf(SpeechToTextModule(reactContext)) + } + + override fun createViewManagers(reactContext: ReactApplicationContext): List> { + return emptyList() + } + + class SpeechToTextModule( + private var context: ReactApplicationContext, + ) : ReactContextBaseJavaModule(context), LifecycleEventListener { + private var environment: OrtEnvironment? = null + private val executorService: ExecutorService = Executors.newFixedThreadPool(1) + private val sessionManager = SpeechToTextSessionManager(executorService) + + override fun getName() = "SpeechToTextModule" + + override fun onHostResume() { } + override fun onHostPause() { } + override fun onHostDestroy() { + environment?.close() + } + + @ReactMethod + fun openSession(modelPath: String, locale: String, promise: Promise) { + val appContext = context.applicationContext + // Initialize environment as late as possible: + val ortEnvironment = environment ?: OrtEnvironment.getEnvironment() + if (environment != null) { + environment = ortEnvironment + } + + try { + val sessionId = sessionManager.openSession(modelPath, locale, ortEnvironment, appContext) + promise.resolve(sessionId) + } catch (exception: Throwable) { + promise.reject(exception) + } + } + + @ReactMethod + fun startRecording(sessionId: Int, promise: Promise) { + sessionManager.startRecording(sessionId, promise) + } + + @ReactMethod + fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { + sessionManager.getBufferLengthSeconds(sessionId, promise) + } + + @ReactMethod + fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { + sessionManager.dropFirstSeconds(sessionId, duration, promise) + } + + @ReactMethod + fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { + sessionManager.expandBufferAndConvert(sessionId, duration, promise) + } + + @ReactMethod + fun convertAvailable(sessionId: Int, promise: Promise) { + sessionManager.convertAvailable(sessionId, promise) + } + + @ReactMethod + fun closeSession(sessionId: Int, promise: Promise) { + sessionManager.closeSession(sessionId, promise) + } + } +} \ No newline at end of file diff --git a/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextSessionManager.kt b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextSessionManager.kt new file mode 100644 index 000000000..3610b4bc3 --- /dev/null +++ b/packages/app-mobile/android/app/src/main/java/net/cozic/joplin/audio/SpeechToTextSessionManager.kt @@ -0,0 +1,111 @@ +package net.cozic.joplin.audio + +import ai.onnxruntime.OrtEnvironment +import android.content.Context +import com.facebook.react.bridge.Promise +import java.util.concurrent.Executor +import java.util.concurrent.locks.ReentrantLock + +class SpeechToTextSession ( + val converter: SpeechToTextConverter +) { + val mutex = ReentrantLock() +} + +class SpeechToTextSessionManager( + private var executor: Executor, +) { + private val sessions: MutableMap = mutableMapOf() + private var nextSessionId: Int = 0 + + fun openSession( + modelPath: String, + locale: String, + environment: OrtEnvironment, + context: Context, + ): Int { + val sessionId = nextSessionId++ + sessions[sessionId] = SpeechToTextSession( + SpeechToTextConverter( + modelPath, locale, recorderFactory = AudioRecorder.factory, environment, context, + ) + ) + return sessionId + } + + private fun getSession(id: Int): SpeechToTextSession { + return sessions[id] ?: throw InvalidSessionIdException(id) + } + + private fun concurrentWithSession( + id: Int, + callback: (session: SpeechToTextSession)->Unit, + ) { + executor.execute { + val session = getSession(id) + session.mutex.lock() + try { + callback(session) + } finally { + session.mutex.unlock() + } + } + } + private fun concurrentWithSession( + id: Int, + onError: (error: Throwable)->Unit, + callback: (session: SpeechToTextSession)->Unit, + ) { + return concurrentWithSession(id) { session -> + try { + callback(session) + } catch (error: Throwable) { + onError(error) + } + } + } + + fun startRecording(sessionId: Int, promise: Promise) { + this.concurrentWithSession(sessionId, promise::reject) { session -> + session.converter.start() + promise.resolve(null) + } + } + + // Left-shifts the recording buffer by [duration] seconds + fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { + this.concurrentWithSession(sessionId, promise::reject) { session -> + session.converter.dropFirstSeconds(duration) + promise.resolve(sessionId) + } + } + + fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { + this.concurrentWithSession(sessionId, promise::reject) { session -> + promise.resolve(session.converter.bufferLengthSeconds) + } + } + + // Waits for the next [duration] seconds to become available, then converts + fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { + this.concurrentWithSession(sessionId, promise::reject) { session -> + val result = session.converter.expandBufferAndConvert(duration) + promise.resolve(result) + } + } + + // Converts all available recorded data + fun convertAvailable(sessionId: Int, promise: Promise) { + this.concurrentWithSession(sessionId, promise::reject) { session -> + val result = session.converter.expandBufferAndConvert() + promise.resolve(result) + } + } + + fun closeSession(sessionId: Int, promise: Promise) { + this.concurrentWithSession(sessionId) { session -> + session.converter.close() + promise.resolve(null) + } + } +} diff --git a/packages/app-mobile/components/screens/Note.tsx b/packages/app-mobile/components/screens/Note.tsx index a5f4253d8..d97eacaf4 100644 --- a/packages/app-mobile/components/screens/Note.tsx +++ b/packages/app-mobile/components/screens/Note.tsx @@ -45,7 +45,6 @@ import ImageEditor from '../NoteEditor/ImageEditor/ImageEditor'; import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave'; import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource'; import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog'; -import { voskEnabled } from '../../services/voiceTyping/vosk'; import { isSupportedLanguage } from '../../services/voiceTyping/vosk'; import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events'; import { join } from 'path'; @@ -1204,8 +1203,8 @@ class NoteScreenComponent extends BaseScreenComponent implements B }); } - // Voice typing is enabled only for French language and on Android for now - if (voskEnabled && shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { + // Voice typing is enabled only on Android for now + if (shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { output.push({ title: _('Voice typing...'), onPress: () => { diff --git a/packages/app-mobile/components/voiceTyping/VoiceTypingDialog.tsx b/packages/app-mobile/components/voiceTyping/VoiceTypingDialog.tsx index 8989a5a89..0effd0a3c 100644 --- a/packages/app-mobile/components/voiceTyping/VoiceTypingDialog.tsx +++ b/packages/app-mobile/components/voiceTyping/VoiceTypingDialog.tsx @@ -1,14 +1,18 @@ import * as React from 'react'; -import { useState, useEffect, useCallback } from 'react'; -import { Banner, ActivityIndicator } from 'react-native-paper'; +import { useState, useEffect, useCallback, useRef, useMemo } from 'react'; +import { Banner, ActivityIndicator, Text } from 'react-native-paper'; import { _, languageName } from '@joplin/lib/locale'; import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect'; -import { getVosk, Recorder, startRecording, Vosk } from '../../services/voiceTyping/vosk'; import { IconSource } from 'react-native-paper/lib/typescript/components/Icon'; -import { modelIsDownloaded } from '../../services/voiceTyping/vosk'; +import VoiceTyping, { OnTextCallback, VoiceTypingSession } from '../../services/voiceTyping/VoiceTyping'; +import whisper from '../../services/voiceTyping/whisper'; +import vosk from '../../services/voiceTyping/vosk'; +import { AppState } from '../../utils/types'; +import { connect } from 'react-redux'; interface Props { locale: string; + provider: string; onDismiss: ()=> void; onText: (text: string)=> void; } @@ -21,44 +25,77 @@ enum RecorderState { Downloading = 5, } -const useVosk = (locale: string): [Error | null, boolean, Vosk|null] => { - const [vosk, setVosk] = useState(null); +interface UseVoiceTypingProps { + locale: string; + provider: string; + onSetPreview: OnTextCallback; + onText: OnTextCallback; +} + +const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => { + const [voiceTyping, setVoiceTyping] = useState(null); const [error, setError] = useState(null); const [mustDownloadModel, setMustDownloadModel] = useState(null); - useAsyncEffect(async (event: AsyncEffectEvent) => { - if (mustDownloadModel === null) return; + const onTextRef = useRef(onText); + onTextRef.current = onText; + const onSetPreviewRef = useRef(onSetPreview); + onSetPreviewRef.current = onSetPreview; + const voiceTypingRef = useRef(voiceTyping); + voiceTypingRef.current = voiceTyping; + + const builder = useMemo(() => { + return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]); + }, [locale, provider]); + + useAsyncEffect(async (event: AsyncEffectEvent) => { try { - const v = await getVosk(locale); + await voiceTypingRef.current?.stop(); + + if (!await builder.isDownloaded()) { + if (event.cancelled) return; + await builder.download(); + } if (event.cancelled) return; - setVosk(v); + + const voiceTyping = await builder.build({ + onPreview: (text) => onSetPreviewRef.current(text), + onFinalize: (text) => onTextRef.current(text), + }); + if (event.cancelled) return; + setVoiceTyping(voiceTyping); } catch (error) { setError(error); } finally { setMustDownloadModel(false); } - }, [locale, mustDownloadModel]); + }, [builder]); useAsyncEffect(async (_event: AsyncEffectEvent) => { - setMustDownloadModel(!(await modelIsDownloaded(locale))); - }, [locale]); + setMustDownloadModel(!(await builder.isDownloaded())); + }, [builder]); - return [error, mustDownloadModel, vosk]; + return [error, mustDownloadModel, voiceTyping]; }; -export default (props: Props) => { - const [recorder, setRecorder] = useState(null); +const VoiceTypingDialog: React.FC = props => { const [recorderState, setRecorderState] = useState(RecorderState.Loading); - const [voskError, mustDownloadModel, vosk] = useVosk(props.locale); + const [preview, setPreview] = useState(''); + const [modelError, mustDownloadModel, voiceTyping] = useWhisper({ + locale: props.locale, + onSetPreview: setPreview, + onText: props.onText, + provider: props.provider, + }); useEffect(() => { - if (voskError) { + if (modelError) { setRecorderState(RecorderState.Error); - } else if (vosk) { + } else if (voiceTyping) { setRecorderState(RecorderState.Recording); } - }, [vosk, voskError]); + }, [voiceTyping, modelError]); useEffect(() => { if (mustDownloadModel) { @@ -68,27 +105,22 @@ export default (props: Props) => { useEffect(() => { if (recorderState === RecorderState.Recording) { - setRecorder(startRecording(vosk, { - onResult: (text: string) => { - props.onText(text); - }, - })); + void voiceTyping.start(); } - }, [recorderState, vosk, props.onText]); + }, [recorderState, voiceTyping, props.onText]); const onDismiss = useCallback(() => { - if (recorder) recorder.cleanup(); + void voiceTyping?.stop(); props.onDismiss(); - }, [recorder, props.onDismiss]); + }, [voiceTyping, props.onDismiss]); const renderContent = () => { - // eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied - const components: Record = { + const components: Record string> = { [RecorderState.Loading]: () => _('Loading...'), [RecorderState.Recording]: () => _('Please record your voice...'), [RecorderState.Processing]: () => _('Converting speech to text...'), [RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)), - [RecorderState.Error]: () => _('Error: %s', voskError.message), + [RecorderState.Error]: () => _('Error: %s', modelError.message), }; return components[recorderState](); @@ -106,6 +138,11 @@ export default (props: Props) => { return components[recorderState]; }; + const renderPreview = () => { + return {preview}; + }; + + const headerAndStatus = {`${_('Voice typing...')}\n${renderContent()}`}; return ( { label: _('Done'), onPress: onDismiss, }, - ]}> - {`${_('Voice typing...')}\n${renderContent()}`} + ]} + > + {headerAndStatus} + {'\n'} + {renderPreview()} ); }; + +export default connect((state: AppState) => ({ + provider: state.settings['voiceTyping.preferredProvider'], +}))(VoiceTypingDialog); diff --git a/packages/app-mobile/jest.setup.js b/packages/app-mobile/jest.setup.js index 93d7ec403..655024e06 100644 --- a/packages/app-mobile/jest.setup.js +++ b/packages/app-mobile/jest.setup.js @@ -80,6 +80,10 @@ jest.mock('react-native-image-picker', () => { return { default: { } }; }); +jest.mock('react-native-zip-archive', () => { + return { default: { } }; +}); + jest.mock('react-native-document-picker', () => ({ default: { } })); // Used by the renderer diff --git a/packages/app-mobile/services/voiceTyping/VoiceTyping.ts b/packages/app-mobile/services/voiceTyping/VoiceTyping.ts new file mode 100644 index 000000000..94f4e755d --- /dev/null +++ b/packages/app-mobile/services/voiceTyping/VoiceTyping.ts @@ -0,0 +1,139 @@ +import shim from '@joplin/lib/shim'; +import Logger from '@joplin/utils/Logger'; +import { PermissionsAndroid, Platform } from 'react-native'; +import { unzip } from 'react-native-zip-archive'; +const md5 = require('md5'); + +const logger = Logger.create('voiceTyping'); + +export type OnTextCallback = (text: string)=> void; + +export interface SpeechToTextCallbacks { + // Called with a block of text that might change in the future + onPreview: OnTextCallback; + // Called with text that will not change and should be added to the document + onFinalize: OnTextCallback; +} + +export interface VoiceTypingSession { + start(): Promise; + stop(): Promise; +} + +export interface BuildProviderOptions { + locale: string; + modelPath: string; + callbacks: SpeechToTextCallbacks; +} + +export interface VoiceTypingProvider { + modelName: string; + supported(): boolean; + modelLocalFilepath(locale: string): string; + getDownloadUrl(locale: string): string; + getUuidPath(locale: string): string; + build(options: BuildProviderOptions): Promise; +} + +export default class VoiceTyping { + private provider: VoiceTypingProvider|null = null; + public constructor( + private locale: string, + providers: VoiceTypingProvider[], + ) { + this.provider = providers.find(p => p.supported()) ?? null; + } + + public supported() { + return this.provider !== null; + } + + private getModelPath() { + const localFilePath = shim.fsDriver().resolveRelativePathWithinDir( + shim.fsDriver().getAppDirectoryPath(), + this.provider.modelLocalFilepath(this.locale), + ); + if (localFilePath === shim.fsDriver().getAppDirectoryPath()) { + throw new Error('Invalid local file path!'); + } + + return localFilePath; + } + + private getUuidPath() { + return shim.fsDriver().resolveRelativePathWithinDir( + shim.fsDriver().getAppDirectoryPath(), + this.provider.getUuidPath(this.locale), + ); + } + + public async isDownloaded() { + return await shim.fsDriver().exists(this.getUuidPath()); + } + + public async download() { + const modelPath = this.getModelPath(); + const modelUrl = this.provider.getDownloadUrl(this.locale); + + await shim.fsDriver().remove(modelPath); + logger.info(`Downloading model from: ${modelUrl}`); + + const isZipped = modelUrl.endsWith('.zip'); + const downloadPath = isZipped ? `${modelPath}.zip` : modelPath; + const response = await shim.fetchBlob(modelUrl, { + path: downloadPath, + }); + + if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); + + if (isZipped) { + const modelName = this.provider.modelName; + const unzipDir = `${shim.fsDriver().getCacheDirectoryPath()}/voice-typing-extract/${modelName}/${this.locale}`; + try { + logger.info(`Unzipping ${downloadPath} => ${unzipDir}`); + + await unzip(downloadPath, unzipDir); + + const contents = await shim.fsDriver().readDirStats(unzipDir); + if (contents.length !== 1) { + logger.error('Expected 1 file or directory but got', contents); + throw new Error(`Expected 1 file or directory, but got ${contents.length}`); + } + + const fullUnzipPath = `${unzipDir}/${contents[0].path}`; + + logger.info(`Moving ${fullUnzipPath} => ${modelPath}`); + await shim.fsDriver().move(fullUnzipPath, modelPath); + + await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8'); + if (!await this.isDownloaded()) { + logger.warn('Model should be downloaded!'); + } + } finally { + await shim.fsDriver().remove(unzipDir); + await shim.fsDriver().remove(downloadPath); + } + } + } + + public async build(callbacks: SpeechToTextCallbacks) { + if (!this.provider) { + throw new Error('No supported provider found!'); + } + + if (!await this.isDownloaded()) { + await this.download(); + } + + const audioPermission = 'android.permission.RECORD_AUDIO'; + if (Platform.OS === 'android' && !await PermissionsAndroid.check(audioPermission)) { + await PermissionsAndroid.request(audioPermission); + } + + return this.provider.build({ + locale: this.locale, + modelPath: this.getModelPath(), + callbacks, + }); + } +} diff --git a/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.ts b/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.ts new file mode 100644 index 000000000..3e5cc951e --- /dev/null +++ b/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.ts @@ -0,0 +1,61 @@ +import splitWhisperText from './splitWhisperText'; + +describe('splitWhisperText', () => { + test.each([ + { + // Should trim at sentence breaks + input: '<|0.00|> This is a test. <|5.00|><|6.00|> This is another sentence. <|7.00|>', + recordingLength: 8, + expected: { + trimTo: 6, + dataBeforeTrim: '<|0.00|> This is a test. ', + dataAfterTrim: ' This is another sentence. <|7.00|>', + }, + }, + { + // Should prefer sentence break splits to non sentence break splits + input: '<|0.00|> This is <|4.00|><|4.50|> a test. <|5.00|><|5.50|> Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', + recordingLength: 8, + expected: { + trimTo: 5.50, + dataBeforeTrim: '<|0.00|> This is <|4.00|><|4.50|> a test. ', + dataAfterTrim: ' Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', + }, + }, + { + // Should avoid splitting for very small timestamps + input: '<|0.00|> This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', + recordingLength: 4, + expected: { + trimTo: 0, + dataBeforeTrim: '', + dataAfterTrim: ' This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', + }, + }, + { + // For larger timestamps, should allow splitting at pauses, even if not on sentence breaks. + input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', + recordingLength: 16, + expected: { + trimTo: 12, + dataBeforeTrim: '<|0.00|> This is a test, ', + dataAfterTrim: ' of splitting on timestamps. <|15.00|>', + }, + }, + { + // Should prefer to break at the end, if a large gap after the last timestamp. + input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', + recordingLength: 30, + expected: { + trimTo: 15, + dataBeforeTrim: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. ', + dataAfterTrim: '', + }, + }, + ])('should prefer to split at the end of sentences (case %#)', ({ input, recordingLength, expected }) => { + const actual = splitWhisperText(input, recordingLength); + expect(actual.trimTo).toBeCloseTo(expected.trimTo); + expect(actual.dataBeforeTrim).toBe(expected.dataBeforeTrim); + expect(actual.dataAfterTrim).toBe(expected.dataAfterTrim); + }); +}); diff --git a/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.ts b/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.ts new file mode 100644 index 000000000..343577f0a --- /dev/null +++ b/packages/app-mobile/services/voiceTyping/utils/splitWhisperText.ts @@ -0,0 +1,65 @@ +// Matches pairs of timestamps or single timestamps. +const timestampExp = /<\|(\d+\.\d*)\|>(?:<\|(\d+\.\d*)\|>)?/g; + +const timestampMatchToNumber = (match: RegExpMatchArray) => { + const firstTimestamp = match[1]; + const secondTimestamp = match[2]; + // Prefer the second timestamp in the pair, to remove leading silence. + const timestamp = Number(secondTimestamp ? secondTimestamp : firstTimestamp); + + // Should always be a finite number (i.e. not NaN) + if (!isFinite(timestamp)) throw new Error(`Timestamp match failed with ${match[0]}`); + + return timestamp; +}; + +const splitWhisperText = (textWithTimestamps: string, recordingLengthSeconds: number) => { + const timestamps = [ + ...textWithTimestamps.matchAll(timestampExp), + ].map(match => { + const timestamp = timestampMatchToNumber(match); + return { timestamp, match }; + }); + + if (!timestamps.length) { + return { trimTo: 0, dataBeforeTrim: '', dataAfterTrim: textWithTimestamps }; + } + + const firstTimestamp = timestamps[0]; + let breakAt = firstTimestamp; + + const lastTimestamp = timestamps[timestamps.length - 1]; + const hasLongPauseAfterData = lastTimestamp.timestamp + 4 < recordingLengthSeconds; + if (hasLongPauseAfterData) { + breakAt = lastTimestamp; + } else { + const textWithTimestampsContentLength = textWithTimestamps.trimEnd().length; + + for (const timestampData of timestamps) { + const { match, timestamp } = timestampData; + const contentBefore = textWithTimestamps.substring(Math.max(match.index - 3, 0), match.index); + const isNearEndOfLatinSentence = contentBefore.match(/[.?!]/); + const isNearEndOfData = match.index + match[0].length >= textWithTimestampsContentLength; + + // Use a heuristic to determine whether to move content from the preview to the document. + // These are based on the maximum buffer length of 30 seconds -- as the buffer gets longer, the + // data should be more likely to be broken into chunks. Where possible, the break should be near + // the end of a sentence: + const canBreak = (timestamp > 4 && isNearEndOfLatinSentence && !isNearEndOfData) + || (timestamp > 8 && !isNearEndOfData) + || timestamp > 16; + if (canBreak) { + breakAt = timestampData; + break; + } + } + } + + const trimTo = breakAt.timestamp; + const dataBeforeTrim = textWithTimestamps.substring(0, breakAt.match.index); + const dataAfterTrim = textWithTimestamps.substring(breakAt.match.index + breakAt.match[0].length); + + return { trimTo, dataBeforeTrim, dataAfterTrim }; +}; + +export default splitWhisperText; diff --git a/packages/app-mobile/services/voiceTyping/vosk.android.ts b/packages/app-mobile/services/voiceTyping/vosk.android.ts index 855e997a5..157249ab5 100644 --- a/packages/app-mobile/services/voiceTyping/vosk.android.ts +++ b/packages/app-mobile/services/voiceTyping/vosk.android.ts @@ -4,9 +4,9 @@ import Setting from '@joplin/lib/models/Setting'; import { rtrimSlashes } from '@joplin/lib/path-utils'; import shim from '@joplin/lib/shim'; import Vosk from 'react-native-vosk'; -import { unzip } from 'react-native-zip-archive'; import RNFetchBlob from 'rn-fetch-blob'; -const md5 = require('md5'); +import { VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; +import { join } from 'path'; const logger = Logger.create('voiceTyping/vosk'); @@ -24,15 +24,6 @@ let vosk_: Record = {}; let state_: State = State.Idle; -export const voskEnabled = true; - -export { Vosk }; - -export interface Recorder { - stop: ()=> Promise; - cleanup: ()=> void; -} - const defaultSupportedLanguages = { 'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip', 'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip', @@ -74,7 +65,7 @@ const getModelDir = (locale: string) => { return `${getUnzipDir(locale)}/model`; }; -const languageModelUrl = (locale: string) => { +const languageModelUrl = (locale: string): string => { const lang = languageCodeOnly(locale).toLowerCase(); if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`); @@ -90,16 +81,11 @@ const languageModelUrl = (locale: string) => { } }; -export const modelIsDownloaded = async (locale: string) => { - const uuidFile = `${getModelDir(locale)}/uuid`; - return shim.fsDriver().exists(uuidFile); -}; -export const getVosk = async (locale: string) => { +export const getVosk = async (modelDir: string, locale: string) => { if (vosk_[locale]) return vosk_[locale]; const vosk = new Vosk(); - const modelDir = await downloadModel(locale); logger.info(`Loading model from ${modelDir}`); await shim.fsDriver().readDirStats(modelDir); const result = await vosk.loadModel(modelDir); @@ -110,51 +96,7 @@ export const getVosk = async (locale: string) => { return vosk; }; -const downloadModel = async (locale: string) => { - const modelUrl = languageModelUrl(locale); - const unzipDir = getUnzipDir(locale); - const zipFilePath = `${unzipDir}.zip`; - const modelDir = getModelDir(locale); - const uuidFile = `${modelDir}/uuid`; - - if (await modelIsDownloaded(locale)) { - logger.info(`Model for ${locale} already exists at ${modelDir}`); - return modelDir; - } - - await shim.fsDriver().remove(unzipDir); - - logger.info(`Downloading model from: ${modelUrl}`); - - const response = await shim.fetchBlob(modelUrl, { - path: zipFilePath, - }); - - if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); - - logger.info(`Unzipping ${zipFilePath} => ${unzipDir}`); - - await unzip(zipFilePath, unzipDir); - - const dirs = await shim.fsDriver().readDirStats(unzipDir); - if (dirs.length !== 1) { - logger.error('Expected 1 directory but got', dirs); - throw new Error(`Expected 1 directory, but got ${dirs.length}`); - } - - const fullUnzipPath = `${unzipDir}/${dirs[0].path}`; - - logger.info(`Moving ${fullUnzipPath} => ${modelDir}`); - await shim.fsDriver().rename(fullUnzipPath, modelDir); - - await shim.fsDriver().writeFile(uuidFile, md5(modelUrl)); - - await shim.fsDriver().remove(zipFilePath); - - return modelDir; -}; - -export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { +export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSession => { if (state_ !== State.Idle) throw new Error('Vosk is already recording'); state_ = State.Recording; @@ -163,10 +105,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied const eventHandlers: any[] = []; // eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied - let finalResultPromiseResolve: Function = null; + const finalResultPromiseResolve: Function = null; // eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied - let finalResultPromiseReject: Function = null; - let finalResultTimeout = false; + const finalResultPromiseReject: Function = null; + const finalResultTimeout = false; const completeRecording = (finalResult: string, error: Error) => { logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error); @@ -212,31 +154,13 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { completeRecording(e.data, null); })); - logger.info('Starting recording...'); - - void vosk.start(); return { - stop: (): Promise => { - logger.info('Stopping recording...'); - - vosk.stopOnly(); - - logger.info('Waiting for final result...'); - - setTimeout(() => { - finalResultTimeout = true; - logger.warn('Timed out waiting for finalResult event'); - completeRecording('', new Error('Could not process your message. Please try again.')); - }, 5000); - - // eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied - return new Promise((resolve: Function, reject: Function) => { - finalResultPromiseResolve = resolve; - finalResultPromiseReject = reject; - }); + start: async () => { + logger.info('Starting recording...'); + await vosk.start(); }, - cleanup: () => { + stop: async () => { if (state_ === State.Recording) { logger.info('Cancelling...'); state_ = State.Completing; @@ -246,3 +170,18 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { }, }; }; + + +const vosk: VoiceTypingProvider = { + supported: () => true, + modelLocalFilepath: (locale: string) => getModelDir(locale), + getDownloadUrl: (locale) => languageModelUrl(locale), + getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'), + build: async ({ callbacks, locale, modelPath }) => { + const vosk = await getVosk(modelPath, locale); + return startRecording(vosk, { onResult: callbacks.onFinalize }); + }, + modelName: 'vosk', +}; + +export default vosk; diff --git a/packages/app-mobile/services/voiceTyping/vosk.ts b/packages/app-mobile/services/voiceTyping/vosk.ts index 78d7ce310..d51289ced 100644 --- a/packages/app-mobile/services/voiceTyping/vosk.ts +++ b/packages/app-mobile/services/voiceTyping/vosk.ts @@ -1,37 +1,14 @@ -// Currently disabled on non-Android platforms +import { VoiceTypingProvider } from './VoiceTyping'; -// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied -type Vosk = any; - -export { Vosk }; - -interface StartOptions { - onResult: (text: string)=> void; -} - -export interface Recorder { - stop: ()=> Promise; - cleanup: ()=> void; -} - -export const isSupportedLanguage = (_locale: string) => { - return false; +const vosk: VoiceTypingProvider = { + supported: () => false, + modelLocalFilepath: () => null, + getDownloadUrl: () => null, + getUuidPath: () => null, + build: async () => { + throw new Error('Unsupported!'); + }, + modelName: 'vosk', }; -export const modelIsDownloaded = async (_locale: string) => { - return false; -}; - -export const getVosk = async (_locale: string) => { - // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied - return {} as any; -}; - -export const startRecording = (_vosk: Vosk, _options: StartOptions): Recorder => { - return { - stop: async () => { return ''; }, - cleanup: () => {}, - }; -}; - -export const voskEnabled = false; +export default vosk; diff --git a/packages/app-mobile/services/voiceTyping/whisper.ts b/packages/app-mobile/services/voiceTyping/whisper.ts new file mode 100644 index 000000000..0db6ae1e5 --- /dev/null +++ b/packages/app-mobile/services/voiceTyping/whisper.ts @@ -0,0 +1,119 @@ +import Setting from '@joplin/lib/models/Setting'; +import shim from '@joplin/lib/shim'; +import Logger from '@joplin/utils/Logger'; +import { rtrimSlashes } from '@joplin/utils/path'; +import { dirname, join } from 'path'; +import { NativeModules } from 'react-native'; +import { SpeechToTextCallbacks, VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; +import splitWhisperText from './utils/splitWhisperText'; + +const logger = Logger.create('voiceTyping/whisper'); + +const { SpeechToTextModule } = NativeModules; + +// Timestamps are in the form <|0.00|>. They seem to be added: +// - After long pauses. +// - Between sentences (in pairs). +// - At the beginning and end of a sequence. +const timestampExp = /<\|(\d+\.\d*)\|>/g; +const postProcessSpeech = (text: string) => { + return text.replace(timestampExp, '').replace(/\[BLANK_AUDIO\]/g, ''); +}; + +class Whisper implements VoiceTypingSession { + private lastPreviewData: string; + private closeCounter = 0; + + public constructor( + private sessionId: number|null, + private callbacks: SpeechToTextCallbacks, + ) { } + + public async start() { + if (this.sessionId === null) { + throw new Error('Session closed.'); + } + try { + logger.debug('starting recorder'); + await SpeechToTextModule.startRecording(this.sessionId); + logger.debug('recorder started'); + + const loopStartCounter = this.closeCounter; + while (this.closeCounter === loopStartCounter) { + logger.debug('reading block'); + const data: string = await SpeechToTextModule.expandBufferAndConvert(this.sessionId, 4); + logger.debug('done reading block. Length', data?.length); + + if (this.sessionId === null) { + logger.debug('Session stopped. Ending inference loop.'); + return; + } + + const recordingLength = await SpeechToTextModule.getBufferLengthSeconds(this.sessionId); + logger.debug('recording length so far', recordingLength); + const { trimTo, dataBeforeTrim, dataAfterTrim } = splitWhisperText(data, recordingLength); + + if (trimTo > 2) { + logger.debug('Trim to', trimTo, 'in recording with length', recordingLength); + this.callbacks.onFinalize(postProcessSpeech(dataBeforeTrim)); + this.callbacks.onPreview(postProcessSpeech(dataAfterTrim)); + this.lastPreviewData = dataAfterTrim; + await SpeechToTextModule.dropFirstSeconds(this.sessionId, trimTo); + } else { + logger.debug('Preview', data); + this.lastPreviewData = data; + this.callbacks.onPreview(postProcessSpeech(data)); + } + } + } catch (error) { + logger.error('Whisper error:', error); + this.lastPreviewData = ''; + await this.stop(); + throw error; + } + } + + public async stop() { + if (this.sessionId === null) { + logger.warn('Session already closed.'); + return; + } + + const sessionId = this.sessionId; + this.sessionId = null; + this.closeCounter ++; + await SpeechToTextModule.closeSession(sessionId); + + if (this.lastPreviewData) { + this.callbacks.onFinalize(postProcessSpeech(this.lastPreviewData)); + } + } +} + +const modelLocalFilepath = () => { + return `${shim.fsDriver().getAppDirectoryPath()}/voice-typing-models/whisper_tiny.onnx`; +}; + +const whisper: VoiceTypingProvider = { + supported: () => !!SpeechToTextModule, + modelLocalFilepath: modelLocalFilepath, + getDownloadUrl: () => { + let urlTemplate = rtrimSlashes(Setting.value('voiceTypingBaseUrl').trim()); + + if (!urlTemplate) { + urlTemplate = 'https://github.com/personalizedrefrigerator/joplin-voice-typing-test/releases/download/test-release/{task}.zip'; + } + + return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx'); + }, + getUuidPath: () => { + return join(dirname(modelLocalFilepath()), 'uuid'); + }, + build: async ({ modelPath, callbacks, locale }) => { + const sessionId = await SpeechToTextModule.openSession(modelPath, locale); + return new Whisper(sessionId, callbacks); + }, + modelName: 'whisper', +}; + +export default whisper; diff --git a/packages/lib/models/settings/builtInMetadata.ts b/packages/lib/models/settings/builtInMetadata.ts index e822d3689..5eed3bb8d 100644 --- a/packages/lib/models/settings/builtInMetadata.ts +++ b/packages/lib/models/settings/builtInMetadata.ts @@ -1615,6 +1615,25 @@ const builtInMetadata = (Setting: typeof SettingType) => { section: 'note', }, + 'voiceTyping.preferredProvider': { + value: 'whisper-tiny', + type: SettingItemType.String, + public: true, + appTypes: [AppType.Mobile], + label: () => _('Preferred voice typing provider'), + isEnum: true, + // For now, iOS and web don't support voice typing. + show: () => shim.mobilePlatform() === 'android', + section: 'note', + + options: () => { + return { + 'vosk': _('Vosk'), + 'whisper-tiny': _('Whisper'), + }; + }, + }, + 'trash.autoDeletionEnabled': { value: true, type: SettingItemType.Bool, diff --git a/packages/tools/cspell/dictionary4.txt b/packages/tools/cspell/dictionary4.txt index ae72b46fd..c29723112 100644 --- a/packages/tools/cspell/dictionary4.txt +++ b/packages/tools/cspell/dictionary4.txt @@ -132,4 +132,6 @@ Famegear rcompare tabindex Backblaze +onnx +onnxruntime treeitem