You've already forked joplin
							
							
				mirror of
				https://github.com/laurent22/joplin.git
				synced 2025-10-31 00:07:48 +02:00 
			
		
		
		
	Android: Allow switching the voice typing library to Whisper (#11158)
Co-authored-by: Laurent Cozic <laurent22@users.noreply.github.com>
This commit is contained in:
		| @@ -737,8 +737,12 @@ packages/app-mobile/services/BackButtonService.js | ||||
| packages/app-mobile/services/e2ee/RSA.react-native.js | ||||
| packages/app-mobile/services/plugins/PlatformImplementation.js | ||||
| packages/app-mobile/services/profiles/index.js | ||||
| packages/app-mobile/services/voiceTyping/VoiceTyping.js | ||||
| packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js | ||||
| packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js | ||||
| packages/app-mobile/services/voiceTyping/vosk.android.js | ||||
| packages/app-mobile/services/voiceTyping/vosk.js | ||||
| packages/app-mobile/services/voiceTyping/whisper.js | ||||
| packages/app-mobile/setupQuickActions.js | ||||
| packages/app-mobile/tools/buildInjectedJs/BundledFile.js | ||||
| packages/app-mobile/tools/buildInjectedJs/constants.js | ||||
|   | ||||
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -714,8 +714,12 @@ packages/app-mobile/services/BackButtonService.js | ||||
| packages/app-mobile/services/e2ee/RSA.react-native.js | ||||
| packages/app-mobile/services/plugins/PlatformImplementation.js | ||||
| packages/app-mobile/services/profiles/index.js | ||||
| packages/app-mobile/services/voiceTyping/VoiceTyping.js | ||||
| packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js | ||||
| packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js | ||||
| packages/app-mobile/services/voiceTyping/vosk.android.js | ||||
| packages/app-mobile/services/voiceTyping/vosk.js | ||||
| packages/app-mobile/services/voiceTyping/whisper.js | ||||
| packages/app-mobile/setupQuickActions.js | ||||
| packages/app-mobile/tools/buildInjectedJs/BundledFile.js | ||||
| packages/app-mobile/tools/buildInjectedJs/constants.js | ||||
|   | ||||
| @@ -136,6 +136,10 @@ dependencies { | ||||
|     } else { | ||||
|         implementation jscFlavor | ||||
|     } | ||||
|  | ||||
|     // Needed for Whisper speech-to-text | ||||
|     implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release' | ||||
|     implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release' | ||||
| } | ||||
|  | ||||
| apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project) | ||||
|   | ||||
| @@ -10,6 +10,7 @@ | ||||
| 	<uses-permission android:name="android.permission.POST_NOTIFICATIONS" /> | ||||
| 	<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> | ||||
| 	<uses-permission android:name="android.permission.READ_MEDIA_IMAGES" /> | ||||
| 	<uses-permission android:name="android.permission.RECORD_AUDIO" /> | ||||
|  | ||||
| 	<!-- Make these features optional to enable Chromebooks --> | ||||
| 	<!-- https://github.com/laurent22/joplin/issues/37 --> | ||||
|   | ||||
| @@ -11,6 +11,7 @@ import com.facebook.react.defaults.DefaultNewArchitectureEntryPoint.load | ||||
| import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost | ||||
| import com.facebook.react.defaults.DefaultReactNativeHost | ||||
| import com.facebook.soloader.SoLoader | ||||
| import net.cozic.joplin.audio.SpeechToTextPackage | ||||
| import net.cozic.joplin.versioninfo.SystemVersionInformationPackage | ||||
| import net.cozic.joplin.share.SharePackage | ||||
| import net.cozic.joplin.ssl.SslPackage | ||||
| @@ -25,6 +26,7 @@ class MainApplication : Application(), ReactApplication { | ||||
|                     add(SslPackage()) | ||||
|                     add(TextInputPackage()) | ||||
|                     add(SystemVersionInformationPackage()) | ||||
|                     add(SpeechToTextPackage()) | ||||
|                 } | ||||
|  | ||||
|         override fun getJSMainModuleName(): String = "index" | ||||
|   | ||||
| @@ -0,0 +1,101 @@ | ||||
| package net.cozic.joplin.audio | ||||
|  | ||||
| import android.Manifest | ||||
| import android.annotation.SuppressLint | ||||
| import android.content.Context | ||||
| import android.content.pm.PackageManager | ||||
| import android.media.AudioFormat | ||||
| import android.media.AudioRecord | ||||
| import android.media.MediaRecorder.AudioSource | ||||
| import java.io.Closeable | ||||
| import kotlin.math.max | ||||
| import kotlin.math.min | ||||
|  | ||||
| typealias AudioRecorderFactory = (context: Context)->AudioRecorder; | ||||
|  | ||||
| class AudioRecorder(context: Context) : Closeable { | ||||
| 	private val sampleRate = 16_000 | ||||
| 	private val maxLengthSeconds = 30 // Whisper supports a maximum of 30s | ||||
| 	private val maxBufferSize = sampleRate * maxLengthSeconds | ||||
| 	private val buffer = FloatArray(maxBufferSize) | ||||
| 	private var bufferWriteOffset = 0 | ||||
|  | ||||
| 	// Accessor must not modify result | ||||
| 	val bufferedData: FloatArray get() = buffer.sliceArray(0 until bufferWriteOffset) | ||||
| 	val bufferLengthSeconds: Double get() = bufferWriteOffset.toDouble() / sampleRate | ||||
|  | ||||
| 	init { | ||||
| 		val permissionResult = context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) | ||||
| 		if (permissionResult == PackageManager.PERMISSION_DENIED) { | ||||
| 			throw SecurityException("Missing RECORD_AUDIO permission!") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Permissions check is included above | ||||
| 	@SuppressLint("MissingPermission") | ||||
| 	private val recorder = AudioRecord.Builder() | ||||
| 		.setAudioSource(AudioSource.MIC) | ||||
| 		.setAudioFormat( | ||||
| 			AudioFormat.Builder() | ||||
| 				// PCM: A WAV format | ||||
| 				.setEncoding(AudioFormat.ENCODING_PCM_FLOAT) | ||||
| 				.setSampleRate(sampleRate) | ||||
| 				.setChannelMask(AudioFormat.CHANNEL_IN_MONO) | ||||
| 				.build() | ||||
| 		) | ||||
| 		.setBufferSizeInBytes(maxBufferSize * Float.SIZE_BYTES) | ||||
| 		.build() | ||||
|  | ||||
| 	// Discards the first [samples] samples from the start of the buffer. Conceptually, this | ||||
| 	// advances the buffer's start point. | ||||
| 	private fun advanceStartBySamples(samples: Int) { | ||||
| 		val samplesClamped = min(samples, maxBufferSize) | ||||
| 		val remainingBuffer = buffer.sliceArray(samplesClamped until maxBufferSize) | ||||
|  | ||||
| 		buffer.fill(0f, samplesClamped, maxBufferSize) | ||||
| 		remainingBuffer.copyInto(buffer, 0) | ||||
| 		bufferWriteOffset = max(bufferWriteOffset - samplesClamped, 0) | ||||
| 	} | ||||
|  | ||||
| 	fun dropFirstSeconds(seconds: Double) { | ||||
| 		advanceStartBySamples((seconds * sampleRate).toInt()) | ||||
| 	} | ||||
|  | ||||
| 	fun start() { | ||||
| 		recorder.startRecording() | ||||
| 	} | ||||
|  | ||||
| 	private fun read(requestedSize: Int, mode: Int) { | ||||
| 		val size = min(requestedSize, maxBufferSize - bufferWriteOffset) | ||||
| 		val sizeRead = recorder.read(buffer, bufferWriteOffset, size, mode) | ||||
| 		if (sizeRead > 0) { | ||||
| 			bufferWriteOffset += sizeRead | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Pulls all available data from the audio recorder's buffer | ||||
| 	fun pullAvailable() { | ||||
| 		return read(maxBufferSize, AudioRecord.READ_NON_BLOCKING) | ||||
| 	} | ||||
|  | ||||
| 	fun pullNextSeconds(seconds: Double) { | ||||
| 		val remainingSize = maxBufferSize - bufferWriteOffset | ||||
| 		val requestedSize = (seconds * sampleRate).toInt() | ||||
|  | ||||
| 		// If low on size, make more room. | ||||
| 		if (remainingSize < maxBufferSize / 3) { | ||||
| 			advanceStartBySamples(maxBufferSize / 3) | ||||
| 		} | ||||
|  | ||||
| 		return read(requestedSize, AudioRecord.READ_BLOCKING) | ||||
| 	} | ||||
|  | ||||
| 	override fun close() { | ||||
| 		recorder.stop() | ||||
| 		recorder.release() | ||||
| 	} | ||||
|  | ||||
| 	companion object { | ||||
| 		val factory: AudioRecorderFactory = { context -> AudioRecorder(context) } | ||||
| 	} | ||||
| } | ||||
| @@ -0,0 +1,5 @@ | ||||
| package net.cozic.joplin.audio | ||||
|  | ||||
|  | ||||
| class InvalidSessionIdException(id: Int) : IllegalArgumentException("Invalid session ID $id") { | ||||
| } | ||||
| @@ -0,0 +1,136 @@ | ||||
| package net.cozic.joplin.audio | ||||
|  | ||||
| import ai.onnxruntime.OnnxTensor | ||||
| import ai.onnxruntime.OrtEnvironment | ||||
| import ai.onnxruntime.OrtSession | ||||
| import ai.onnxruntime.extensions.OrtxPackage | ||||
| import android.annotation.SuppressLint | ||||
| import android.content.Context | ||||
| import android.util.Log | ||||
| import java.io.Closeable | ||||
| import java.nio.FloatBuffer | ||||
| import java.nio.IntBuffer | ||||
| import kotlin.time.DurationUnit | ||||
| import kotlin.time.measureTimedValue | ||||
|  | ||||
| class SpeechToTextConverter( | ||||
| 	modelPath: String, | ||||
| 	locale: String, | ||||
| 	recorderFactory: AudioRecorderFactory, | ||||
| 	private val environment: OrtEnvironment, | ||||
| 	context: Context, | ||||
| ) : Closeable { | ||||
| 	private val recorder = recorderFactory(context) | ||||
| 	private val session: OrtSession = environment.createSession( | ||||
| 		modelPath, | ||||
| 		OrtSession.SessionOptions().apply { | ||||
| 			// Needed for audio decoding | ||||
| 			registerCustomOpLibrary(OrtxPackage.getLibraryPath()) | ||||
| 		}, | ||||
| 	) | ||||
| 	private val languageCode = Regex("_.*").replace(locale, "") | ||||
| 	private val decoderInputIds = when (languageCode) { | ||||
| 		// Add 50363 to the end to omit timestamps | ||||
| 		"en" -> intArrayOf(50258, 50259, 50359) | ||||
| 		"fr" -> intArrayOf(50258, 50265, 50359) | ||||
| 		"es" -> intArrayOf(50258, 50262, 50359) | ||||
| 		"de" -> intArrayOf(50258, 50261, 50359) | ||||
| 		"it" -> intArrayOf(50258, 50274, 50359) | ||||
| 		"nl" -> intArrayOf(50258, 50271, 50359) | ||||
| 		"ko" -> intArrayOf(50258, 50264, 50359) | ||||
| 		"th" -> intArrayOf(50258, 50289, 50359) | ||||
| 		"ru" -> intArrayOf(50258, 50263, 50359) | ||||
| 		"pt" -> intArrayOf(50258, 50267, 50359) | ||||
| 		"pl" -> intArrayOf(50258, 50269, 50359) | ||||
| 		"id" -> intArrayOf(50258, 50275, 50359) | ||||
| 		"hi" -> intArrayOf(50258, 50276, 50359) | ||||
| 		// Let Whisper guess the language | ||||
| 		else -> intArrayOf(50258) | ||||
| 	} | ||||
|  | ||||
| 	fun start() { | ||||
| 		recorder.start() | ||||
| 	} | ||||
|  | ||||
| 	private fun getInputs(data: FloatArray): MutableMap<String, OnnxTensor> { | ||||
| 		fun intTensor(value: Int) = OnnxTensor.createTensor( | ||||
| 			environment, | ||||
| 			IntBuffer.wrap(intArrayOf(value)), | ||||
| 			longArrayOf(1), | ||||
| 		) | ||||
| 		fun floatTensor(value: Float) = OnnxTensor.createTensor( | ||||
| 			environment, | ||||
| 			FloatBuffer.wrap(floatArrayOf(value)), | ||||
| 			longArrayOf(1), | ||||
| 		) | ||||
| 		val audioPcmTensor = OnnxTensor.createTensor( | ||||
| 			environment, | ||||
| 			FloatBuffer.wrap(data), | ||||
| 			longArrayOf(1, data.size.toLong()), | ||||
| 		) | ||||
| 		val decoderInputIdsTensor = OnnxTensor.createTensor( | ||||
| 			environment, | ||||
| 			IntBuffer.wrap(decoderInputIds), | ||||
| 			longArrayOf(1, decoderInputIds.size.toLong()) | ||||
| 		) | ||||
|  | ||||
| 		return mutableMapOf( | ||||
| 			"audio_pcm" to audioPcmTensor, | ||||
| 			"max_length" to intTensor(412), | ||||
| 			"min_length" to intTensor(0), | ||||
| 			"num_return_sequences" to intTensor(1), | ||||
| 			"num_beams" to intTensor(1), | ||||
| 			"length_penalty" to floatTensor(1.1f), | ||||
| 			"repetition_penalty" to floatTensor(3f), | ||||
| 			"decoder_input_ids" to decoderInputIdsTensor, | ||||
|  | ||||
| 			// Required for timestamps | ||||
| 			"logits_processor" to intTensor(1) | ||||
| 		) | ||||
| 	} | ||||
|  | ||||
| 	// TODO .get() fails on older Android versions | ||||
| 	@SuppressLint("NewApi") | ||||
| 	private fun convert(data: FloatArray): String { | ||||
| 		val (inputs, convertInputsTime) = measureTimedValue { | ||||
| 			getInputs(data) | ||||
| 		} | ||||
| 		val (outputs, getOutputsTime) = measureTimedValue { | ||||
| 			session.run(inputs, setOf("str")) | ||||
| 		} | ||||
| 		val mainOutput = outputs.get("str").get().value as Array<Array<String>> | ||||
| 		outputs.close() | ||||
|  | ||||
| 		Log.i("Whisper", "Converted ${data.size / 16000}s of data in ${ | ||||
| 			getOutputsTime.toString(DurationUnit.SECONDS, 2) | ||||
| 		} converted inputs in ${convertInputsTime.inWholeMilliseconds}ms") | ||||
| 		return mainOutput[0][0] | ||||
| 	} | ||||
|  | ||||
| 	fun dropFirstSeconds(seconds: Double) { | ||||
| 		Log.i("Whisper", "Drop first seconds $seconds") | ||||
| 		recorder.dropFirstSeconds(seconds) | ||||
| 	} | ||||
|  | ||||
| 	val bufferLengthSeconds: Double get() = recorder.bufferLengthSeconds | ||||
|  | ||||
| 	fun expandBufferAndConvert(seconds: Double): String { | ||||
| 		recorder.pullNextSeconds(seconds) | ||||
| 		// Also pull any extra available data, in case the speech-to-text converter | ||||
| 		// is lagging behind the audio recorder. | ||||
| 		recorder.pullAvailable() | ||||
|  | ||||
| 		return convert(recorder.bufferedData) | ||||
| 	} | ||||
|  | ||||
| 	// Converts as many seconds of buffered data as possible, without waiting | ||||
| 	fun expandBufferAndConvert(): String { | ||||
| 		recorder.pullAvailable() | ||||
| 		return convert(recorder.bufferedData) | ||||
| 	} | ||||
|  | ||||
| 	override fun close() { | ||||
| 		recorder.close() | ||||
| 		session.close() | ||||
| 	} | ||||
| } | ||||
| @@ -0,0 +1,86 @@ | ||||
| package net.cozic.joplin.audio | ||||
|  | ||||
| import ai.onnxruntime.OrtEnvironment | ||||
| import com.facebook.react.ReactPackage | ||||
| import com.facebook.react.bridge.LifecycleEventListener | ||||
| import com.facebook.react.bridge.NativeModule | ||||
| import com.facebook.react.bridge.Promise | ||||
| import com.facebook.react.bridge.ReactApplicationContext | ||||
| import com.facebook.react.bridge.ReactContextBaseJavaModule | ||||
| import com.facebook.react.bridge.ReactMethod | ||||
| import com.facebook.react.uimanager.ViewManager | ||||
| import java.util.concurrent.ExecutorService | ||||
| import java.util.concurrent.Executors | ||||
|  | ||||
| class SpeechToTextPackage : ReactPackage { | ||||
| 	override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> { | ||||
| 		return listOf<NativeModule>(SpeechToTextModule(reactContext)) | ||||
| 	} | ||||
|  | ||||
| 	override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> { | ||||
| 		return emptyList() | ||||
| 	} | ||||
|  | ||||
| 	class SpeechToTextModule( | ||||
| 		private var context: ReactApplicationContext, | ||||
| 	) : ReactContextBaseJavaModule(context), LifecycleEventListener { | ||||
| 		private var environment: OrtEnvironment? = null | ||||
| 		private val executorService: ExecutorService = Executors.newFixedThreadPool(1) | ||||
| 		private val sessionManager = SpeechToTextSessionManager(executorService) | ||||
|  | ||||
| 		override fun getName() = "SpeechToTextModule" | ||||
|  | ||||
| 		override fun onHostResume() { } | ||||
| 		override fun onHostPause() { } | ||||
| 		override fun onHostDestroy() { | ||||
| 			environment?.close() | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun openSession(modelPath: String, locale: String, promise: Promise) { | ||||
| 			val appContext = context.applicationContext | ||||
| 			// Initialize environment as late as possible: | ||||
| 			val ortEnvironment = environment ?: OrtEnvironment.getEnvironment() | ||||
| 			if (environment != null) { | ||||
| 				environment = ortEnvironment | ||||
| 			} | ||||
|  | ||||
| 			try { | ||||
| 				val sessionId = sessionManager.openSession(modelPath, locale, ortEnvironment, appContext) | ||||
| 				promise.resolve(sessionId) | ||||
| 			} catch (exception: Throwable) { | ||||
| 				promise.reject(exception) | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun startRecording(sessionId: Int, promise: Promise) { | ||||
| 			sessionManager.startRecording(sessionId, promise) | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { | ||||
| 			sessionManager.getBufferLengthSeconds(sessionId, promise) | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { | ||||
| 			sessionManager.dropFirstSeconds(sessionId, duration, promise) | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { | ||||
| 			sessionManager.expandBufferAndConvert(sessionId, duration, promise) | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun convertAvailable(sessionId: Int, promise: Promise) { | ||||
| 			sessionManager.convertAvailable(sessionId, promise) | ||||
| 		} | ||||
|  | ||||
| 		@ReactMethod | ||||
| 		fun closeSession(sessionId: Int, promise: Promise) { | ||||
| 			sessionManager.closeSession(sessionId, promise) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -0,0 +1,111 @@ | ||||
| package net.cozic.joplin.audio | ||||
|  | ||||
| import ai.onnxruntime.OrtEnvironment | ||||
| import android.content.Context | ||||
| import com.facebook.react.bridge.Promise | ||||
| import java.util.concurrent.Executor | ||||
| import java.util.concurrent.locks.ReentrantLock | ||||
|  | ||||
| class SpeechToTextSession ( | ||||
| 	val converter: SpeechToTextConverter | ||||
| ) { | ||||
| 	val mutex = ReentrantLock() | ||||
| } | ||||
|  | ||||
| class SpeechToTextSessionManager( | ||||
| 	private var executor: Executor, | ||||
| ) { | ||||
| 	private val sessions: MutableMap<Int, SpeechToTextSession> = mutableMapOf() | ||||
| 	private var nextSessionId: Int = 0 | ||||
|  | ||||
| 	fun openSession( | ||||
| 		modelPath: String, | ||||
| 		locale: String, | ||||
| 		environment: OrtEnvironment, | ||||
| 		context: Context, | ||||
| 	): Int { | ||||
| 		val sessionId = nextSessionId++ | ||||
| 		sessions[sessionId] = SpeechToTextSession( | ||||
| 			SpeechToTextConverter( | ||||
| 				modelPath, locale, recorderFactory = AudioRecorder.factory, environment, context, | ||||
| 			) | ||||
| 		) | ||||
| 		return sessionId | ||||
| 	} | ||||
|  | ||||
| 	private fun getSession(id: Int): SpeechToTextSession { | ||||
| 		return sessions[id] ?: throw InvalidSessionIdException(id) | ||||
| 	} | ||||
|  | ||||
| 	private fun concurrentWithSession( | ||||
| 		id: Int, | ||||
| 		callback: (session: SpeechToTextSession)->Unit, | ||||
| 	) { | ||||
| 		executor.execute { | ||||
| 			val session = getSession(id) | ||||
| 			session.mutex.lock() | ||||
| 			try { | ||||
| 				callback(session) | ||||
| 			} finally { | ||||
| 				session.mutex.unlock() | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	private fun concurrentWithSession( | ||||
| 		id: Int, | ||||
| 		onError: (error: Throwable)->Unit, | ||||
| 		callback: (session: SpeechToTextSession)->Unit, | ||||
| 	) { | ||||
| 		return concurrentWithSession(id) { session -> | ||||
| 			try { | ||||
| 				callback(session) | ||||
| 			} catch (error: Throwable) { | ||||
| 				onError(error) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	fun startRecording(sessionId: Int, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||
| 			session.converter.start() | ||||
| 			promise.resolve(null) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Left-shifts the recording buffer by [duration] seconds | ||||
| 	fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||
| 			session.converter.dropFirstSeconds(duration) | ||||
| 			promise.resolve(sessionId) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||
| 			promise.resolve(session.converter.bufferLengthSeconds) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Waits for the next [duration] seconds to become available, then converts | ||||
| 	fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||
| 			val result = session.converter.expandBufferAndConvert(duration) | ||||
| 			promise.resolve(result) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Converts all available recorded data | ||||
| 	fun convertAvailable(sessionId: Int, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||
| 			val result = session.converter.expandBufferAndConvert() | ||||
| 			promise.resolve(result) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	fun closeSession(sessionId: Int, promise: Promise) { | ||||
| 		this.concurrentWithSession(sessionId) { session -> | ||||
| 			session.converter.close() | ||||
| 			promise.resolve(null) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
| @@ -45,7 +45,6 @@ import ImageEditor from '../NoteEditor/ImageEditor/ImageEditor'; | ||||
| import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave'; | ||||
| import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource'; | ||||
| import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog'; | ||||
| import { voskEnabled } from '../../services/voiceTyping/vosk'; | ||||
| import { isSupportedLanguage } from '../../services/voiceTyping/vosk'; | ||||
| import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events'; | ||||
| import { join } from 'path'; | ||||
| @@ -1204,8 +1203,8 @@ class NoteScreenComponent extends BaseScreenComponent<Props, State> implements B | ||||
| 			}); | ||||
| 		} | ||||
|  | ||||
| 		// Voice typing is enabled only for French language and on Android for now | ||||
| 		if (voskEnabled && shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { | ||||
| 		// Voice typing is enabled only on Android for now | ||||
| 		if (shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { | ||||
| 			output.push({ | ||||
| 				title: _('Voice typing...'), | ||||
| 				onPress: () => { | ||||
|   | ||||
| @@ -1,14 +1,18 @@ | ||||
| import * as React from 'react'; | ||||
| import { useState, useEffect, useCallback } from 'react'; | ||||
| import { Banner, ActivityIndicator } from 'react-native-paper'; | ||||
| import { useState, useEffect, useCallback, useRef, useMemo } from 'react'; | ||||
| import { Banner, ActivityIndicator, Text } from 'react-native-paper'; | ||||
| import { _, languageName } from '@joplin/lib/locale'; | ||||
| import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect'; | ||||
| import { getVosk, Recorder, startRecording, Vosk } from '../../services/voiceTyping/vosk'; | ||||
| import { IconSource } from 'react-native-paper/lib/typescript/components/Icon'; | ||||
| import { modelIsDownloaded } from '../../services/voiceTyping/vosk'; | ||||
| import VoiceTyping, { OnTextCallback, VoiceTypingSession } from '../../services/voiceTyping/VoiceTyping'; | ||||
| import whisper from '../../services/voiceTyping/whisper'; | ||||
| import vosk from '../../services/voiceTyping/vosk'; | ||||
| import { AppState } from '../../utils/types'; | ||||
| import { connect } from 'react-redux'; | ||||
|  | ||||
| interface Props { | ||||
| 	locale: string; | ||||
| 	provider: string; | ||||
| 	onDismiss: ()=> void; | ||||
| 	onText: (text: string)=> void; | ||||
| } | ||||
| @@ -21,44 +25,77 @@ enum RecorderState { | ||||
| 	Downloading = 5, | ||||
| } | ||||
|  | ||||
| const useVosk = (locale: string): [Error | null, boolean, Vosk|null] => { | ||||
| 	const [vosk, setVosk] = useState<Vosk>(null); | ||||
| interface UseVoiceTypingProps { | ||||
| 	locale: string; | ||||
| 	provider: string; | ||||
| 	onSetPreview: OnTextCallback; | ||||
| 	onText: OnTextCallback; | ||||
| } | ||||
|  | ||||
| const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => { | ||||
| 	const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null); | ||||
| 	const [error, setError] = useState<Error>(null); | ||||
| 	const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null); | ||||
|  | ||||
| 	useAsyncEffect(async (event: AsyncEffectEvent) => { | ||||
| 		if (mustDownloadModel === null) return; | ||||
| 	const onTextRef = useRef(onText); | ||||
| 	onTextRef.current = onText; | ||||
| 	const onSetPreviewRef = useRef(onSetPreview); | ||||
| 	onSetPreviewRef.current = onSetPreview; | ||||
|  | ||||
| 	const voiceTypingRef = useRef(voiceTyping); | ||||
| 	voiceTypingRef.current = voiceTyping; | ||||
|  | ||||
| 	const builder = useMemo(() => { | ||||
| 		return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]); | ||||
| 	}, [locale, provider]); | ||||
|  | ||||
| 	useAsyncEffect(async (event: AsyncEffectEvent) => { | ||||
| 		try { | ||||
| 			const v = await getVosk(locale); | ||||
| 			await voiceTypingRef.current?.stop(); | ||||
|  | ||||
| 			if (!await builder.isDownloaded()) { | ||||
| 				if (event.cancelled) return; | ||||
| 				await builder.download(); | ||||
| 			} | ||||
| 			if (event.cancelled) return; | ||||
| 			setVosk(v); | ||||
|  | ||||
| 			const voiceTyping = await builder.build({ | ||||
| 				onPreview: (text) => onSetPreviewRef.current(text), | ||||
| 				onFinalize: (text) => onTextRef.current(text), | ||||
| 			}); | ||||
| 			if (event.cancelled) return; | ||||
| 			setVoiceTyping(voiceTyping); | ||||
| 		} catch (error) { | ||||
| 			setError(error); | ||||
| 		} finally { | ||||
| 			setMustDownloadModel(false); | ||||
| 		} | ||||
| 	}, [locale, mustDownloadModel]); | ||||
| 	}, [builder]); | ||||
|  | ||||
| 	useAsyncEffect(async (_event: AsyncEffectEvent) => { | ||||
| 		setMustDownloadModel(!(await modelIsDownloaded(locale))); | ||||
| 	}, [locale]); | ||||
| 		setMustDownloadModel(!(await builder.isDownloaded())); | ||||
| 	}, [builder]); | ||||
|  | ||||
| 	return [error, mustDownloadModel, vosk]; | ||||
| 	return [error, mustDownloadModel, voiceTyping]; | ||||
| }; | ||||
|  | ||||
| export default (props: Props) => { | ||||
| 	const [recorder, setRecorder] = useState<Recorder>(null); | ||||
| const VoiceTypingDialog: React.FC<Props> = props => { | ||||
| 	const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading); | ||||
| 	const [voskError, mustDownloadModel, vosk] = useVosk(props.locale); | ||||
| 	const [preview, setPreview] = useState<string>(''); | ||||
| 	const [modelError, mustDownloadModel, voiceTyping] = useWhisper({ | ||||
| 		locale: props.locale, | ||||
| 		onSetPreview: setPreview, | ||||
| 		onText: props.onText, | ||||
| 		provider: props.provider, | ||||
| 	}); | ||||
|  | ||||
| 	useEffect(() => { | ||||
| 		if (voskError) { | ||||
| 		if (modelError) { | ||||
| 			setRecorderState(RecorderState.Error); | ||||
| 		} else if (vosk) { | ||||
| 		} else if (voiceTyping) { | ||||
| 			setRecorderState(RecorderState.Recording); | ||||
| 		} | ||||
| 	}, [vosk, voskError]); | ||||
| 	}, [voiceTyping, modelError]); | ||||
|  | ||||
| 	useEffect(() => { | ||||
| 		if (mustDownloadModel) { | ||||
| @@ -68,27 +105,22 @@ export default (props: Props) => { | ||||
|  | ||||
| 	useEffect(() => { | ||||
| 		if (recorderState === RecorderState.Recording) { | ||||
| 			setRecorder(startRecording(vosk, { | ||||
| 				onResult: (text: string) => { | ||||
| 					props.onText(text); | ||||
| 				}, | ||||
| 			})); | ||||
| 			void voiceTyping.start(); | ||||
| 		} | ||||
| 	}, [recorderState, vosk, props.onText]); | ||||
| 	}, [recorderState, voiceTyping, props.onText]); | ||||
|  | ||||
| 	const onDismiss = useCallback(() => { | ||||
| 		if (recorder) recorder.cleanup(); | ||||
| 		void voiceTyping?.stop(); | ||||
| 		props.onDismiss(); | ||||
| 	}, [recorder, props.onDismiss]); | ||||
| 	}, [voiceTyping, props.onDismiss]); | ||||
|  | ||||
| 	const renderContent = () => { | ||||
| 		// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||
| 		const components: Record<RecorderState, Function> = { | ||||
| 		const components: Record<RecorderState, ()=> string> = { | ||||
| 			[RecorderState.Loading]: () => _('Loading...'), | ||||
| 			[RecorderState.Recording]: () => _('Please record your voice...'), | ||||
| 			[RecorderState.Processing]: () => _('Converting speech to text...'), | ||||
| 			[RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)), | ||||
| 			[RecorderState.Error]: () => _('Error: %s', voskError.message), | ||||
| 			[RecorderState.Error]: () => _('Error: %s', modelError.message), | ||||
| 		}; | ||||
|  | ||||
| 		return components[recorderState](); | ||||
| @@ -106,6 +138,11 @@ export default (props: Props) => { | ||||
| 		return components[recorderState]; | ||||
| 	}; | ||||
|  | ||||
| 	const renderPreview = () => { | ||||
| 		return <Text variant='labelSmall'>{preview}</Text>; | ||||
| 	}; | ||||
|  | ||||
| 	const headerAndStatus = <Text variant='bodyMedium'>{`${_('Voice typing...')}\n${renderContent()}`}</Text>; | ||||
| 	return ( | ||||
| 		<Banner | ||||
| 			visible={true} | ||||
| @@ -115,8 +152,15 @@ export default (props: Props) => { | ||||
| 					label: _('Done'), | ||||
| 					onPress: onDismiss, | ||||
| 				}, | ||||
| 			]}> | ||||
| 			{`${_('Voice typing...')}\n${renderContent()}`} | ||||
| 			]} | ||||
| 		> | ||||
| 			{headerAndStatus} | ||||
| 			<Text>{'\n'}</Text> | ||||
| 			{renderPreview()} | ||||
| 		</Banner> | ||||
| 	); | ||||
| }; | ||||
|  | ||||
| export default connect((state: AppState) => ({ | ||||
| 	provider: state.settings['voiceTyping.preferredProvider'], | ||||
| }))(VoiceTypingDialog); | ||||
|   | ||||
| @@ -80,6 +80,10 @@ jest.mock('react-native-image-picker', () => { | ||||
| 	return { default: { } }; | ||||
| }); | ||||
|  | ||||
| jest.mock('react-native-zip-archive', () => { | ||||
| 	return { default: { } }; | ||||
| }); | ||||
|  | ||||
| jest.mock('react-native-document-picker', () => ({ default: { } })); | ||||
|  | ||||
| // Used by the renderer | ||||
|   | ||||
							
								
								
									
										139
									
								
								packages/app-mobile/services/voiceTyping/VoiceTyping.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								packages/app-mobile/services/voiceTyping/VoiceTyping.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | ||||
| import shim from '@joplin/lib/shim'; | ||||
| import Logger from '@joplin/utils/Logger'; | ||||
| import { PermissionsAndroid, Platform } from 'react-native'; | ||||
| import { unzip } from 'react-native-zip-archive'; | ||||
| const md5 = require('md5'); | ||||
|  | ||||
| const logger = Logger.create('voiceTyping'); | ||||
|  | ||||
| export type OnTextCallback = (text: string)=> void; | ||||
|  | ||||
| export interface SpeechToTextCallbacks { | ||||
| 	// Called with a block of text that might change in the future | ||||
| 	onPreview: OnTextCallback; | ||||
| 	// Called with text that will not change and should be added to the document | ||||
| 	onFinalize: OnTextCallback; | ||||
| } | ||||
|  | ||||
| export interface VoiceTypingSession { | ||||
| 	start(): Promise<void>; | ||||
| 	stop(): Promise<void>; | ||||
| } | ||||
|  | ||||
| export interface BuildProviderOptions { | ||||
| 	locale: string; | ||||
| 	modelPath: string; | ||||
| 	callbacks: SpeechToTextCallbacks; | ||||
| } | ||||
|  | ||||
| export interface VoiceTypingProvider { | ||||
| 	modelName: string; | ||||
| 	supported(): boolean; | ||||
| 	modelLocalFilepath(locale: string): string; | ||||
| 	getDownloadUrl(locale: string): string; | ||||
| 	getUuidPath(locale: string): string; | ||||
| 	build(options: BuildProviderOptions): Promise<VoiceTypingSession>; | ||||
| } | ||||
|  | ||||
| export default class VoiceTyping { | ||||
| 	private provider: VoiceTypingProvider|null = null; | ||||
| 	public constructor( | ||||
| 		private locale: string, | ||||
| 		providers: VoiceTypingProvider[], | ||||
| 	) { | ||||
| 		this.provider = providers.find(p => p.supported()) ?? null; | ||||
| 	} | ||||
|  | ||||
| 	public supported() { | ||||
| 		return this.provider !== null; | ||||
| 	} | ||||
|  | ||||
| 	private getModelPath() { | ||||
| 		const localFilePath = shim.fsDriver().resolveRelativePathWithinDir( | ||||
| 			shim.fsDriver().getAppDirectoryPath(), | ||||
| 			this.provider.modelLocalFilepath(this.locale), | ||||
| 		); | ||||
| 		if (localFilePath === shim.fsDriver().getAppDirectoryPath()) { | ||||
| 			throw new Error('Invalid local file path!'); | ||||
| 		} | ||||
|  | ||||
| 		return localFilePath; | ||||
| 	} | ||||
|  | ||||
| 	private getUuidPath() { | ||||
| 		return shim.fsDriver().resolveRelativePathWithinDir( | ||||
| 			shim.fsDriver().getAppDirectoryPath(), | ||||
| 			this.provider.getUuidPath(this.locale), | ||||
| 		); | ||||
| 	} | ||||
|  | ||||
| 	public async isDownloaded() { | ||||
| 		return await shim.fsDriver().exists(this.getUuidPath()); | ||||
| 	} | ||||
|  | ||||
| 	public async download() { | ||||
| 		const modelPath = this.getModelPath(); | ||||
| 		const modelUrl = this.provider.getDownloadUrl(this.locale); | ||||
|  | ||||
| 		await shim.fsDriver().remove(modelPath); | ||||
| 		logger.info(`Downloading model from: ${modelUrl}`); | ||||
|  | ||||
| 		const isZipped = modelUrl.endsWith('.zip'); | ||||
| 		const downloadPath = isZipped ? `${modelPath}.zip` : modelPath; | ||||
| 		const response = await shim.fetchBlob(modelUrl, { | ||||
| 			path: downloadPath, | ||||
| 		}); | ||||
|  | ||||
| 		if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); | ||||
|  | ||||
| 		if (isZipped) { | ||||
| 			const modelName = this.provider.modelName; | ||||
| 			const unzipDir = `${shim.fsDriver().getCacheDirectoryPath()}/voice-typing-extract/${modelName}/${this.locale}`; | ||||
| 			try { | ||||
| 				logger.info(`Unzipping ${downloadPath} => ${unzipDir}`); | ||||
|  | ||||
| 				await unzip(downloadPath, unzipDir); | ||||
|  | ||||
| 				const contents = await shim.fsDriver().readDirStats(unzipDir); | ||||
| 				if (contents.length !== 1) { | ||||
| 					logger.error('Expected 1 file or directory but got', contents); | ||||
| 					throw new Error(`Expected 1 file or directory, but got ${contents.length}`); | ||||
| 				} | ||||
|  | ||||
| 				const fullUnzipPath = `${unzipDir}/${contents[0].path}`; | ||||
|  | ||||
| 				logger.info(`Moving ${fullUnzipPath} => ${modelPath}`); | ||||
| 				await shim.fsDriver().move(fullUnzipPath, modelPath); | ||||
|  | ||||
| 				await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8'); | ||||
| 				if (!await this.isDownloaded()) { | ||||
| 					logger.warn('Model should be downloaded!'); | ||||
| 				} | ||||
| 			} finally { | ||||
| 				await shim.fsDriver().remove(unzipDir); | ||||
| 				await shim.fsDriver().remove(downloadPath); | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	public async build(callbacks: SpeechToTextCallbacks) { | ||||
| 		if (!this.provider) { | ||||
| 			throw new Error('No supported provider found!'); | ||||
| 		} | ||||
|  | ||||
| 		if (!await this.isDownloaded()) { | ||||
| 			await this.download(); | ||||
| 		} | ||||
|  | ||||
| 		const audioPermission = 'android.permission.RECORD_AUDIO'; | ||||
| 		if (Platform.OS === 'android' && !await PermissionsAndroid.check(audioPermission)) { | ||||
| 			await PermissionsAndroid.request(audioPermission); | ||||
| 		} | ||||
|  | ||||
| 		return this.provider.build({ | ||||
| 			locale: this.locale, | ||||
| 			modelPath: this.getModelPath(), | ||||
| 			callbacks, | ||||
| 		}); | ||||
| 	} | ||||
| } | ||||
| @@ -0,0 +1,61 @@ | ||||
| import splitWhisperText from './splitWhisperText'; | ||||
|  | ||||
| describe('splitWhisperText', () => { | ||||
| 	test.each([ | ||||
| 		{ | ||||
| 			// Should trim at sentence breaks | ||||
| 			input: '<|0.00|> This is a test. <|5.00|><|6.00|> This is another sentence. <|7.00|>', | ||||
| 			recordingLength: 8, | ||||
| 			expected: { | ||||
| 				trimTo: 6, | ||||
| 				dataBeforeTrim: '<|0.00|> This is a test. ', | ||||
| 				dataAfterTrim: ' This is another sentence. <|7.00|>', | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Should prefer sentence break splits to non sentence break splits | ||||
| 			input: '<|0.00|> This is <|4.00|><|4.50|> a test. <|5.00|><|5.50|> Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', | ||||
| 			recordingLength: 8, | ||||
| 			expected: { | ||||
| 				trimTo: 5.50, | ||||
| 				dataBeforeTrim: '<|0.00|> This is <|4.00|><|4.50|> a test. ', | ||||
| 				dataAfterTrim: ' Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Should avoid splitting for very small timestamps | ||||
| 			input: '<|0.00|> This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', | ||||
| 			recordingLength: 4, | ||||
| 			expected: { | ||||
| 				trimTo: 0, | ||||
| 				dataBeforeTrim: '', | ||||
| 				dataAfterTrim: ' This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// For larger timestamps, should allow splitting at pauses, even if not on sentence breaks. | ||||
| 			input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', | ||||
| 			recordingLength: 16, | ||||
| 			expected: { | ||||
| 				trimTo: 12, | ||||
| 				dataBeforeTrim: '<|0.00|> This is a test, ', | ||||
| 				dataAfterTrim: ' of splitting on timestamps. <|15.00|>', | ||||
| 			}, | ||||
| 		}, | ||||
| 		{ | ||||
| 			// Should prefer to break at the end, if a large gap after the last timestamp. | ||||
| 			input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', | ||||
| 			recordingLength: 30, | ||||
| 			expected: { | ||||
| 				trimTo: 15, | ||||
| 				dataBeforeTrim: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. ', | ||||
| 				dataAfterTrim: '', | ||||
| 			}, | ||||
| 		}, | ||||
| 	])('should prefer to split at the end of sentences (case %#)', ({ input, recordingLength, expected }) => { | ||||
| 		const actual = splitWhisperText(input, recordingLength); | ||||
| 		expect(actual.trimTo).toBeCloseTo(expected.trimTo); | ||||
| 		expect(actual.dataBeforeTrim).toBe(expected.dataBeforeTrim); | ||||
| 		expect(actual.dataAfterTrim).toBe(expected.dataAfterTrim); | ||||
| 	}); | ||||
| }); | ||||
| @@ -0,0 +1,65 @@ | ||||
| // Matches pairs of timestamps or single timestamps. | ||||
| const timestampExp = /<\|(\d+\.\d*)\|>(?:<\|(\d+\.\d*)\|>)?/g; | ||||
|  | ||||
| const timestampMatchToNumber = (match: RegExpMatchArray) => { | ||||
| 	const firstTimestamp = match[1]; | ||||
| 	const secondTimestamp = match[2]; | ||||
| 	// Prefer the second timestamp in the pair, to remove leading silence. | ||||
| 	const timestamp = Number(secondTimestamp ? secondTimestamp : firstTimestamp); | ||||
|  | ||||
| 	// Should always be a finite number (i.e. not NaN) | ||||
| 	if (!isFinite(timestamp)) throw new Error(`Timestamp match failed with ${match[0]}`); | ||||
|  | ||||
| 	return timestamp; | ||||
| }; | ||||
|  | ||||
| const splitWhisperText = (textWithTimestamps: string, recordingLengthSeconds: number) => { | ||||
| 	const timestamps = [ | ||||
| 		...textWithTimestamps.matchAll(timestampExp), | ||||
| 	].map(match => { | ||||
| 		const timestamp = timestampMatchToNumber(match); | ||||
| 		return { timestamp, match }; | ||||
| 	}); | ||||
|  | ||||
| 	if (!timestamps.length) { | ||||
| 		return { trimTo: 0, dataBeforeTrim: '', dataAfterTrim: textWithTimestamps }; | ||||
| 	} | ||||
|  | ||||
| 	const firstTimestamp = timestamps[0]; | ||||
| 	let breakAt = firstTimestamp; | ||||
|  | ||||
| 	const lastTimestamp = timestamps[timestamps.length - 1]; | ||||
| 	const hasLongPauseAfterData = lastTimestamp.timestamp + 4 < recordingLengthSeconds; | ||||
| 	if (hasLongPauseAfterData) { | ||||
| 		breakAt = lastTimestamp; | ||||
| 	} else { | ||||
| 		const textWithTimestampsContentLength = textWithTimestamps.trimEnd().length; | ||||
|  | ||||
| 		for (const timestampData of timestamps) { | ||||
| 			const { match, timestamp } = timestampData; | ||||
| 			const contentBefore = textWithTimestamps.substring(Math.max(match.index - 3, 0), match.index); | ||||
| 			const isNearEndOfLatinSentence = contentBefore.match(/[.?!]/); | ||||
| 			const isNearEndOfData = match.index + match[0].length >= textWithTimestampsContentLength; | ||||
|  | ||||
| 			// Use a heuristic to determine whether to move content from the preview to the document. | ||||
| 			// These are based on the maximum buffer length of 30 seconds -- as the buffer gets longer, the | ||||
| 			// data should be more likely to be broken into chunks. Where possible, the break should be near | ||||
| 			// the end of a sentence: | ||||
| 			const canBreak = (timestamp > 4 && isNearEndOfLatinSentence && !isNearEndOfData) | ||||
| 					|| (timestamp > 8 && !isNearEndOfData) | ||||
| 					|| timestamp > 16; | ||||
| 			if (canBreak) { | ||||
| 				breakAt = timestampData; | ||||
| 				break; | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	const trimTo = breakAt.timestamp; | ||||
| 	const dataBeforeTrim = textWithTimestamps.substring(0, breakAt.match.index); | ||||
| 	const dataAfterTrim = textWithTimestamps.substring(breakAt.match.index + breakAt.match[0].length); | ||||
|  | ||||
| 	return { trimTo, dataBeforeTrim, dataAfterTrim }; | ||||
| }; | ||||
|  | ||||
| export default splitWhisperText; | ||||
| @@ -4,9 +4,9 @@ import Setting from '@joplin/lib/models/Setting'; | ||||
| import { rtrimSlashes } from '@joplin/lib/path-utils'; | ||||
| import shim from '@joplin/lib/shim'; | ||||
| import Vosk from 'react-native-vosk'; | ||||
| import { unzip } from 'react-native-zip-archive'; | ||||
| import RNFetchBlob from 'rn-fetch-blob'; | ||||
| const md5 = require('md5'); | ||||
| import { VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; | ||||
| import { join } from 'path'; | ||||
|  | ||||
| const logger = Logger.create('voiceTyping/vosk'); | ||||
|  | ||||
| @@ -24,15 +24,6 @@ let vosk_: Record<string, Vosk> = {}; | ||||
|  | ||||
| let state_: State = State.Idle; | ||||
|  | ||||
| export const voskEnabled = true; | ||||
|  | ||||
| export { Vosk }; | ||||
|  | ||||
| export interface Recorder { | ||||
| 	stop: ()=> Promise<string>; | ||||
| 	cleanup: ()=> void; | ||||
| } | ||||
|  | ||||
| const defaultSupportedLanguages = { | ||||
| 	'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip', | ||||
| 	'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip', | ||||
| @@ -74,7 +65,7 @@ const getModelDir = (locale: string) => { | ||||
| 	return `${getUnzipDir(locale)}/model`; | ||||
| }; | ||||
|  | ||||
| const languageModelUrl = (locale: string) => { | ||||
| const languageModelUrl = (locale: string): string => { | ||||
| 	const lang = languageCodeOnly(locale).toLowerCase(); | ||||
| 	if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`); | ||||
|  | ||||
| @@ -90,16 +81,11 @@ const languageModelUrl = (locale: string) => { | ||||
| 	} | ||||
| }; | ||||
|  | ||||
| export const modelIsDownloaded = async (locale: string) => { | ||||
| 	const uuidFile = `${getModelDir(locale)}/uuid`; | ||||
| 	return shim.fsDriver().exists(uuidFile); | ||||
| }; | ||||
|  | ||||
| export const getVosk = async (locale: string) => { | ||||
| export const getVosk = async (modelDir: string, locale: string) => { | ||||
| 	if (vosk_[locale]) return vosk_[locale]; | ||||
|  | ||||
| 	const vosk = new Vosk(); | ||||
| 	const modelDir = await downloadModel(locale); | ||||
| 	logger.info(`Loading model from ${modelDir}`); | ||||
| 	await shim.fsDriver().readDirStats(modelDir); | ||||
| 	const result = await vosk.loadModel(modelDir); | ||||
| @@ -110,51 +96,7 @@ export const getVosk = async (locale: string) => { | ||||
| 	return vosk; | ||||
| }; | ||||
|  | ||||
| const downloadModel = async (locale: string) => { | ||||
| 	const modelUrl = languageModelUrl(locale); | ||||
| 	const unzipDir = getUnzipDir(locale); | ||||
| 	const zipFilePath = `${unzipDir}.zip`; | ||||
| 	const modelDir = getModelDir(locale); | ||||
| 	const uuidFile = `${modelDir}/uuid`; | ||||
|  | ||||
| 	if (await modelIsDownloaded(locale)) { | ||||
| 		logger.info(`Model for ${locale} already exists at ${modelDir}`); | ||||
| 		return modelDir; | ||||
| 	} | ||||
|  | ||||
| 	await shim.fsDriver().remove(unzipDir); | ||||
|  | ||||
| 	logger.info(`Downloading model from: ${modelUrl}`); | ||||
|  | ||||
| 	const response = await shim.fetchBlob(modelUrl, { | ||||
| 		path: zipFilePath, | ||||
| 	}); | ||||
|  | ||||
| 	if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); | ||||
|  | ||||
| 	logger.info(`Unzipping ${zipFilePath} => ${unzipDir}`); | ||||
|  | ||||
| 	await unzip(zipFilePath, unzipDir); | ||||
|  | ||||
| 	const dirs = await shim.fsDriver().readDirStats(unzipDir); | ||||
| 	if (dirs.length !== 1) { | ||||
| 		logger.error('Expected 1 directory but got', dirs); | ||||
| 		throw new Error(`Expected 1 directory, but got ${dirs.length}`); | ||||
| 	} | ||||
|  | ||||
| 	const fullUnzipPath = `${unzipDir}/${dirs[0].path}`; | ||||
|  | ||||
| 	logger.info(`Moving ${fullUnzipPath} =>  ${modelDir}`); | ||||
| 	await shim.fsDriver().rename(fullUnzipPath, modelDir); | ||||
|  | ||||
| 	await shim.fsDriver().writeFile(uuidFile, md5(modelUrl)); | ||||
|  | ||||
| 	await shim.fsDriver().remove(zipFilePath); | ||||
|  | ||||
| 	return modelDir; | ||||
| }; | ||||
|  | ||||
| export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | ||||
| export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSession => { | ||||
| 	if (state_ !== State.Idle) throw new Error('Vosk is already recording'); | ||||
|  | ||||
| 	state_ = State.Recording; | ||||
| @@ -163,10 +105,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | ||||
| 	// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | ||||
| 	const eventHandlers: any[] = []; | ||||
| 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||
| 	let finalResultPromiseResolve: Function = null; | ||||
| 	const finalResultPromiseResolve: Function = null; | ||||
| 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||
| 	let finalResultPromiseReject: Function = null; | ||||
| 	let finalResultTimeout = false; | ||||
| 	const finalResultPromiseReject: Function = null; | ||||
| 	const finalResultTimeout = false; | ||||
|  | ||||
| 	const completeRecording = (finalResult: string, error: Error) => { | ||||
| 		logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error); | ||||
| @@ -212,31 +154,13 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | ||||
| 		completeRecording(e.data, null); | ||||
| 	})); | ||||
|  | ||||
| 	logger.info('Starting recording...'); | ||||
|  | ||||
| 	void vosk.start(); | ||||
|  | ||||
| 	return { | ||||
| 		stop: (): Promise<string> => { | ||||
| 			logger.info('Stopping recording...'); | ||||
|  | ||||
| 			vosk.stopOnly(); | ||||
|  | ||||
| 			logger.info('Waiting for final result...'); | ||||
|  | ||||
| 			setTimeout(() => { | ||||
| 				finalResultTimeout = true; | ||||
| 				logger.warn('Timed out waiting for finalResult event'); | ||||
| 				completeRecording('', new Error('Could not process your message. Please try again.')); | ||||
| 			}, 5000); | ||||
|  | ||||
| 			// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||
| 			return new Promise((resolve: Function, reject: Function) => { | ||||
| 				finalResultPromiseResolve = resolve; | ||||
| 				finalResultPromiseReject = reject; | ||||
| 			}); | ||||
| 		start: async () => { | ||||
| 			logger.info('Starting recording...'); | ||||
| 			await vosk.start(); | ||||
| 		}, | ||||
| 		cleanup: () => { | ||||
| 		stop: async () => { | ||||
| 			if (state_ === State.Recording) { | ||||
| 				logger.info('Cancelling...'); | ||||
| 				state_ = State.Completing; | ||||
| @@ -246,3 +170,18 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | ||||
| 		}, | ||||
| 	}; | ||||
| }; | ||||
|  | ||||
|  | ||||
| const vosk: VoiceTypingProvider = { | ||||
| 	supported: () => true, | ||||
| 	modelLocalFilepath: (locale: string) => getModelDir(locale), | ||||
| 	getDownloadUrl: (locale) => languageModelUrl(locale), | ||||
| 	getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'), | ||||
| 	build: async ({ callbacks, locale, modelPath }) => { | ||||
| 		const vosk = await getVosk(modelPath, locale); | ||||
| 		return startRecording(vosk, { onResult: callbacks.onFinalize }); | ||||
| 	}, | ||||
| 	modelName: 'vosk', | ||||
| }; | ||||
|  | ||||
| export default vosk; | ||||
|   | ||||
| @@ -1,37 +1,14 @@ | ||||
| // Currently disabled on non-Android platforms | ||||
| import { VoiceTypingProvider } from './VoiceTyping'; | ||||
|  | ||||
| // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | ||||
| type Vosk = any; | ||||
|  | ||||
| export { Vosk }; | ||||
|  | ||||
| interface StartOptions { | ||||
| 	onResult: (text: string)=> void; | ||||
| } | ||||
|  | ||||
| export interface Recorder { | ||||
| 	stop: ()=> Promise<string>; | ||||
| 	cleanup: ()=> void; | ||||
| } | ||||
|  | ||||
| export const isSupportedLanguage = (_locale: string) => { | ||||
| 	return false; | ||||
| const vosk: VoiceTypingProvider = { | ||||
| 	supported: () => false, | ||||
| 	modelLocalFilepath: () => null, | ||||
| 	getDownloadUrl: () => null, | ||||
| 	getUuidPath: () => null, | ||||
| 	build: async () => { | ||||
| 		throw new Error('Unsupported!'); | ||||
| 	}, | ||||
| 	modelName: 'vosk', | ||||
| }; | ||||
|  | ||||
| export const modelIsDownloaded = async (_locale: string) => { | ||||
| 	return false; | ||||
| }; | ||||
|  | ||||
| export const getVosk = async (_locale: string) => { | ||||
| 	// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | ||||
| 	return {} as any; | ||||
| }; | ||||
|  | ||||
| export const startRecording = (_vosk: Vosk, _options: StartOptions): Recorder => { | ||||
| 	return { | ||||
| 		stop: async () => { return ''; }, | ||||
| 		cleanup: () => {}, | ||||
| 	}; | ||||
| }; | ||||
|  | ||||
| export const voskEnabled = false; | ||||
| export default vosk; | ||||
|   | ||||
							
								
								
									
										119
									
								
								packages/app-mobile/services/voiceTyping/whisper.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								packages/app-mobile/services/voiceTyping/whisper.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | ||||
| import Setting from '@joplin/lib/models/Setting'; | ||||
| import shim from '@joplin/lib/shim'; | ||||
| import Logger from '@joplin/utils/Logger'; | ||||
| import { rtrimSlashes } from '@joplin/utils/path'; | ||||
| import { dirname, join } from 'path'; | ||||
| import { NativeModules } from 'react-native'; | ||||
| import { SpeechToTextCallbacks, VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; | ||||
| import splitWhisperText from './utils/splitWhisperText'; | ||||
|  | ||||
| const logger = Logger.create('voiceTyping/whisper'); | ||||
|  | ||||
| const { SpeechToTextModule } = NativeModules; | ||||
|  | ||||
| // Timestamps are in the form <|0.00|>. They seem to be added: | ||||
| // - After long pauses. | ||||
| // - Between sentences (in pairs). | ||||
| // - At the beginning and end of a sequence. | ||||
| const timestampExp = /<\|(\d+\.\d*)\|>/g; | ||||
| const postProcessSpeech = (text: string) => { | ||||
| 	return text.replace(timestampExp, '').replace(/\[BLANK_AUDIO\]/g, ''); | ||||
| }; | ||||
|  | ||||
| class Whisper implements VoiceTypingSession { | ||||
| 	private lastPreviewData: string; | ||||
| 	private closeCounter = 0; | ||||
|  | ||||
| 	public constructor( | ||||
| 		private sessionId: number|null, | ||||
| 		private callbacks: SpeechToTextCallbacks, | ||||
| 	) { } | ||||
|  | ||||
| 	public async start() { | ||||
| 		if (this.sessionId === null) { | ||||
| 			throw new Error('Session closed.'); | ||||
| 		} | ||||
| 		try { | ||||
| 			logger.debug('starting recorder'); | ||||
| 			await SpeechToTextModule.startRecording(this.sessionId); | ||||
| 			logger.debug('recorder started'); | ||||
|  | ||||
| 			const loopStartCounter = this.closeCounter; | ||||
| 			while (this.closeCounter === loopStartCounter) { | ||||
| 				logger.debug('reading block'); | ||||
| 				const data: string = await SpeechToTextModule.expandBufferAndConvert(this.sessionId, 4); | ||||
| 				logger.debug('done reading block. Length', data?.length); | ||||
|  | ||||
| 				if (this.sessionId === null) { | ||||
| 					logger.debug('Session stopped. Ending inference loop.'); | ||||
| 					return; | ||||
| 				} | ||||
|  | ||||
| 				const recordingLength = await SpeechToTextModule.getBufferLengthSeconds(this.sessionId); | ||||
| 				logger.debug('recording length so far', recordingLength); | ||||
| 				const { trimTo, dataBeforeTrim, dataAfterTrim } = splitWhisperText(data, recordingLength); | ||||
|  | ||||
| 				if (trimTo > 2) { | ||||
| 					logger.debug('Trim to', trimTo, 'in recording with length', recordingLength); | ||||
| 					this.callbacks.onFinalize(postProcessSpeech(dataBeforeTrim)); | ||||
| 					this.callbacks.onPreview(postProcessSpeech(dataAfterTrim)); | ||||
| 					this.lastPreviewData = dataAfterTrim; | ||||
| 					await SpeechToTextModule.dropFirstSeconds(this.sessionId, trimTo); | ||||
| 				} else { | ||||
| 					logger.debug('Preview', data); | ||||
| 					this.lastPreviewData = data; | ||||
| 					this.callbacks.onPreview(postProcessSpeech(data)); | ||||
| 				} | ||||
| 			} | ||||
| 		} catch (error) { | ||||
| 			logger.error('Whisper error:', error); | ||||
| 			this.lastPreviewData = ''; | ||||
| 			await this.stop(); | ||||
| 			throw error; | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	public async stop() { | ||||
| 		if (this.sessionId === null) { | ||||
| 			logger.warn('Session already closed.'); | ||||
| 			return; | ||||
| 		} | ||||
|  | ||||
| 		const sessionId = this.sessionId; | ||||
| 		this.sessionId = null; | ||||
| 		this.closeCounter ++; | ||||
| 		await SpeechToTextModule.closeSession(sessionId); | ||||
|  | ||||
| 		if (this.lastPreviewData) { | ||||
| 			this.callbacks.onFinalize(postProcessSpeech(this.lastPreviewData)); | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| const modelLocalFilepath = () => { | ||||
| 	return `${shim.fsDriver().getAppDirectoryPath()}/voice-typing-models/whisper_tiny.onnx`; | ||||
| }; | ||||
|  | ||||
| const whisper: VoiceTypingProvider = { | ||||
| 	supported: () => !!SpeechToTextModule, | ||||
| 	modelLocalFilepath: modelLocalFilepath, | ||||
| 	getDownloadUrl: () => { | ||||
| 		let urlTemplate = rtrimSlashes(Setting.value('voiceTypingBaseUrl').trim()); | ||||
|  | ||||
| 		if (!urlTemplate) { | ||||
| 			urlTemplate = 'https://github.com/personalizedrefrigerator/joplin-voice-typing-test/releases/download/test-release/{task}.zip'; | ||||
| 		} | ||||
|  | ||||
| 		return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx'); | ||||
| 	}, | ||||
| 	getUuidPath: () => { | ||||
| 		return join(dirname(modelLocalFilepath()), 'uuid'); | ||||
| 	}, | ||||
| 	build: async ({ modelPath, callbacks, locale }) => { | ||||
| 		const sessionId = await SpeechToTextModule.openSession(modelPath, locale); | ||||
| 		return new Whisper(sessionId, callbacks); | ||||
| 	}, | ||||
| 	modelName: 'whisper', | ||||
| }; | ||||
|  | ||||
| export default whisper; | ||||
| @@ -1615,6 +1615,25 @@ const builtInMetadata = (Setting: typeof SettingType) => { | ||||
| 			section: 'note', | ||||
| 		}, | ||||
|  | ||||
| 		'voiceTyping.preferredProvider': { | ||||
| 			value: 'whisper-tiny', | ||||
| 			type: SettingItemType.String, | ||||
| 			public: true, | ||||
| 			appTypes: [AppType.Mobile], | ||||
| 			label: () => _('Preferred voice typing provider'), | ||||
| 			isEnum: true, | ||||
| 			// For now, iOS and web don't support voice typing. | ||||
| 			show: () => shim.mobilePlatform() === 'android', | ||||
| 			section: 'note', | ||||
|  | ||||
| 			options: () => { | ||||
| 				return { | ||||
| 					'vosk': _('Vosk'), | ||||
| 					'whisper-tiny': _('Whisper'), | ||||
| 				}; | ||||
| 			}, | ||||
| 		}, | ||||
|  | ||||
| 		'trash.autoDeletionEnabled': { | ||||
| 			value: true, | ||||
| 			type: SettingItemType.Bool, | ||||
|   | ||||
| @@ -132,4 +132,6 @@ Famegear | ||||
| rcompare | ||||
| tabindex | ||||
| Backblaze | ||||
| onnx | ||||
| onnxruntime | ||||
| treeitem | ||||
|   | ||||
		Reference in New Issue
	
	Block a user