mirror of
https://github.com/laurent22/joplin.git
synced 2024-12-21 09:38:01 +02:00
Android: Allow switching the voice typing library to Whisper (#11158)
Co-authored-by: Laurent Cozic <laurent22@users.noreply.github.com>
This commit is contained in:
parent
3a316a1dbc
commit
9f5282c8f5
@ -737,8 +737,12 @@ packages/app-mobile/services/BackButtonService.js
|
||||
packages/app-mobile/services/e2ee/RSA.react-native.js
|
||||
packages/app-mobile/services/plugins/PlatformImplementation.js
|
||||
packages/app-mobile/services/profiles/index.js
|
||||
packages/app-mobile/services/voiceTyping/VoiceTyping.js
|
||||
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js
|
||||
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js
|
||||
packages/app-mobile/services/voiceTyping/vosk.android.js
|
||||
packages/app-mobile/services/voiceTyping/vosk.js
|
||||
packages/app-mobile/services/voiceTyping/whisper.js
|
||||
packages/app-mobile/setupQuickActions.js
|
||||
packages/app-mobile/tools/buildInjectedJs/BundledFile.js
|
||||
packages/app-mobile/tools/buildInjectedJs/constants.js
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -714,8 +714,12 @@ packages/app-mobile/services/BackButtonService.js
|
||||
packages/app-mobile/services/e2ee/RSA.react-native.js
|
||||
packages/app-mobile/services/plugins/PlatformImplementation.js
|
||||
packages/app-mobile/services/profiles/index.js
|
||||
packages/app-mobile/services/voiceTyping/VoiceTyping.js
|
||||
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js
|
||||
packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js
|
||||
packages/app-mobile/services/voiceTyping/vosk.android.js
|
||||
packages/app-mobile/services/voiceTyping/vosk.js
|
||||
packages/app-mobile/services/voiceTyping/whisper.js
|
||||
packages/app-mobile/setupQuickActions.js
|
||||
packages/app-mobile/tools/buildInjectedJs/BundledFile.js
|
||||
packages/app-mobile/tools/buildInjectedJs/constants.js
|
||||
|
@ -136,6 +136,10 @@ dependencies {
|
||||
} else {
|
||||
implementation jscFlavor
|
||||
}
|
||||
|
||||
// Needed for Whisper speech-to-text
|
||||
implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release'
|
||||
implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release'
|
||||
}
|
||||
|
||||
apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project)
|
||||
|
@ -10,6 +10,7 @@
|
||||
<uses-permission android:name="android.permission.POST_NOTIFICATIONS" />
|
||||
<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" />
|
||||
<uses-permission android:name="android.permission.READ_MEDIA_IMAGES" />
|
||||
<uses-permission android:name="android.permission.RECORD_AUDIO" />
|
||||
|
||||
<!-- Make these features optional to enable Chromebooks -->
|
||||
<!-- https://github.com/laurent22/joplin/issues/37 -->
|
||||
|
@ -11,6 +11,7 @@ import com.facebook.react.defaults.DefaultNewArchitectureEntryPoint.load
|
||||
import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost
|
||||
import com.facebook.react.defaults.DefaultReactNativeHost
|
||||
import com.facebook.soloader.SoLoader
|
||||
import net.cozic.joplin.audio.SpeechToTextPackage
|
||||
import net.cozic.joplin.versioninfo.SystemVersionInformationPackage
|
||||
import net.cozic.joplin.share.SharePackage
|
||||
import net.cozic.joplin.ssl.SslPackage
|
||||
@ -25,6 +26,7 @@ class MainApplication : Application(), ReactApplication {
|
||||
add(SslPackage())
|
||||
add(TextInputPackage())
|
||||
add(SystemVersionInformationPackage())
|
||||
add(SpeechToTextPackage())
|
||||
}
|
||||
|
||||
override fun getJSMainModuleName(): String = "index"
|
||||
|
@ -0,0 +1,101 @@
|
||||
package net.cozic.joplin.audio
|
||||
|
||||
import android.Manifest
|
||||
import android.annotation.SuppressLint
|
||||
import android.content.Context
|
||||
import android.content.pm.PackageManager
|
||||
import android.media.AudioFormat
|
||||
import android.media.AudioRecord
|
||||
import android.media.MediaRecorder.AudioSource
|
||||
import java.io.Closeable
|
||||
import kotlin.math.max
|
||||
import kotlin.math.min
|
||||
|
||||
typealias AudioRecorderFactory = (context: Context)->AudioRecorder;
|
||||
|
||||
class AudioRecorder(context: Context) : Closeable {
|
||||
private val sampleRate = 16_000
|
||||
private val maxLengthSeconds = 30 // Whisper supports a maximum of 30s
|
||||
private val maxBufferSize = sampleRate * maxLengthSeconds
|
||||
private val buffer = FloatArray(maxBufferSize)
|
||||
private var bufferWriteOffset = 0
|
||||
|
||||
// Accessor must not modify result
|
||||
val bufferedData: FloatArray get() = buffer.sliceArray(0 until bufferWriteOffset)
|
||||
val bufferLengthSeconds: Double get() = bufferWriteOffset.toDouble() / sampleRate
|
||||
|
||||
init {
|
||||
val permissionResult = context.checkSelfPermission(Manifest.permission.RECORD_AUDIO)
|
||||
if (permissionResult == PackageManager.PERMISSION_DENIED) {
|
||||
throw SecurityException("Missing RECORD_AUDIO permission!")
|
||||
}
|
||||
}
|
||||
|
||||
// Permissions check is included above
|
||||
@SuppressLint("MissingPermission")
|
||||
private val recorder = AudioRecord.Builder()
|
||||
.setAudioSource(AudioSource.MIC)
|
||||
.setAudioFormat(
|
||||
AudioFormat.Builder()
|
||||
// PCM: A WAV format
|
||||
.setEncoding(AudioFormat.ENCODING_PCM_FLOAT)
|
||||
.setSampleRate(sampleRate)
|
||||
.setChannelMask(AudioFormat.CHANNEL_IN_MONO)
|
||||
.build()
|
||||
)
|
||||
.setBufferSizeInBytes(maxBufferSize * Float.SIZE_BYTES)
|
||||
.build()
|
||||
|
||||
// Discards the first [samples] samples from the start of the buffer. Conceptually, this
|
||||
// advances the buffer's start point.
|
||||
private fun advanceStartBySamples(samples: Int) {
|
||||
val samplesClamped = min(samples, maxBufferSize)
|
||||
val remainingBuffer = buffer.sliceArray(samplesClamped until maxBufferSize)
|
||||
|
||||
buffer.fill(0f, samplesClamped, maxBufferSize)
|
||||
remainingBuffer.copyInto(buffer, 0)
|
||||
bufferWriteOffset = max(bufferWriteOffset - samplesClamped, 0)
|
||||
}
|
||||
|
||||
fun dropFirstSeconds(seconds: Double) {
|
||||
advanceStartBySamples((seconds * sampleRate).toInt())
|
||||
}
|
||||
|
||||
fun start() {
|
||||
recorder.startRecording()
|
||||
}
|
||||
|
||||
private fun read(requestedSize: Int, mode: Int) {
|
||||
val size = min(requestedSize, maxBufferSize - bufferWriteOffset)
|
||||
val sizeRead = recorder.read(buffer, bufferWriteOffset, size, mode)
|
||||
if (sizeRead > 0) {
|
||||
bufferWriteOffset += sizeRead
|
||||
}
|
||||
}
|
||||
|
||||
// Pulls all available data from the audio recorder's buffer
|
||||
fun pullAvailable() {
|
||||
return read(maxBufferSize, AudioRecord.READ_NON_BLOCKING)
|
||||
}
|
||||
|
||||
fun pullNextSeconds(seconds: Double) {
|
||||
val remainingSize = maxBufferSize - bufferWriteOffset
|
||||
val requestedSize = (seconds * sampleRate).toInt()
|
||||
|
||||
// If low on size, make more room.
|
||||
if (remainingSize < maxBufferSize / 3) {
|
||||
advanceStartBySamples(maxBufferSize / 3)
|
||||
}
|
||||
|
||||
return read(requestedSize, AudioRecord.READ_BLOCKING)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
recorder.stop()
|
||||
recorder.release()
|
||||
}
|
||||
|
||||
companion object {
|
||||
val factory: AudioRecorderFactory = { context -> AudioRecorder(context) }
|
||||
}
|
||||
}
|
@ -0,0 +1,5 @@
|
||||
package net.cozic.joplin.audio
|
||||
|
||||
|
||||
class InvalidSessionIdException(id: Int) : IllegalArgumentException("Invalid session ID $id") {
|
||||
}
|
@ -0,0 +1,136 @@
|
||||
package net.cozic.joplin.audio
|
||||
|
||||
import ai.onnxruntime.OnnxTensor
|
||||
import ai.onnxruntime.OrtEnvironment
|
||||
import ai.onnxruntime.OrtSession
|
||||
import ai.onnxruntime.extensions.OrtxPackage
|
||||
import android.annotation.SuppressLint
|
||||
import android.content.Context
|
||||
import android.util.Log
|
||||
import java.io.Closeable
|
||||
import java.nio.FloatBuffer
|
||||
import java.nio.IntBuffer
|
||||
import kotlin.time.DurationUnit
|
||||
import kotlin.time.measureTimedValue
|
||||
|
||||
class SpeechToTextConverter(
|
||||
modelPath: String,
|
||||
locale: String,
|
||||
recorderFactory: AudioRecorderFactory,
|
||||
private val environment: OrtEnvironment,
|
||||
context: Context,
|
||||
) : Closeable {
|
||||
private val recorder = recorderFactory(context)
|
||||
private val session: OrtSession = environment.createSession(
|
||||
modelPath,
|
||||
OrtSession.SessionOptions().apply {
|
||||
// Needed for audio decoding
|
||||
registerCustomOpLibrary(OrtxPackage.getLibraryPath())
|
||||
},
|
||||
)
|
||||
private val languageCode = Regex("_.*").replace(locale, "")
|
||||
private val decoderInputIds = when (languageCode) {
|
||||
// Add 50363 to the end to omit timestamps
|
||||
"en" -> intArrayOf(50258, 50259, 50359)
|
||||
"fr" -> intArrayOf(50258, 50265, 50359)
|
||||
"es" -> intArrayOf(50258, 50262, 50359)
|
||||
"de" -> intArrayOf(50258, 50261, 50359)
|
||||
"it" -> intArrayOf(50258, 50274, 50359)
|
||||
"nl" -> intArrayOf(50258, 50271, 50359)
|
||||
"ko" -> intArrayOf(50258, 50264, 50359)
|
||||
"th" -> intArrayOf(50258, 50289, 50359)
|
||||
"ru" -> intArrayOf(50258, 50263, 50359)
|
||||
"pt" -> intArrayOf(50258, 50267, 50359)
|
||||
"pl" -> intArrayOf(50258, 50269, 50359)
|
||||
"id" -> intArrayOf(50258, 50275, 50359)
|
||||
"hi" -> intArrayOf(50258, 50276, 50359)
|
||||
// Let Whisper guess the language
|
||||
else -> intArrayOf(50258)
|
||||
}
|
||||
|
||||
fun start() {
|
||||
recorder.start()
|
||||
}
|
||||
|
||||
private fun getInputs(data: FloatArray): MutableMap<String, OnnxTensor> {
|
||||
fun intTensor(value: Int) = OnnxTensor.createTensor(
|
||||
environment,
|
||||
IntBuffer.wrap(intArrayOf(value)),
|
||||
longArrayOf(1),
|
||||
)
|
||||
fun floatTensor(value: Float) = OnnxTensor.createTensor(
|
||||
environment,
|
||||
FloatBuffer.wrap(floatArrayOf(value)),
|
||||
longArrayOf(1),
|
||||
)
|
||||
val audioPcmTensor = OnnxTensor.createTensor(
|
||||
environment,
|
||||
FloatBuffer.wrap(data),
|
||||
longArrayOf(1, data.size.toLong()),
|
||||
)
|
||||
val decoderInputIdsTensor = OnnxTensor.createTensor(
|
||||
environment,
|
||||
IntBuffer.wrap(decoderInputIds),
|
||||
longArrayOf(1, decoderInputIds.size.toLong())
|
||||
)
|
||||
|
||||
return mutableMapOf(
|
||||
"audio_pcm" to audioPcmTensor,
|
||||
"max_length" to intTensor(412),
|
||||
"min_length" to intTensor(0),
|
||||
"num_return_sequences" to intTensor(1),
|
||||
"num_beams" to intTensor(1),
|
||||
"length_penalty" to floatTensor(1.1f),
|
||||
"repetition_penalty" to floatTensor(3f),
|
||||
"decoder_input_ids" to decoderInputIdsTensor,
|
||||
|
||||
// Required for timestamps
|
||||
"logits_processor" to intTensor(1)
|
||||
)
|
||||
}
|
||||
|
||||
// TODO .get() fails on older Android versions
|
||||
@SuppressLint("NewApi")
|
||||
private fun convert(data: FloatArray): String {
|
||||
val (inputs, convertInputsTime) = measureTimedValue {
|
||||
getInputs(data)
|
||||
}
|
||||
val (outputs, getOutputsTime) = measureTimedValue {
|
||||
session.run(inputs, setOf("str"))
|
||||
}
|
||||
val mainOutput = outputs.get("str").get().value as Array<Array<String>>
|
||||
outputs.close()
|
||||
|
||||
Log.i("Whisper", "Converted ${data.size / 16000}s of data in ${
|
||||
getOutputsTime.toString(DurationUnit.SECONDS, 2)
|
||||
} converted inputs in ${convertInputsTime.inWholeMilliseconds}ms")
|
||||
return mainOutput[0][0]
|
||||
}
|
||||
|
||||
fun dropFirstSeconds(seconds: Double) {
|
||||
Log.i("Whisper", "Drop first seconds $seconds")
|
||||
recorder.dropFirstSeconds(seconds)
|
||||
}
|
||||
|
||||
val bufferLengthSeconds: Double get() = recorder.bufferLengthSeconds
|
||||
|
||||
fun expandBufferAndConvert(seconds: Double): String {
|
||||
recorder.pullNextSeconds(seconds)
|
||||
// Also pull any extra available data, in case the speech-to-text converter
|
||||
// is lagging behind the audio recorder.
|
||||
recorder.pullAvailable()
|
||||
|
||||
return convert(recorder.bufferedData)
|
||||
}
|
||||
|
||||
// Converts as many seconds of buffered data as possible, without waiting
|
||||
fun expandBufferAndConvert(): String {
|
||||
recorder.pullAvailable()
|
||||
return convert(recorder.bufferedData)
|
||||
}
|
||||
|
||||
override fun close() {
|
||||
recorder.close()
|
||||
session.close()
|
||||
}
|
||||
}
|
@ -0,0 +1,86 @@
|
||||
package net.cozic.joplin.audio
|
||||
|
||||
import ai.onnxruntime.OrtEnvironment
|
||||
import com.facebook.react.ReactPackage
|
||||
import com.facebook.react.bridge.LifecycleEventListener
|
||||
import com.facebook.react.bridge.NativeModule
|
||||
import com.facebook.react.bridge.Promise
|
||||
import com.facebook.react.bridge.ReactApplicationContext
|
||||
import com.facebook.react.bridge.ReactContextBaseJavaModule
|
||||
import com.facebook.react.bridge.ReactMethod
|
||||
import com.facebook.react.uimanager.ViewManager
|
||||
import java.util.concurrent.ExecutorService
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
class SpeechToTextPackage : ReactPackage {
|
||||
override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> {
|
||||
return listOf<NativeModule>(SpeechToTextModule(reactContext))
|
||||
}
|
||||
|
||||
override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> {
|
||||
return emptyList()
|
||||
}
|
||||
|
||||
class SpeechToTextModule(
|
||||
private var context: ReactApplicationContext,
|
||||
) : ReactContextBaseJavaModule(context), LifecycleEventListener {
|
||||
private var environment: OrtEnvironment? = null
|
||||
private val executorService: ExecutorService = Executors.newFixedThreadPool(1)
|
||||
private val sessionManager = SpeechToTextSessionManager(executorService)
|
||||
|
||||
override fun getName() = "SpeechToTextModule"
|
||||
|
||||
override fun onHostResume() { }
|
||||
override fun onHostPause() { }
|
||||
override fun onHostDestroy() {
|
||||
environment?.close()
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun openSession(modelPath: String, locale: String, promise: Promise) {
|
||||
val appContext = context.applicationContext
|
||||
// Initialize environment as late as possible:
|
||||
val ortEnvironment = environment ?: OrtEnvironment.getEnvironment()
|
||||
if (environment != null) {
|
||||
environment = ortEnvironment
|
||||
}
|
||||
|
||||
try {
|
||||
val sessionId = sessionManager.openSession(modelPath, locale, ortEnvironment, appContext)
|
||||
promise.resolve(sessionId)
|
||||
} catch (exception: Throwable) {
|
||||
promise.reject(exception)
|
||||
}
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun startRecording(sessionId: Int, promise: Promise) {
|
||||
sessionManager.startRecording(sessionId, promise)
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun getBufferLengthSeconds(sessionId: Int, promise: Promise) {
|
||||
sessionManager.getBufferLengthSeconds(sessionId, promise)
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) {
|
||||
sessionManager.dropFirstSeconds(sessionId, duration, promise)
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) {
|
||||
sessionManager.expandBufferAndConvert(sessionId, duration, promise)
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun convertAvailable(sessionId: Int, promise: Promise) {
|
||||
sessionManager.convertAvailable(sessionId, promise)
|
||||
}
|
||||
|
||||
@ReactMethod
|
||||
fun closeSession(sessionId: Int, promise: Promise) {
|
||||
sessionManager.closeSession(sessionId, promise)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,111 @@
|
||||
package net.cozic.joplin.audio
|
||||
|
||||
import ai.onnxruntime.OrtEnvironment
|
||||
import android.content.Context
|
||||
import com.facebook.react.bridge.Promise
|
||||
import java.util.concurrent.Executor
|
||||
import java.util.concurrent.locks.ReentrantLock
|
||||
|
||||
class SpeechToTextSession (
|
||||
val converter: SpeechToTextConverter
|
||||
) {
|
||||
val mutex = ReentrantLock()
|
||||
}
|
||||
|
||||
class SpeechToTextSessionManager(
|
||||
private var executor: Executor,
|
||||
) {
|
||||
private val sessions: MutableMap<Int, SpeechToTextSession> = mutableMapOf()
|
||||
private var nextSessionId: Int = 0
|
||||
|
||||
fun openSession(
|
||||
modelPath: String,
|
||||
locale: String,
|
||||
environment: OrtEnvironment,
|
||||
context: Context,
|
||||
): Int {
|
||||
val sessionId = nextSessionId++
|
||||
sessions[sessionId] = SpeechToTextSession(
|
||||
SpeechToTextConverter(
|
||||
modelPath, locale, recorderFactory = AudioRecorder.factory, environment, context,
|
||||
)
|
||||
)
|
||||
return sessionId
|
||||
}
|
||||
|
||||
private fun getSession(id: Int): SpeechToTextSession {
|
||||
return sessions[id] ?: throw InvalidSessionIdException(id)
|
||||
}
|
||||
|
||||
private fun concurrentWithSession(
|
||||
id: Int,
|
||||
callback: (session: SpeechToTextSession)->Unit,
|
||||
) {
|
||||
executor.execute {
|
||||
val session = getSession(id)
|
||||
session.mutex.lock()
|
||||
try {
|
||||
callback(session)
|
||||
} finally {
|
||||
session.mutex.unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
private fun concurrentWithSession(
|
||||
id: Int,
|
||||
onError: (error: Throwable)->Unit,
|
||||
callback: (session: SpeechToTextSession)->Unit,
|
||||
) {
|
||||
return concurrentWithSession(id) { session ->
|
||||
try {
|
||||
callback(session)
|
||||
} catch (error: Throwable) {
|
||||
onError(error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun startRecording(sessionId: Int, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId, promise::reject) { session ->
|
||||
session.converter.start()
|
||||
promise.resolve(null)
|
||||
}
|
||||
}
|
||||
|
||||
// Left-shifts the recording buffer by [duration] seconds
|
||||
fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId, promise::reject) { session ->
|
||||
session.converter.dropFirstSeconds(duration)
|
||||
promise.resolve(sessionId)
|
||||
}
|
||||
}
|
||||
|
||||
fun getBufferLengthSeconds(sessionId: Int, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId, promise::reject) { session ->
|
||||
promise.resolve(session.converter.bufferLengthSeconds)
|
||||
}
|
||||
}
|
||||
|
||||
// Waits for the next [duration] seconds to become available, then converts
|
||||
fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId, promise::reject) { session ->
|
||||
val result = session.converter.expandBufferAndConvert(duration)
|
||||
promise.resolve(result)
|
||||
}
|
||||
}
|
||||
|
||||
// Converts all available recorded data
|
||||
fun convertAvailable(sessionId: Int, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId, promise::reject) { session ->
|
||||
val result = session.converter.expandBufferAndConvert()
|
||||
promise.resolve(result)
|
||||
}
|
||||
}
|
||||
|
||||
fun closeSession(sessionId: Int, promise: Promise) {
|
||||
this.concurrentWithSession(sessionId) { session ->
|
||||
session.converter.close()
|
||||
promise.resolve(null)
|
||||
}
|
||||
}
|
||||
}
|
@ -45,7 +45,6 @@ import ImageEditor from '../NoteEditor/ImageEditor/ImageEditor';
|
||||
import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave';
|
||||
import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource';
|
||||
import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog';
|
||||
import { voskEnabled } from '../../services/voiceTyping/vosk';
|
||||
import { isSupportedLanguage } from '../../services/voiceTyping/vosk';
|
||||
import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events';
|
||||
import { join } from 'path';
|
||||
@ -1204,8 +1203,8 @@ class NoteScreenComponent extends BaseScreenComponent<Props, State> implements B
|
||||
});
|
||||
}
|
||||
|
||||
// Voice typing is enabled only for French language and on Android for now
|
||||
if (voskEnabled && shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) {
|
||||
// Voice typing is enabled only on Android for now
|
||||
if (shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) {
|
||||
output.push({
|
||||
title: _('Voice typing...'),
|
||||
onPress: () => {
|
||||
|
@ -1,14 +1,18 @@
|
||||
import * as React from 'react';
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { Banner, ActivityIndicator } from 'react-native-paper';
|
||||
import { useState, useEffect, useCallback, useRef, useMemo } from 'react';
|
||||
import { Banner, ActivityIndicator, Text } from 'react-native-paper';
|
||||
import { _, languageName } from '@joplin/lib/locale';
|
||||
import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect';
|
||||
import { getVosk, Recorder, startRecording, Vosk } from '../../services/voiceTyping/vosk';
|
||||
import { IconSource } from 'react-native-paper/lib/typescript/components/Icon';
|
||||
import { modelIsDownloaded } from '../../services/voiceTyping/vosk';
|
||||
import VoiceTyping, { OnTextCallback, VoiceTypingSession } from '../../services/voiceTyping/VoiceTyping';
|
||||
import whisper from '../../services/voiceTyping/whisper';
|
||||
import vosk from '../../services/voiceTyping/vosk';
|
||||
import { AppState } from '../../utils/types';
|
||||
import { connect } from 'react-redux';
|
||||
|
||||
interface Props {
|
||||
locale: string;
|
||||
provider: string;
|
||||
onDismiss: ()=> void;
|
||||
onText: (text: string)=> void;
|
||||
}
|
||||
@ -21,44 +25,77 @@ enum RecorderState {
|
||||
Downloading = 5,
|
||||
}
|
||||
|
||||
const useVosk = (locale: string): [Error | null, boolean, Vosk|null] => {
|
||||
const [vosk, setVosk] = useState<Vosk>(null);
|
||||
interface UseVoiceTypingProps {
|
||||
locale: string;
|
||||
provider: string;
|
||||
onSetPreview: OnTextCallback;
|
||||
onText: OnTextCallback;
|
||||
}
|
||||
|
||||
const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => {
|
||||
const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null);
|
||||
const [error, setError] = useState<Error>(null);
|
||||
const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null);
|
||||
|
||||
useAsyncEffect(async (event: AsyncEffectEvent) => {
|
||||
if (mustDownloadModel === null) return;
|
||||
const onTextRef = useRef(onText);
|
||||
onTextRef.current = onText;
|
||||
const onSetPreviewRef = useRef(onSetPreview);
|
||||
onSetPreviewRef.current = onSetPreview;
|
||||
|
||||
const voiceTypingRef = useRef(voiceTyping);
|
||||
voiceTypingRef.current = voiceTyping;
|
||||
|
||||
const builder = useMemo(() => {
|
||||
return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]);
|
||||
}, [locale, provider]);
|
||||
|
||||
useAsyncEffect(async (event: AsyncEffectEvent) => {
|
||||
try {
|
||||
const v = await getVosk(locale);
|
||||
await voiceTypingRef.current?.stop();
|
||||
|
||||
if (!await builder.isDownloaded()) {
|
||||
if (event.cancelled) return;
|
||||
await builder.download();
|
||||
}
|
||||
if (event.cancelled) return;
|
||||
setVosk(v);
|
||||
|
||||
const voiceTyping = await builder.build({
|
||||
onPreview: (text) => onSetPreviewRef.current(text),
|
||||
onFinalize: (text) => onTextRef.current(text),
|
||||
});
|
||||
if (event.cancelled) return;
|
||||
setVoiceTyping(voiceTyping);
|
||||
} catch (error) {
|
||||
setError(error);
|
||||
} finally {
|
||||
setMustDownloadModel(false);
|
||||
}
|
||||
}, [locale, mustDownloadModel]);
|
||||
}, [builder]);
|
||||
|
||||
useAsyncEffect(async (_event: AsyncEffectEvent) => {
|
||||
setMustDownloadModel(!(await modelIsDownloaded(locale)));
|
||||
}, [locale]);
|
||||
setMustDownloadModel(!(await builder.isDownloaded()));
|
||||
}, [builder]);
|
||||
|
||||
return [error, mustDownloadModel, vosk];
|
||||
return [error, mustDownloadModel, voiceTyping];
|
||||
};
|
||||
|
||||
export default (props: Props) => {
|
||||
const [recorder, setRecorder] = useState<Recorder>(null);
|
||||
const VoiceTypingDialog: React.FC<Props> = props => {
|
||||
const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading);
|
||||
const [voskError, mustDownloadModel, vosk] = useVosk(props.locale);
|
||||
const [preview, setPreview] = useState<string>('');
|
||||
const [modelError, mustDownloadModel, voiceTyping] = useWhisper({
|
||||
locale: props.locale,
|
||||
onSetPreview: setPreview,
|
||||
onText: props.onText,
|
||||
provider: props.provider,
|
||||
});
|
||||
|
||||
useEffect(() => {
|
||||
if (voskError) {
|
||||
if (modelError) {
|
||||
setRecorderState(RecorderState.Error);
|
||||
} else if (vosk) {
|
||||
} else if (voiceTyping) {
|
||||
setRecorderState(RecorderState.Recording);
|
||||
}
|
||||
}, [vosk, voskError]);
|
||||
}, [voiceTyping, modelError]);
|
||||
|
||||
useEffect(() => {
|
||||
if (mustDownloadModel) {
|
||||
@ -68,27 +105,22 @@ export default (props: Props) => {
|
||||
|
||||
useEffect(() => {
|
||||
if (recorderState === RecorderState.Recording) {
|
||||
setRecorder(startRecording(vosk, {
|
||||
onResult: (text: string) => {
|
||||
props.onText(text);
|
||||
},
|
||||
}));
|
||||
void voiceTyping.start();
|
||||
}
|
||||
}, [recorderState, vosk, props.onText]);
|
||||
}, [recorderState, voiceTyping, props.onText]);
|
||||
|
||||
const onDismiss = useCallback(() => {
|
||||
if (recorder) recorder.cleanup();
|
||||
void voiceTyping?.stop();
|
||||
props.onDismiss();
|
||||
}, [recorder, props.onDismiss]);
|
||||
}, [voiceTyping, props.onDismiss]);
|
||||
|
||||
const renderContent = () => {
|
||||
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
|
||||
const components: Record<RecorderState, Function> = {
|
||||
const components: Record<RecorderState, ()=> string> = {
|
||||
[RecorderState.Loading]: () => _('Loading...'),
|
||||
[RecorderState.Recording]: () => _('Please record your voice...'),
|
||||
[RecorderState.Processing]: () => _('Converting speech to text...'),
|
||||
[RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)),
|
||||
[RecorderState.Error]: () => _('Error: %s', voskError.message),
|
||||
[RecorderState.Error]: () => _('Error: %s', modelError.message),
|
||||
};
|
||||
|
||||
return components[recorderState]();
|
||||
@ -106,6 +138,11 @@ export default (props: Props) => {
|
||||
return components[recorderState];
|
||||
};
|
||||
|
||||
const renderPreview = () => {
|
||||
return <Text variant='labelSmall'>{preview}</Text>;
|
||||
};
|
||||
|
||||
const headerAndStatus = <Text variant='bodyMedium'>{`${_('Voice typing...')}\n${renderContent()}`}</Text>;
|
||||
return (
|
||||
<Banner
|
||||
visible={true}
|
||||
@ -115,8 +152,15 @@ export default (props: Props) => {
|
||||
label: _('Done'),
|
||||
onPress: onDismiss,
|
||||
},
|
||||
]}>
|
||||
{`${_('Voice typing...')}\n${renderContent()}`}
|
||||
]}
|
||||
>
|
||||
{headerAndStatus}
|
||||
<Text>{'\n'}</Text>
|
||||
{renderPreview()}
|
||||
</Banner>
|
||||
);
|
||||
};
|
||||
|
||||
export default connect((state: AppState) => ({
|
||||
provider: state.settings['voiceTyping.preferredProvider'],
|
||||
}))(VoiceTypingDialog);
|
||||
|
@ -80,6 +80,10 @@ jest.mock('react-native-image-picker', () => {
|
||||
return { default: { } };
|
||||
});
|
||||
|
||||
jest.mock('react-native-zip-archive', () => {
|
||||
return { default: { } };
|
||||
});
|
||||
|
||||
jest.mock('react-native-document-picker', () => ({ default: { } }));
|
||||
|
||||
// Used by the renderer
|
||||
|
139
packages/app-mobile/services/voiceTyping/VoiceTyping.ts
Normal file
139
packages/app-mobile/services/voiceTyping/VoiceTyping.ts
Normal file
@ -0,0 +1,139 @@
|
||||
import shim from '@joplin/lib/shim';
|
||||
import Logger from '@joplin/utils/Logger';
|
||||
import { PermissionsAndroid, Platform } from 'react-native';
|
||||
import { unzip } from 'react-native-zip-archive';
|
||||
const md5 = require('md5');
|
||||
|
||||
const logger = Logger.create('voiceTyping');
|
||||
|
||||
export type OnTextCallback = (text: string)=> void;
|
||||
|
||||
export interface SpeechToTextCallbacks {
|
||||
// Called with a block of text that might change in the future
|
||||
onPreview: OnTextCallback;
|
||||
// Called with text that will not change and should be added to the document
|
||||
onFinalize: OnTextCallback;
|
||||
}
|
||||
|
||||
export interface VoiceTypingSession {
|
||||
start(): Promise<void>;
|
||||
stop(): Promise<void>;
|
||||
}
|
||||
|
||||
export interface BuildProviderOptions {
|
||||
locale: string;
|
||||
modelPath: string;
|
||||
callbacks: SpeechToTextCallbacks;
|
||||
}
|
||||
|
||||
export interface VoiceTypingProvider {
|
||||
modelName: string;
|
||||
supported(): boolean;
|
||||
modelLocalFilepath(locale: string): string;
|
||||
getDownloadUrl(locale: string): string;
|
||||
getUuidPath(locale: string): string;
|
||||
build(options: BuildProviderOptions): Promise<VoiceTypingSession>;
|
||||
}
|
||||
|
||||
export default class VoiceTyping {
|
||||
private provider: VoiceTypingProvider|null = null;
|
||||
public constructor(
|
||||
private locale: string,
|
||||
providers: VoiceTypingProvider[],
|
||||
) {
|
||||
this.provider = providers.find(p => p.supported()) ?? null;
|
||||
}
|
||||
|
||||
public supported() {
|
||||
return this.provider !== null;
|
||||
}
|
||||
|
||||
private getModelPath() {
|
||||
const localFilePath = shim.fsDriver().resolveRelativePathWithinDir(
|
||||
shim.fsDriver().getAppDirectoryPath(),
|
||||
this.provider.modelLocalFilepath(this.locale),
|
||||
);
|
||||
if (localFilePath === shim.fsDriver().getAppDirectoryPath()) {
|
||||
throw new Error('Invalid local file path!');
|
||||
}
|
||||
|
||||
return localFilePath;
|
||||
}
|
||||
|
||||
private getUuidPath() {
|
||||
return shim.fsDriver().resolveRelativePathWithinDir(
|
||||
shim.fsDriver().getAppDirectoryPath(),
|
||||
this.provider.getUuidPath(this.locale),
|
||||
);
|
||||
}
|
||||
|
||||
public async isDownloaded() {
|
||||
return await shim.fsDriver().exists(this.getUuidPath());
|
||||
}
|
||||
|
||||
public async download() {
|
||||
const modelPath = this.getModelPath();
|
||||
const modelUrl = this.provider.getDownloadUrl(this.locale);
|
||||
|
||||
await shim.fsDriver().remove(modelPath);
|
||||
logger.info(`Downloading model from: ${modelUrl}`);
|
||||
|
||||
const isZipped = modelUrl.endsWith('.zip');
|
||||
const downloadPath = isZipped ? `${modelPath}.zip` : modelPath;
|
||||
const response = await shim.fetchBlob(modelUrl, {
|
||||
path: downloadPath,
|
||||
});
|
||||
|
||||
if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`);
|
||||
|
||||
if (isZipped) {
|
||||
const modelName = this.provider.modelName;
|
||||
const unzipDir = `${shim.fsDriver().getCacheDirectoryPath()}/voice-typing-extract/${modelName}/${this.locale}`;
|
||||
try {
|
||||
logger.info(`Unzipping ${downloadPath} => ${unzipDir}`);
|
||||
|
||||
await unzip(downloadPath, unzipDir);
|
||||
|
||||
const contents = await shim.fsDriver().readDirStats(unzipDir);
|
||||
if (contents.length !== 1) {
|
||||
logger.error('Expected 1 file or directory but got', contents);
|
||||
throw new Error(`Expected 1 file or directory, but got ${contents.length}`);
|
||||
}
|
||||
|
||||
const fullUnzipPath = `${unzipDir}/${contents[0].path}`;
|
||||
|
||||
logger.info(`Moving ${fullUnzipPath} => ${modelPath}`);
|
||||
await shim.fsDriver().move(fullUnzipPath, modelPath);
|
||||
|
||||
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
|
||||
if (!await this.isDownloaded()) {
|
||||
logger.warn('Model should be downloaded!');
|
||||
}
|
||||
} finally {
|
||||
await shim.fsDriver().remove(unzipDir);
|
||||
await shim.fsDriver().remove(downloadPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public async build(callbacks: SpeechToTextCallbacks) {
|
||||
if (!this.provider) {
|
||||
throw new Error('No supported provider found!');
|
||||
}
|
||||
|
||||
if (!await this.isDownloaded()) {
|
||||
await this.download();
|
||||
}
|
||||
|
||||
const audioPermission = 'android.permission.RECORD_AUDIO';
|
||||
if (Platform.OS === 'android' && !await PermissionsAndroid.check(audioPermission)) {
|
||||
await PermissionsAndroid.request(audioPermission);
|
||||
}
|
||||
|
||||
return this.provider.build({
|
||||
locale: this.locale,
|
||||
modelPath: this.getModelPath(),
|
||||
callbacks,
|
||||
});
|
||||
}
|
||||
}
|
@ -0,0 +1,61 @@
|
||||
import splitWhisperText from './splitWhisperText';
|
||||
|
||||
describe('splitWhisperText', () => {
|
||||
test.each([
|
||||
{
|
||||
// Should trim at sentence breaks
|
||||
input: '<|0.00|> This is a test. <|5.00|><|6.00|> This is another sentence. <|7.00|>',
|
||||
recordingLength: 8,
|
||||
expected: {
|
||||
trimTo: 6,
|
||||
dataBeforeTrim: '<|0.00|> This is a test. ',
|
||||
dataAfterTrim: ' This is another sentence. <|7.00|>',
|
||||
},
|
||||
},
|
||||
{
|
||||
// Should prefer sentence break splits to non sentence break splits
|
||||
input: '<|0.00|> This is <|4.00|><|4.50|> a test. <|5.00|><|5.50|> Testing, <|6.00|><|7.00|> this is a test. <|8.00|>',
|
||||
recordingLength: 8,
|
||||
expected: {
|
||||
trimTo: 5.50,
|
||||
dataBeforeTrim: '<|0.00|> This is <|4.00|><|4.50|> a test. ',
|
||||
dataAfterTrim: ' Testing, <|6.00|><|7.00|> this is a test. <|8.00|>',
|
||||
},
|
||||
},
|
||||
{
|
||||
// Should avoid splitting for very small timestamps
|
||||
input: '<|0.00|> This is a test. <|2.00|><|2.30|> Testing! <|3.00|>',
|
||||
recordingLength: 4,
|
||||
expected: {
|
||||
trimTo: 0,
|
||||
dataBeforeTrim: '',
|
||||
dataAfterTrim: ' This is a test. <|2.00|><|2.30|> Testing! <|3.00|>',
|
||||
},
|
||||
},
|
||||
{
|
||||
// For larger timestamps, should allow splitting at pauses, even if not on sentence breaks.
|
||||
input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>',
|
||||
recordingLength: 16,
|
||||
expected: {
|
||||
trimTo: 12,
|
||||
dataBeforeTrim: '<|0.00|> This is a test, ',
|
||||
dataAfterTrim: ' of splitting on timestamps. <|15.00|>',
|
||||
},
|
||||
},
|
||||
{
|
||||
// Should prefer to break at the end, if a large gap after the last timestamp.
|
||||
input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>',
|
||||
recordingLength: 30,
|
||||
expected: {
|
||||
trimTo: 15,
|
||||
dataBeforeTrim: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. ',
|
||||
dataAfterTrim: '',
|
||||
},
|
||||
},
|
||||
])('should prefer to split at the end of sentences (case %#)', ({ input, recordingLength, expected }) => {
|
||||
const actual = splitWhisperText(input, recordingLength);
|
||||
expect(actual.trimTo).toBeCloseTo(expected.trimTo);
|
||||
expect(actual.dataBeforeTrim).toBe(expected.dataBeforeTrim);
|
||||
expect(actual.dataAfterTrim).toBe(expected.dataAfterTrim);
|
||||
});
|
||||
});
|
@ -0,0 +1,65 @@
|
||||
// Matches pairs of timestamps or single timestamps.
|
||||
const timestampExp = /<\|(\d+\.\d*)\|>(?:<\|(\d+\.\d*)\|>)?/g;
|
||||
|
||||
const timestampMatchToNumber = (match: RegExpMatchArray) => {
|
||||
const firstTimestamp = match[1];
|
||||
const secondTimestamp = match[2];
|
||||
// Prefer the second timestamp in the pair, to remove leading silence.
|
||||
const timestamp = Number(secondTimestamp ? secondTimestamp : firstTimestamp);
|
||||
|
||||
// Should always be a finite number (i.e. not NaN)
|
||||
if (!isFinite(timestamp)) throw new Error(`Timestamp match failed with ${match[0]}`);
|
||||
|
||||
return timestamp;
|
||||
};
|
||||
|
||||
const splitWhisperText = (textWithTimestamps: string, recordingLengthSeconds: number) => {
|
||||
const timestamps = [
|
||||
...textWithTimestamps.matchAll(timestampExp),
|
||||
].map(match => {
|
||||
const timestamp = timestampMatchToNumber(match);
|
||||
return { timestamp, match };
|
||||
});
|
||||
|
||||
if (!timestamps.length) {
|
||||
return { trimTo: 0, dataBeforeTrim: '', dataAfterTrim: textWithTimestamps };
|
||||
}
|
||||
|
||||
const firstTimestamp = timestamps[0];
|
||||
let breakAt = firstTimestamp;
|
||||
|
||||
const lastTimestamp = timestamps[timestamps.length - 1];
|
||||
const hasLongPauseAfterData = lastTimestamp.timestamp + 4 < recordingLengthSeconds;
|
||||
if (hasLongPauseAfterData) {
|
||||
breakAt = lastTimestamp;
|
||||
} else {
|
||||
const textWithTimestampsContentLength = textWithTimestamps.trimEnd().length;
|
||||
|
||||
for (const timestampData of timestamps) {
|
||||
const { match, timestamp } = timestampData;
|
||||
const contentBefore = textWithTimestamps.substring(Math.max(match.index - 3, 0), match.index);
|
||||
const isNearEndOfLatinSentence = contentBefore.match(/[.?!]/);
|
||||
const isNearEndOfData = match.index + match[0].length >= textWithTimestampsContentLength;
|
||||
|
||||
// Use a heuristic to determine whether to move content from the preview to the document.
|
||||
// These are based on the maximum buffer length of 30 seconds -- as the buffer gets longer, the
|
||||
// data should be more likely to be broken into chunks. Where possible, the break should be near
|
||||
// the end of a sentence:
|
||||
const canBreak = (timestamp > 4 && isNearEndOfLatinSentence && !isNearEndOfData)
|
||||
|| (timestamp > 8 && !isNearEndOfData)
|
||||
|| timestamp > 16;
|
||||
if (canBreak) {
|
||||
breakAt = timestampData;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const trimTo = breakAt.timestamp;
|
||||
const dataBeforeTrim = textWithTimestamps.substring(0, breakAt.match.index);
|
||||
const dataAfterTrim = textWithTimestamps.substring(breakAt.match.index + breakAt.match[0].length);
|
||||
|
||||
return { trimTo, dataBeforeTrim, dataAfterTrim };
|
||||
};
|
||||
|
||||
export default splitWhisperText;
|
@ -4,9 +4,9 @@ import Setting from '@joplin/lib/models/Setting';
|
||||
import { rtrimSlashes } from '@joplin/lib/path-utils';
|
||||
import shim from '@joplin/lib/shim';
|
||||
import Vosk from 'react-native-vosk';
|
||||
import { unzip } from 'react-native-zip-archive';
|
||||
import RNFetchBlob from 'rn-fetch-blob';
|
||||
const md5 = require('md5');
|
||||
import { VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping';
|
||||
import { join } from 'path';
|
||||
|
||||
const logger = Logger.create('voiceTyping/vosk');
|
||||
|
||||
@ -24,15 +24,6 @@ let vosk_: Record<string, Vosk> = {};
|
||||
|
||||
let state_: State = State.Idle;
|
||||
|
||||
export const voskEnabled = true;
|
||||
|
||||
export { Vosk };
|
||||
|
||||
export interface Recorder {
|
||||
stop: ()=> Promise<string>;
|
||||
cleanup: ()=> void;
|
||||
}
|
||||
|
||||
const defaultSupportedLanguages = {
|
||||
'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip',
|
||||
'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip',
|
||||
@ -74,7 +65,7 @@ const getModelDir = (locale: string) => {
|
||||
return `${getUnzipDir(locale)}/model`;
|
||||
};
|
||||
|
||||
const languageModelUrl = (locale: string) => {
|
||||
const languageModelUrl = (locale: string): string => {
|
||||
const lang = languageCodeOnly(locale).toLowerCase();
|
||||
if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`);
|
||||
|
||||
@ -90,16 +81,11 @@ const languageModelUrl = (locale: string) => {
|
||||
}
|
||||
};
|
||||
|
||||
export const modelIsDownloaded = async (locale: string) => {
|
||||
const uuidFile = `${getModelDir(locale)}/uuid`;
|
||||
return shim.fsDriver().exists(uuidFile);
|
||||
};
|
||||
|
||||
export const getVosk = async (locale: string) => {
|
||||
export const getVosk = async (modelDir: string, locale: string) => {
|
||||
if (vosk_[locale]) return vosk_[locale];
|
||||
|
||||
const vosk = new Vosk();
|
||||
const modelDir = await downloadModel(locale);
|
||||
logger.info(`Loading model from ${modelDir}`);
|
||||
await shim.fsDriver().readDirStats(modelDir);
|
||||
const result = await vosk.loadModel(modelDir);
|
||||
@ -110,51 +96,7 @@ export const getVosk = async (locale: string) => {
|
||||
return vosk;
|
||||
};
|
||||
|
||||
const downloadModel = async (locale: string) => {
|
||||
const modelUrl = languageModelUrl(locale);
|
||||
const unzipDir = getUnzipDir(locale);
|
||||
const zipFilePath = `${unzipDir}.zip`;
|
||||
const modelDir = getModelDir(locale);
|
||||
const uuidFile = `${modelDir}/uuid`;
|
||||
|
||||
if (await modelIsDownloaded(locale)) {
|
||||
logger.info(`Model for ${locale} already exists at ${modelDir}`);
|
||||
return modelDir;
|
||||
}
|
||||
|
||||
await shim.fsDriver().remove(unzipDir);
|
||||
|
||||
logger.info(`Downloading model from: ${modelUrl}`);
|
||||
|
||||
const response = await shim.fetchBlob(modelUrl, {
|
||||
path: zipFilePath,
|
||||
});
|
||||
|
||||
if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`);
|
||||
|
||||
logger.info(`Unzipping ${zipFilePath} => ${unzipDir}`);
|
||||
|
||||
await unzip(zipFilePath, unzipDir);
|
||||
|
||||
const dirs = await shim.fsDriver().readDirStats(unzipDir);
|
||||
if (dirs.length !== 1) {
|
||||
logger.error('Expected 1 directory but got', dirs);
|
||||
throw new Error(`Expected 1 directory, but got ${dirs.length}`);
|
||||
}
|
||||
|
||||
const fullUnzipPath = `${unzipDir}/${dirs[0].path}`;
|
||||
|
||||
logger.info(`Moving ${fullUnzipPath} => ${modelDir}`);
|
||||
await shim.fsDriver().rename(fullUnzipPath, modelDir);
|
||||
|
||||
await shim.fsDriver().writeFile(uuidFile, md5(modelUrl));
|
||||
|
||||
await shim.fsDriver().remove(zipFilePath);
|
||||
|
||||
return modelDir;
|
||||
};
|
||||
|
||||
export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
|
||||
export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSession => {
|
||||
if (state_ !== State.Idle) throw new Error('Vosk is already recording');
|
||||
|
||||
state_ = State.Recording;
|
||||
@ -163,10 +105,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
|
||||
const eventHandlers: any[] = [];
|
||||
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
|
||||
let finalResultPromiseResolve: Function = null;
|
||||
const finalResultPromiseResolve: Function = null;
|
||||
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
|
||||
let finalResultPromiseReject: Function = null;
|
||||
let finalResultTimeout = false;
|
||||
const finalResultPromiseReject: Function = null;
|
||||
const finalResultTimeout = false;
|
||||
|
||||
const completeRecording = (finalResult: string, error: Error) => {
|
||||
logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error);
|
||||
@ -212,31 +154,13 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
|
||||
completeRecording(e.data, null);
|
||||
}));
|
||||
|
||||
logger.info('Starting recording...');
|
||||
|
||||
void vosk.start();
|
||||
|
||||
return {
|
||||
stop: (): Promise<string> => {
|
||||
logger.info('Stopping recording...');
|
||||
|
||||
vosk.stopOnly();
|
||||
|
||||
logger.info('Waiting for final result...');
|
||||
|
||||
setTimeout(() => {
|
||||
finalResultTimeout = true;
|
||||
logger.warn('Timed out waiting for finalResult event');
|
||||
completeRecording('', new Error('Could not process your message. Please try again.'));
|
||||
}, 5000);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied
|
||||
return new Promise((resolve: Function, reject: Function) => {
|
||||
finalResultPromiseResolve = resolve;
|
||||
finalResultPromiseReject = reject;
|
||||
});
|
||||
start: async () => {
|
||||
logger.info('Starting recording...');
|
||||
await vosk.start();
|
||||
},
|
||||
cleanup: () => {
|
||||
stop: async () => {
|
||||
if (state_ === State.Recording) {
|
||||
logger.info('Cancelling...');
|
||||
state_ = State.Completing;
|
||||
@ -246,3 +170,18 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => {
|
||||
},
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
const vosk: VoiceTypingProvider = {
|
||||
supported: () => true,
|
||||
modelLocalFilepath: (locale: string) => getModelDir(locale),
|
||||
getDownloadUrl: (locale) => languageModelUrl(locale),
|
||||
getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'),
|
||||
build: async ({ callbacks, locale, modelPath }) => {
|
||||
const vosk = await getVosk(modelPath, locale);
|
||||
return startRecording(vosk, { onResult: callbacks.onFinalize });
|
||||
},
|
||||
modelName: 'vosk',
|
||||
};
|
||||
|
||||
export default vosk;
|
||||
|
@ -1,37 +1,14 @@
|
||||
// Currently disabled on non-Android platforms
|
||||
import { VoiceTypingProvider } from './VoiceTyping';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
|
||||
type Vosk = any;
|
||||
|
||||
export { Vosk };
|
||||
|
||||
interface StartOptions {
|
||||
onResult: (text: string)=> void;
|
||||
}
|
||||
|
||||
export interface Recorder {
|
||||
stop: ()=> Promise<string>;
|
||||
cleanup: ()=> void;
|
||||
}
|
||||
|
||||
export const isSupportedLanguage = (_locale: string) => {
|
||||
return false;
|
||||
const vosk: VoiceTypingProvider = {
|
||||
supported: () => false,
|
||||
modelLocalFilepath: () => null,
|
||||
getDownloadUrl: () => null,
|
||||
getUuidPath: () => null,
|
||||
build: async () => {
|
||||
throw new Error('Unsupported!');
|
||||
},
|
||||
modelName: 'vosk',
|
||||
};
|
||||
|
||||
export const modelIsDownloaded = async (_locale: string) => {
|
||||
return false;
|
||||
};
|
||||
|
||||
export const getVosk = async (_locale: string) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied
|
||||
return {} as any;
|
||||
};
|
||||
|
||||
export const startRecording = (_vosk: Vosk, _options: StartOptions): Recorder => {
|
||||
return {
|
||||
stop: async () => { return ''; },
|
||||
cleanup: () => {},
|
||||
};
|
||||
};
|
||||
|
||||
export const voskEnabled = false;
|
||||
export default vosk;
|
||||
|
119
packages/app-mobile/services/voiceTyping/whisper.ts
Normal file
119
packages/app-mobile/services/voiceTyping/whisper.ts
Normal file
@ -0,0 +1,119 @@
|
||||
import Setting from '@joplin/lib/models/Setting';
|
||||
import shim from '@joplin/lib/shim';
|
||||
import Logger from '@joplin/utils/Logger';
|
||||
import { rtrimSlashes } from '@joplin/utils/path';
|
||||
import { dirname, join } from 'path';
|
||||
import { NativeModules } from 'react-native';
|
||||
import { SpeechToTextCallbacks, VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping';
|
||||
import splitWhisperText from './utils/splitWhisperText';
|
||||
|
||||
const logger = Logger.create('voiceTyping/whisper');
|
||||
|
||||
const { SpeechToTextModule } = NativeModules;
|
||||
|
||||
// Timestamps are in the form <|0.00|>. They seem to be added:
|
||||
// - After long pauses.
|
||||
// - Between sentences (in pairs).
|
||||
// - At the beginning and end of a sequence.
|
||||
const timestampExp = /<\|(\d+\.\d*)\|>/g;
|
||||
const postProcessSpeech = (text: string) => {
|
||||
return text.replace(timestampExp, '').replace(/\[BLANK_AUDIO\]/g, '');
|
||||
};
|
||||
|
||||
class Whisper implements VoiceTypingSession {
|
||||
private lastPreviewData: string;
|
||||
private closeCounter = 0;
|
||||
|
||||
public constructor(
|
||||
private sessionId: number|null,
|
||||
private callbacks: SpeechToTextCallbacks,
|
||||
) { }
|
||||
|
||||
public async start() {
|
||||
if (this.sessionId === null) {
|
||||
throw new Error('Session closed.');
|
||||
}
|
||||
try {
|
||||
logger.debug('starting recorder');
|
||||
await SpeechToTextModule.startRecording(this.sessionId);
|
||||
logger.debug('recorder started');
|
||||
|
||||
const loopStartCounter = this.closeCounter;
|
||||
while (this.closeCounter === loopStartCounter) {
|
||||
logger.debug('reading block');
|
||||
const data: string = await SpeechToTextModule.expandBufferAndConvert(this.sessionId, 4);
|
||||
logger.debug('done reading block. Length', data?.length);
|
||||
|
||||
if (this.sessionId === null) {
|
||||
logger.debug('Session stopped. Ending inference loop.');
|
||||
return;
|
||||
}
|
||||
|
||||
const recordingLength = await SpeechToTextModule.getBufferLengthSeconds(this.sessionId);
|
||||
logger.debug('recording length so far', recordingLength);
|
||||
const { trimTo, dataBeforeTrim, dataAfterTrim } = splitWhisperText(data, recordingLength);
|
||||
|
||||
if (trimTo > 2) {
|
||||
logger.debug('Trim to', trimTo, 'in recording with length', recordingLength);
|
||||
this.callbacks.onFinalize(postProcessSpeech(dataBeforeTrim));
|
||||
this.callbacks.onPreview(postProcessSpeech(dataAfterTrim));
|
||||
this.lastPreviewData = dataAfterTrim;
|
||||
await SpeechToTextModule.dropFirstSeconds(this.sessionId, trimTo);
|
||||
} else {
|
||||
logger.debug('Preview', data);
|
||||
this.lastPreviewData = data;
|
||||
this.callbacks.onPreview(postProcessSpeech(data));
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Whisper error:', error);
|
||||
this.lastPreviewData = '';
|
||||
await this.stop();
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
public async stop() {
|
||||
if (this.sessionId === null) {
|
||||
logger.warn('Session already closed.');
|
||||
return;
|
||||
}
|
||||
|
||||
const sessionId = this.sessionId;
|
||||
this.sessionId = null;
|
||||
this.closeCounter ++;
|
||||
await SpeechToTextModule.closeSession(sessionId);
|
||||
|
||||
if (this.lastPreviewData) {
|
||||
this.callbacks.onFinalize(postProcessSpeech(this.lastPreviewData));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelLocalFilepath = () => {
|
||||
return `${shim.fsDriver().getAppDirectoryPath()}/voice-typing-models/whisper_tiny.onnx`;
|
||||
};
|
||||
|
||||
const whisper: VoiceTypingProvider = {
|
||||
supported: () => !!SpeechToTextModule,
|
||||
modelLocalFilepath: modelLocalFilepath,
|
||||
getDownloadUrl: () => {
|
||||
let urlTemplate = rtrimSlashes(Setting.value('voiceTypingBaseUrl').trim());
|
||||
|
||||
if (!urlTemplate) {
|
||||
urlTemplate = 'https://github.com/personalizedrefrigerator/joplin-voice-typing-test/releases/download/test-release/{task}.zip';
|
||||
}
|
||||
|
||||
return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx');
|
||||
},
|
||||
getUuidPath: () => {
|
||||
return join(dirname(modelLocalFilepath()), 'uuid');
|
||||
},
|
||||
build: async ({ modelPath, callbacks, locale }) => {
|
||||
const sessionId = await SpeechToTextModule.openSession(modelPath, locale);
|
||||
return new Whisper(sessionId, callbacks);
|
||||
},
|
||||
modelName: 'whisper',
|
||||
};
|
||||
|
||||
export default whisper;
|
@ -1615,6 +1615,25 @@ const builtInMetadata = (Setting: typeof SettingType) => {
|
||||
section: 'note',
|
||||
},
|
||||
|
||||
'voiceTyping.preferredProvider': {
|
||||
value: 'whisper-tiny',
|
||||
type: SettingItemType.String,
|
||||
public: true,
|
||||
appTypes: [AppType.Mobile],
|
||||
label: () => _('Preferred voice typing provider'),
|
||||
isEnum: true,
|
||||
// For now, iOS and web don't support voice typing.
|
||||
show: () => shim.mobilePlatform() === 'android',
|
||||
section: 'note',
|
||||
|
||||
options: () => {
|
||||
return {
|
||||
'vosk': _('Vosk'),
|
||||
'whisper-tiny': _('Whisper'),
|
||||
};
|
||||
},
|
||||
},
|
||||
|
||||
'trash.autoDeletionEnabled': {
|
||||
value: true,
|
||||
type: SettingItemType.Bool,
|
||||
|
@ -132,4 +132,6 @@ Famegear
|
||||
rcompare
|
||||
tabindex
|
||||
Backblaze
|
||||
onnx
|
||||
onnxruntime
|
||||
treeitem
|
||||
|
Loading…
Reference in New Issue
Block a user