You've already forked joplin
							
							
				mirror of
				https://github.com/laurent22/joplin.git
				synced 2025-10-31 00:07:48 +02:00 
			
		
		
		
	Android: Allow switching the voice typing library to Whisper (#11158)
Co-authored-by: Laurent Cozic <laurent22@users.noreply.github.com>
This commit is contained in:
		| @@ -737,8 +737,12 @@ packages/app-mobile/services/BackButtonService.js | |||||||
| packages/app-mobile/services/e2ee/RSA.react-native.js | packages/app-mobile/services/e2ee/RSA.react-native.js | ||||||
| packages/app-mobile/services/plugins/PlatformImplementation.js | packages/app-mobile/services/plugins/PlatformImplementation.js | ||||||
| packages/app-mobile/services/profiles/index.js | packages/app-mobile/services/profiles/index.js | ||||||
|  | packages/app-mobile/services/voiceTyping/VoiceTyping.js | ||||||
|  | packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js | ||||||
|  | packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js | ||||||
| packages/app-mobile/services/voiceTyping/vosk.android.js | packages/app-mobile/services/voiceTyping/vosk.android.js | ||||||
| packages/app-mobile/services/voiceTyping/vosk.js | packages/app-mobile/services/voiceTyping/vosk.js | ||||||
|  | packages/app-mobile/services/voiceTyping/whisper.js | ||||||
| packages/app-mobile/setupQuickActions.js | packages/app-mobile/setupQuickActions.js | ||||||
| packages/app-mobile/tools/buildInjectedJs/BundledFile.js | packages/app-mobile/tools/buildInjectedJs/BundledFile.js | ||||||
| packages/app-mobile/tools/buildInjectedJs/constants.js | packages/app-mobile/tools/buildInjectedJs/constants.js | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										4
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -714,8 +714,12 @@ packages/app-mobile/services/BackButtonService.js | |||||||
| packages/app-mobile/services/e2ee/RSA.react-native.js | packages/app-mobile/services/e2ee/RSA.react-native.js | ||||||
| packages/app-mobile/services/plugins/PlatformImplementation.js | packages/app-mobile/services/plugins/PlatformImplementation.js | ||||||
| packages/app-mobile/services/profiles/index.js | packages/app-mobile/services/profiles/index.js | ||||||
|  | packages/app-mobile/services/voiceTyping/VoiceTyping.js | ||||||
|  | packages/app-mobile/services/voiceTyping/utils/splitWhisperText.test.js | ||||||
|  | packages/app-mobile/services/voiceTyping/utils/splitWhisperText.js | ||||||
| packages/app-mobile/services/voiceTyping/vosk.android.js | packages/app-mobile/services/voiceTyping/vosk.android.js | ||||||
| packages/app-mobile/services/voiceTyping/vosk.js | packages/app-mobile/services/voiceTyping/vosk.js | ||||||
|  | packages/app-mobile/services/voiceTyping/whisper.js | ||||||
| packages/app-mobile/setupQuickActions.js | packages/app-mobile/setupQuickActions.js | ||||||
| packages/app-mobile/tools/buildInjectedJs/BundledFile.js | packages/app-mobile/tools/buildInjectedJs/BundledFile.js | ||||||
| packages/app-mobile/tools/buildInjectedJs/constants.js | packages/app-mobile/tools/buildInjectedJs/constants.js | ||||||
|   | |||||||
| @@ -136,6 +136,10 @@ dependencies { | |||||||
|     } else { |     } else { | ||||||
|         implementation jscFlavor |         implementation jscFlavor | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // Needed for Whisper speech-to-text | ||||||
|  |     implementation 'com.microsoft.onnxruntime:onnxruntime-android:latest.release' | ||||||
|  |     implementation 'com.microsoft.onnxruntime:onnxruntime-extensions-android:latest.release' | ||||||
| } | } | ||||||
|  |  | ||||||
| apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project) | apply from: file("../../node_modules/@react-native-community/cli-platform-android/native_modules.gradle"); applyNativeModulesAppBuildGradle(project) | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ | |||||||
| 	<uses-permission android:name="android.permission.POST_NOTIFICATIONS" /> | 	<uses-permission android:name="android.permission.POST_NOTIFICATIONS" /> | ||||||
| 	<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> | 	<uses-permission android:name="android.permission.READ_EXTERNAL_STORAGE" /> | ||||||
| 	<uses-permission android:name="android.permission.READ_MEDIA_IMAGES" /> | 	<uses-permission android:name="android.permission.READ_MEDIA_IMAGES" /> | ||||||
|  | 	<uses-permission android:name="android.permission.RECORD_AUDIO" /> | ||||||
|  |  | ||||||
| 	<!-- Make these features optional to enable Chromebooks --> | 	<!-- Make these features optional to enable Chromebooks --> | ||||||
| 	<!-- https://github.com/laurent22/joplin/issues/37 --> | 	<!-- https://github.com/laurent22/joplin/issues/37 --> | ||||||
|   | |||||||
| @@ -11,6 +11,7 @@ import com.facebook.react.defaults.DefaultNewArchitectureEntryPoint.load | |||||||
| import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost | import com.facebook.react.defaults.DefaultReactHost.getDefaultReactHost | ||||||
| import com.facebook.react.defaults.DefaultReactNativeHost | import com.facebook.react.defaults.DefaultReactNativeHost | ||||||
| import com.facebook.soloader.SoLoader | import com.facebook.soloader.SoLoader | ||||||
|  | import net.cozic.joplin.audio.SpeechToTextPackage | ||||||
| import net.cozic.joplin.versioninfo.SystemVersionInformationPackage | import net.cozic.joplin.versioninfo.SystemVersionInformationPackage | ||||||
| import net.cozic.joplin.share.SharePackage | import net.cozic.joplin.share.SharePackage | ||||||
| import net.cozic.joplin.ssl.SslPackage | import net.cozic.joplin.ssl.SslPackage | ||||||
| @@ -25,6 +26,7 @@ class MainApplication : Application(), ReactApplication { | |||||||
|                     add(SslPackage()) |                     add(SslPackage()) | ||||||
|                     add(TextInputPackage()) |                     add(TextInputPackage()) | ||||||
|                     add(SystemVersionInformationPackage()) |                     add(SystemVersionInformationPackage()) | ||||||
|  |                     add(SpeechToTextPackage()) | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|         override fun getJSMainModuleName(): String = "index" |         override fun getJSMainModuleName(): String = "index" | ||||||
|   | |||||||
| @@ -0,0 +1,101 @@ | |||||||
|  | package net.cozic.joplin.audio | ||||||
|  |  | ||||||
|  | import android.Manifest | ||||||
|  | import android.annotation.SuppressLint | ||||||
|  | import android.content.Context | ||||||
|  | import android.content.pm.PackageManager | ||||||
|  | import android.media.AudioFormat | ||||||
|  | import android.media.AudioRecord | ||||||
|  | import android.media.MediaRecorder.AudioSource | ||||||
|  | import java.io.Closeable | ||||||
|  | import kotlin.math.max | ||||||
|  | import kotlin.math.min | ||||||
|  |  | ||||||
|  | typealias AudioRecorderFactory = (context: Context)->AudioRecorder; | ||||||
|  |  | ||||||
|  | class AudioRecorder(context: Context) : Closeable { | ||||||
|  | 	private val sampleRate = 16_000 | ||||||
|  | 	private val maxLengthSeconds = 30 // Whisper supports a maximum of 30s | ||||||
|  | 	private val maxBufferSize = sampleRate * maxLengthSeconds | ||||||
|  | 	private val buffer = FloatArray(maxBufferSize) | ||||||
|  | 	private var bufferWriteOffset = 0 | ||||||
|  |  | ||||||
|  | 	// Accessor must not modify result | ||||||
|  | 	val bufferedData: FloatArray get() = buffer.sliceArray(0 until bufferWriteOffset) | ||||||
|  | 	val bufferLengthSeconds: Double get() = bufferWriteOffset.toDouble() / sampleRate | ||||||
|  |  | ||||||
|  | 	init { | ||||||
|  | 		val permissionResult = context.checkSelfPermission(Manifest.permission.RECORD_AUDIO) | ||||||
|  | 		if (permissionResult == PackageManager.PERMISSION_DENIED) { | ||||||
|  | 			throw SecurityException("Missing RECORD_AUDIO permission!") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Permissions check is included above | ||||||
|  | 	@SuppressLint("MissingPermission") | ||||||
|  | 	private val recorder = AudioRecord.Builder() | ||||||
|  | 		.setAudioSource(AudioSource.MIC) | ||||||
|  | 		.setAudioFormat( | ||||||
|  | 			AudioFormat.Builder() | ||||||
|  | 				// PCM: A WAV format | ||||||
|  | 				.setEncoding(AudioFormat.ENCODING_PCM_FLOAT) | ||||||
|  | 				.setSampleRate(sampleRate) | ||||||
|  | 				.setChannelMask(AudioFormat.CHANNEL_IN_MONO) | ||||||
|  | 				.build() | ||||||
|  | 		) | ||||||
|  | 		.setBufferSizeInBytes(maxBufferSize * Float.SIZE_BYTES) | ||||||
|  | 		.build() | ||||||
|  |  | ||||||
|  | 	// Discards the first [samples] samples from the start of the buffer. Conceptually, this | ||||||
|  | 	// advances the buffer's start point. | ||||||
|  | 	private fun advanceStartBySamples(samples: Int) { | ||||||
|  | 		val samplesClamped = min(samples, maxBufferSize) | ||||||
|  | 		val remainingBuffer = buffer.sliceArray(samplesClamped until maxBufferSize) | ||||||
|  |  | ||||||
|  | 		buffer.fill(0f, samplesClamped, maxBufferSize) | ||||||
|  | 		remainingBuffer.copyInto(buffer, 0) | ||||||
|  | 		bufferWriteOffset = max(bufferWriteOffset - samplesClamped, 0) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun dropFirstSeconds(seconds: Double) { | ||||||
|  | 		advanceStartBySamples((seconds * sampleRate).toInt()) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun start() { | ||||||
|  | 		recorder.startRecording() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private fun read(requestedSize: Int, mode: Int) { | ||||||
|  | 		val size = min(requestedSize, maxBufferSize - bufferWriteOffset) | ||||||
|  | 		val sizeRead = recorder.read(buffer, bufferWriteOffset, size, mode) | ||||||
|  | 		if (sizeRead > 0) { | ||||||
|  | 			bufferWriteOffset += sizeRead | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Pulls all available data from the audio recorder's buffer | ||||||
|  | 	fun pullAvailable() { | ||||||
|  | 		return read(maxBufferSize, AudioRecord.READ_NON_BLOCKING) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun pullNextSeconds(seconds: Double) { | ||||||
|  | 		val remainingSize = maxBufferSize - bufferWriteOffset | ||||||
|  | 		val requestedSize = (seconds * sampleRate).toInt() | ||||||
|  |  | ||||||
|  | 		// If low on size, make more room. | ||||||
|  | 		if (remainingSize < maxBufferSize / 3) { | ||||||
|  | 			advanceStartBySamples(maxBufferSize / 3) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return read(requestedSize, AudioRecord.READ_BLOCKING) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	override fun close() { | ||||||
|  | 		recorder.stop() | ||||||
|  | 		recorder.release() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	companion object { | ||||||
|  | 		val factory: AudioRecorderFactory = { context -> AudioRecorder(context) } | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | package net.cozic.joplin.audio | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class InvalidSessionIdException(id: Int) : IllegalArgumentException("Invalid session ID $id") { | ||||||
|  | } | ||||||
| @@ -0,0 +1,136 @@ | |||||||
|  | package net.cozic.joplin.audio | ||||||
|  |  | ||||||
|  | import ai.onnxruntime.OnnxTensor | ||||||
|  | import ai.onnxruntime.OrtEnvironment | ||||||
|  | import ai.onnxruntime.OrtSession | ||||||
|  | import ai.onnxruntime.extensions.OrtxPackage | ||||||
|  | import android.annotation.SuppressLint | ||||||
|  | import android.content.Context | ||||||
|  | import android.util.Log | ||||||
|  | import java.io.Closeable | ||||||
|  | import java.nio.FloatBuffer | ||||||
|  | import java.nio.IntBuffer | ||||||
|  | import kotlin.time.DurationUnit | ||||||
|  | import kotlin.time.measureTimedValue | ||||||
|  |  | ||||||
|  | class SpeechToTextConverter( | ||||||
|  | 	modelPath: String, | ||||||
|  | 	locale: String, | ||||||
|  | 	recorderFactory: AudioRecorderFactory, | ||||||
|  | 	private val environment: OrtEnvironment, | ||||||
|  | 	context: Context, | ||||||
|  | ) : Closeable { | ||||||
|  | 	private val recorder = recorderFactory(context) | ||||||
|  | 	private val session: OrtSession = environment.createSession( | ||||||
|  | 		modelPath, | ||||||
|  | 		OrtSession.SessionOptions().apply { | ||||||
|  | 			// Needed for audio decoding | ||||||
|  | 			registerCustomOpLibrary(OrtxPackage.getLibraryPath()) | ||||||
|  | 		}, | ||||||
|  | 	) | ||||||
|  | 	private val languageCode = Regex("_.*").replace(locale, "") | ||||||
|  | 	private val decoderInputIds = when (languageCode) { | ||||||
|  | 		// Add 50363 to the end to omit timestamps | ||||||
|  | 		"en" -> intArrayOf(50258, 50259, 50359) | ||||||
|  | 		"fr" -> intArrayOf(50258, 50265, 50359) | ||||||
|  | 		"es" -> intArrayOf(50258, 50262, 50359) | ||||||
|  | 		"de" -> intArrayOf(50258, 50261, 50359) | ||||||
|  | 		"it" -> intArrayOf(50258, 50274, 50359) | ||||||
|  | 		"nl" -> intArrayOf(50258, 50271, 50359) | ||||||
|  | 		"ko" -> intArrayOf(50258, 50264, 50359) | ||||||
|  | 		"th" -> intArrayOf(50258, 50289, 50359) | ||||||
|  | 		"ru" -> intArrayOf(50258, 50263, 50359) | ||||||
|  | 		"pt" -> intArrayOf(50258, 50267, 50359) | ||||||
|  | 		"pl" -> intArrayOf(50258, 50269, 50359) | ||||||
|  | 		"id" -> intArrayOf(50258, 50275, 50359) | ||||||
|  | 		"hi" -> intArrayOf(50258, 50276, 50359) | ||||||
|  | 		// Let Whisper guess the language | ||||||
|  | 		else -> intArrayOf(50258) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun start() { | ||||||
|  | 		recorder.start() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private fun getInputs(data: FloatArray): MutableMap<String, OnnxTensor> { | ||||||
|  | 		fun intTensor(value: Int) = OnnxTensor.createTensor( | ||||||
|  | 			environment, | ||||||
|  | 			IntBuffer.wrap(intArrayOf(value)), | ||||||
|  | 			longArrayOf(1), | ||||||
|  | 		) | ||||||
|  | 		fun floatTensor(value: Float) = OnnxTensor.createTensor( | ||||||
|  | 			environment, | ||||||
|  | 			FloatBuffer.wrap(floatArrayOf(value)), | ||||||
|  | 			longArrayOf(1), | ||||||
|  | 		) | ||||||
|  | 		val audioPcmTensor = OnnxTensor.createTensor( | ||||||
|  | 			environment, | ||||||
|  | 			FloatBuffer.wrap(data), | ||||||
|  | 			longArrayOf(1, data.size.toLong()), | ||||||
|  | 		) | ||||||
|  | 		val decoderInputIdsTensor = OnnxTensor.createTensor( | ||||||
|  | 			environment, | ||||||
|  | 			IntBuffer.wrap(decoderInputIds), | ||||||
|  | 			longArrayOf(1, decoderInputIds.size.toLong()) | ||||||
|  | 		) | ||||||
|  |  | ||||||
|  | 		return mutableMapOf( | ||||||
|  | 			"audio_pcm" to audioPcmTensor, | ||||||
|  | 			"max_length" to intTensor(412), | ||||||
|  | 			"min_length" to intTensor(0), | ||||||
|  | 			"num_return_sequences" to intTensor(1), | ||||||
|  | 			"num_beams" to intTensor(1), | ||||||
|  | 			"length_penalty" to floatTensor(1.1f), | ||||||
|  | 			"repetition_penalty" to floatTensor(3f), | ||||||
|  | 			"decoder_input_ids" to decoderInputIdsTensor, | ||||||
|  |  | ||||||
|  | 			// Required for timestamps | ||||||
|  | 			"logits_processor" to intTensor(1) | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// TODO .get() fails on older Android versions | ||||||
|  | 	@SuppressLint("NewApi") | ||||||
|  | 	private fun convert(data: FloatArray): String { | ||||||
|  | 		val (inputs, convertInputsTime) = measureTimedValue { | ||||||
|  | 			getInputs(data) | ||||||
|  | 		} | ||||||
|  | 		val (outputs, getOutputsTime) = measureTimedValue { | ||||||
|  | 			session.run(inputs, setOf("str")) | ||||||
|  | 		} | ||||||
|  | 		val mainOutput = outputs.get("str").get().value as Array<Array<String>> | ||||||
|  | 		outputs.close() | ||||||
|  |  | ||||||
|  | 		Log.i("Whisper", "Converted ${data.size / 16000}s of data in ${ | ||||||
|  | 			getOutputsTime.toString(DurationUnit.SECONDS, 2) | ||||||
|  | 		} converted inputs in ${convertInputsTime.inWholeMilliseconds}ms") | ||||||
|  | 		return mainOutput[0][0] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun dropFirstSeconds(seconds: Double) { | ||||||
|  | 		Log.i("Whisper", "Drop first seconds $seconds") | ||||||
|  | 		recorder.dropFirstSeconds(seconds) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	val bufferLengthSeconds: Double get() = recorder.bufferLengthSeconds | ||||||
|  |  | ||||||
|  | 	fun expandBufferAndConvert(seconds: Double): String { | ||||||
|  | 		recorder.pullNextSeconds(seconds) | ||||||
|  | 		// Also pull any extra available data, in case the speech-to-text converter | ||||||
|  | 		// is lagging behind the audio recorder. | ||||||
|  | 		recorder.pullAvailable() | ||||||
|  |  | ||||||
|  | 		return convert(recorder.bufferedData) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Converts as many seconds of buffered data as possible, without waiting | ||||||
|  | 	fun expandBufferAndConvert(): String { | ||||||
|  | 		recorder.pullAvailable() | ||||||
|  | 		return convert(recorder.bufferedData) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	override fun close() { | ||||||
|  | 		recorder.close() | ||||||
|  | 		session.close() | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -0,0 +1,86 @@ | |||||||
|  | package net.cozic.joplin.audio | ||||||
|  |  | ||||||
|  | import ai.onnxruntime.OrtEnvironment | ||||||
|  | import com.facebook.react.ReactPackage | ||||||
|  | import com.facebook.react.bridge.LifecycleEventListener | ||||||
|  | import com.facebook.react.bridge.NativeModule | ||||||
|  | import com.facebook.react.bridge.Promise | ||||||
|  | import com.facebook.react.bridge.ReactApplicationContext | ||||||
|  | import com.facebook.react.bridge.ReactContextBaseJavaModule | ||||||
|  | import com.facebook.react.bridge.ReactMethod | ||||||
|  | import com.facebook.react.uimanager.ViewManager | ||||||
|  | import java.util.concurrent.ExecutorService | ||||||
|  | import java.util.concurrent.Executors | ||||||
|  |  | ||||||
|  | class SpeechToTextPackage : ReactPackage { | ||||||
|  | 	override fun createNativeModules(reactContext: ReactApplicationContext): List<NativeModule> { | ||||||
|  | 		return listOf<NativeModule>(SpeechToTextModule(reactContext)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	override fun createViewManagers(reactContext: ReactApplicationContext): List<ViewManager<*, *>> { | ||||||
|  | 		return emptyList() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	class SpeechToTextModule( | ||||||
|  | 		private var context: ReactApplicationContext, | ||||||
|  | 	) : ReactContextBaseJavaModule(context), LifecycleEventListener { | ||||||
|  | 		private var environment: OrtEnvironment? = null | ||||||
|  | 		private val executorService: ExecutorService = Executors.newFixedThreadPool(1) | ||||||
|  | 		private val sessionManager = SpeechToTextSessionManager(executorService) | ||||||
|  |  | ||||||
|  | 		override fun getName() = "SpeechToTextModule" | ||||||
|  |  | ||||||
|  | 		override fun onHostResume() { } | ||||||
|  | 		override fun onHostPause() { } | ||||||
|  | 		override fun onHostDestroy() { | ||||||
|  | 			environment?.close() | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun openSession(modelPath: String, locale: String, promise: Promise) { | ||||||
|  | 			val appContext = context.applicationContext | ||||||
|  | 			// Initialize environment as late as possible: | ||||||
|  | 			val ortEnvironment = environment ?: OrtEnvironment.getEnvironment() | ||||||
|  | 			if (environment != null) { | ||||||
|  | 				environment = ortEnvironment | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			try { | ||||||
|  | 				val sessionId = sessionManager.openSession(modelPath, locale, ortEnvironment, appContext) | ||||||
|  | 				promise.resolve(sessionId) | ||||||
|  | 			} catch (exception: Throwable) { | ||||||
|  | 				promise.reject(exception) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun startRecording(sessionId: Int, promise: Promise) { | ||||||
|  | 			sessionManager.startRecording(sessionId, promise) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { | ||||||
|  | 			sessionManager.getBufferLengthSeconds(sessionId, promise) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { | ||||||
|  | 			sessionManager.dropFirstSeconds(sessionId, duration, promise) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { | ||||||
|  | 			sessionManager.expandBufferAndConvert(sessionId, duration, promise) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun convertAvailable(sessionId: Int, promise: Promise) { | ||||||
|  | 			sessionManager.convertAvailable(sessionId, promise) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		@ReactMethod | ||||||
|  | 		fun closeSession(sessionId: Int, promise: Promise) { | ||||||
|  | 			sessionManager.closeSession(sessionId, promise) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -0,0 +1,111 @@ | |||||||
|  | package net.cozic.joplin.audio | ||||||
|  |  | ||||||
|  | import ai.onnxruntime.OrtEnvironment | ||||||
|  | import android.content.Context | ||||||
|  | import com.facebook.react.bridge.Promise | ||||||
|  | import java.util.concurrent.Executor | ||||||
|  | import java.util.concurrent.locks.ReentrantLock | ||||||
|  |  | ||||||
|  | class SpeechToTextSession ( | ||||||
|  | 	val converter: SpeechToTextConverter | ||||||
|  | ) { | ||||||
|  | 	val mutex = ReentrantLock() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | class SpeechToTextSessionManager( | ||||||
|  | 	private var executor: Executor, | ||||||
|  | ) { | ||||||
|  | 	private val sessions: MutableMap<Int, SpeechToTextSession> = mutableMapOf() | ||||||
|  | 	private var nextSessionId: Int = 0 | ||||||
|  |  | ||||||
|  | 	fun openSession( | ||||||
|  | 		modelPath: String, | ||||||
|  | 		locale: String, | ||||||
|  | 		environment: OrtEnvironment, | ||||||
|  | 		context: Context, | ||||||
|  | 	): Int { | ||||||
|  | 		val sessionId = nextSessionId++ | ||||||
|  | 		sessions[sessionId] = SpeechToTextSession( | ||||||
|  | 			SpeechToTextConverter( | ||||||
|  | 				modelPath, locale, recorderFactory = AudioRecorder.factory, environment, context, | ||||||
|  | 			) | ||||||
|  | 		) | ||||||
|  | 		return sessionId | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private fun getSession(id: Int): SpeechToTextSession { | ||||||
|  | 		return sessions[id] ?: throw InvalidSessionIdException(id) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private fun concurrentWithSession( | ||||||
|  | 		id: Int, | ||||||
|  | 		callback: (session: SpeechToTextSession)->Unit, | ||||||
|  | 	) { | ||||||
|  | 		executor.execute { | ||||||
|  | 			val session = getSession(id) | ||||||
|  | 			session.mutex.lock() | ||||||
|  | 			try { | ||||||
|  | 				callback(session) | ||||||
|  | 			} finally { | ||||||
|  | 				session.mutex.unlock() | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	private fun concurrentWithSession( | ||||||
|  | 		id: Int, | ||||||
|  | 		onError: (error: Throwable)->Unit, | ||||||
|  | 		callback: (session: SpeechToTextSession)->Unit, | ||||||
|  | 	) { | ||||||
|  | 		return concurrentWithSession(id) { session -> | ||||||
|  | 			try { | ||||||
|  | 				callback(session) | ||||||
|  | 			} catch (error: Throwable) { | ||||||
|  | 				onError(error) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun startRecording(sessionId: Int, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||||
|  | 			session.converter.start() | ||||||
|  | 			promise.resolve(null) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Left-shifts the recording buffer by [duration] seconds | ||||||
|  | 	fun dropFirstSeconds(sessionId: Int, duration: Double, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||||
|  | 			session.converter.dropFirstSeconds(duration) | ||||||
|  | 			promise.resolve(sessionId) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun getBufferLengthSeconds(sessionId: Int, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||||
|  | 			promise.resolve(session.converter.bufferLengthSeconds) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Waits for the next [duration] seconds to become available, then converts | ||||||
|  | 	fun expandBufferAndConvert(sessionId: Int, duration: Double, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||||
|  | 			val result = session.converter.expandBufferAndConvert(duration) | ||||||
|  | 			promise.resolve(result) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Converts all available recorded data | ||||||
|  | 	fun convertAvailable(sessionId: Int, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId, promise::reject) { session -> | ||||||
|  | 			val result = session.converter.expandBufferAndConvert() | ||||||
|  | 			promise.resolve(result) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	fun closeSession(sessionId: Int, promise: Promise) { | ||||||
|  | 		this.concurrentWithSession(sessionId) { session -> | ||||||
|  | 			session.converter.close() | ||||||
|  | 			promise.resolve(null) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -45,7 +45,6 @@ import ImageEditor from '../NoteEditor/ImageEditor/ImageEditor'; | |||||||
| import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave'; | import promptRestoreAutosave from '../NoteEditor/ImageEditor/promptRestoreAutosave'; | ||||||
| import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource'; | import isEditableResource from '../NoteEditor/ImageEditor/isEditableResource'; | ||||||
| import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog'; | import VoiceTypingDialog from '../voiceTyping/VoiceTypingDialog'; | ||||||
| import { voskEnabled } from '../../services/voiceTyping/vosk'; |  | ||||||
| import { isSupportedLanguage } from '../../services/voiceTyping/vosk'; | import { isSupportedLanguage } from '../../services/voiceTyping/vosk'; | ||||||
| import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events'; | import { ChangeEvent as EditorChangeEvent, SelectionRangeChangeEvent, UndoRedoDepthChangeEvent } from '@joplin/editor/events'; | ||||||
| import { join } from 'path'; | import { join } from 'path'; | ||||||
| @@ -1204,8 +1203,8 @@ class NoteScreenComponent extends BaseScreenComponent<Props, State> implements B | |||||||
| 			}); | 			}); | ||||||
| 		} | 		} | ||||||
|  |  | ||||||
| 		// Voice typing is enabled only for French language and on Android for now | 		// Voice typing is enabled only on Android for now | ||||||
| 		if (voskEnabled && shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { | 		if (shim.mobilePlatform() === 'android' && isSupportedLanguage(currentLocale())) { | ||||||
| 			output.push({ | 			output.push({ | ||||||
| 				title: _('Voice typing...'), | 				title: _('Voice typing...'), | ||||||
| 				onPress: () => { | 				onPress: () => { | ||||||
|   | |||||||
| @@ -1,14 +1,18 @@ | |||||||
| import * as React from 'react'; | import * as React from 'react'; | ||||||
| import { useState, useEffect, useCallback } from 'react'; | import { useState, useEffect, useCallback, useRef, useMemo } from 'react'; | ||||||
| import { Banner, ActivityIndicator } from 'react-native-paper'; | import { Banner, ActivityIndicator, Text } from 'react-native-paper'; | ||||||
| import { _, languageName } from '@joplin/lib/locale'; | import { _, languageName } from '@joplin/lib/locale'; | ||||||
| import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect'; | import useAsyncEffect, { AsyncEffectEvent } from '@joplin/lib/hooks/useAsyncEffect'; | ||||||
| import { getVosk, Recorder, startRecording, Vosk } from '../../services/voiceTyping/vosk'; |  | ||||||
| import { IconSource } from 'react-native-paper/lib/typescript/components/Icon'; | import { IconSource } from 'react-native-paper/lib/typescript/components/Icon'; | ||||||
| import { modelIsDownloaded } from '../../services/voiceTyping/vosk'; | import VoiceTyping, { OnTextCallback, VoiceTypingSession } from '../../services/voiceTyping/VoiceTyping'; | ||||||
|  | import whisper from '../../services/voiceTyping/whisper'; | ||||||
|  | import vosk from '../../services/voiceTyping/vosk'; | ||||||
|  | import { AppState } from '../../utils/types'; | ||||||
|  | import { connect } from 'react-redux'; | ||||||
|  |  | ||||||
| interface Props { | interface Props { | ||||||
| 	locale: string; | 	locale: string; | ||||||
|  | 	provider: string; | ||||||
| 	onDismiss: ()=> void; | 	onDismiss: ()=> void; | ||||||
| 	onText: (text: string)=> void; | 	onText: (text: string)=> void; | ||||||
| } | } | ||||||
| @@ -21,44 +25,77 @@ enum RecorderState { | |||||||
| 	Downloading = 5, | 	Downloading = 5, | ||||||
| } | } | ||||||
|  |  | ||||||
| const useVosk = (locale: string): [Error | null, boolean, Vosk|null] => { | interface UseVoiceTypingProps { | ||||||
| 	const [vosk, setVosk] = useState<Vosk>(null); | 	locale: string; | ||||||
|  | 	provider: string; | ||||||
|  | 	onSetPreview: OnTextCallback; | ||||||
|  | 	onText: OnTextCallback; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => { | ||||||
|  | 	const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null); | ||||||
| 	const [error, setError] = useState<Error>(null); | 	const [error, setError] = useState<Error>(null); | ||||||
| 	const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null); | 	const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null); | ||||||
|  |  | ||||||
| 	useAsyncEffect(async (event: AsyncEffectEvent) => { | 	const onTextRef = useRef(onText); | ||||||
| 		if (mustDownloadModel === null) return; | 	onTextRef.current = onText; | ||||||
|  | 	const onSetPreviewRef = useRef(onSetPreview); | ||||||
|  | 	onSetPreviewRef.current = onSetPreview; | ||||||
|  |  | ||||||
|  | 	const voiceTypingRef = useRef(voiceTyping); | ||||||
|  | 	voiceTypingRef.current = voiceTyping; | ||||||
|  |  | ||||||
|  | 	const builder = useMemo(() => { | ||||||
|  | 		return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]); | ||||||
|  | 	}, [locale, provider]); | ||||||
|  |  | ||||||
|  | 	useAsyncEffect(async (event: AsyncEffectEvent) => { | ||||||
| 		try { | 		try { | ||||||
| 			const v = await getVosk(locale); | 			await voiceTypingRef.current?.stop(); | ||||||
|  |  | ||||||
|  | 			if (!await builder.isDownloaded()) { | ||||||
|  | 				if (event.cancelled) return; | ||||||
|  | 				await builder.download(); | ||||||
|  | 			} | ||||||
| 			if (event.cancelled) return; | 			if (event.cancelled) return; | ||||||
| 			setVosk(v); |  | ||||||
|  | 			const voiceTyping = await builder.build({ | ||||||
|  | 				onPreview: (text) => onSetPreviewRef.current(text), | ||||||
|  | 				onFinalize: (text) => onTextRef.current(text), | ||||||
|  | 			}); | ||||||
|  | 			if (event.cancelled) return; | ||||||
|  | 			setVoiceTyping(voiceTyping); | ||||||
| 		} catch (error) { | 		} catch (error) { | ||||||
| 			setError(error); | 			setError(error); | ||||||
| 		} finally { | 		} finally { | ||||||
| 			setMustDownloadModel(false); | 			setMustDownloadModel(false); | ||||||
| 		} | 		} | ||||||
| 	}, [locale, mustDownloadModel]); | 	}, [builder]); | ||||||
|  |  | ||||||
| 	useAsyncEffect(async (_event: AsyncEffectEvent) => { | 	useAsyncEffect(async (_event: AsyncEffectEvent) => { | ||||||
| 		setMustDownloadModel(!(await modelIsDownloaded(locale))); | 		setMustDownloadModel(!(await builder.isDownloaded())); | ||||||
| 	}, [locale]); | 	}, [builder]); | ||||||
|  |  | ||||||
| 	return [error, mustDownloadModel, vosk]; | 	return [error, mustDownloadModel, voiceTyping]; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| export default (props: Props) => { | const VoiceTypingDialog: React.FC<Props> = props => { | ||||||
| 	const [recorder, setRecorder] = useState<Recorder>(null); |  | ||||||
| 	const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading); | 	const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading); | ||||||
| 	const [voskError, mustDownloadModel, vosk] = useVosk(props.locale); | 	const [preview, setPreview] = useState<string>(''); | ||||||
|  | 	const [modelError, mustDownloadModel, voiceTyping] = useWhisper({ | ||||||
|  | 		locale: props.locale, | ||||||
|  | 		onSetPreview: setPreview, | ||||||
|  | 		onText: props.onText, | ||||||
|  | 		provider: props.provider, | ||||||
|  | 	}); | ||||||
|  |  | ||||||
| 	useEffect(() => { | 	useEffect(() => { | ||||||
| 		if (voskError) { | 		if (modelError) { | ||||||
| 			setRecorderState(RecorderState.Error); | 			setRecorderState(RecorderState.Error); | ||||||
| 		} else if (vosk) { | 		} else if (voiceTyping) { | ||||||
| 			setRecorderState(RecorderState.Recording); | 			setRecorderState(RecorderState.Recording); | ||||||
| 		} | 		} | ||||||
| 	}, [vosk, voskError]); | 	}, [voiceTyping, modelError]); | ||||||
|  |  | ||||||
| 	useEffect(() => { | 	useEffect(() => { | ||||||
| 		if (mustDownloadModel) { | 		if (mustDownloadModel) { | ||||||
| @@ -68,27 +105,22 @@ export default (props: Props) => { | |||||||
|  |  | ||||||
| 	useEffect(() => { | 	useEffect(() => { | ||||||
| 		if (recorderState === RecorderState.Recording) { | 		if (recorderState === RecorderState.Recording) { | ||||||
| 			setRecorder(startRecording(vosk, { | 			void voiceTyping.start(); | ||||||
| 				onResult: (text: string) => { |  | ||||||
| 					props.onText(text); |  | ||||||
| 				}, |  | ||||||
| 			})); |  | ||||||
| 		} | 		} | ||||||
| 	}, [recorderState, vosk, props.onText]); | 	}, [recorderState, voiceTyping, props.onText]); | ||||||
|  |  | ||||||
| 	const onDismiss = useCallback(() => { | 	const onDismiss = useCallback(() => { | ||||||
| 		if (recorder) recorder.cleanup(); | 		void voiceTyping?.stop(); | ||||||
| 		props.onDismiss(); | 		props.onDismiss(); | ||||||
| 	}, [recorder, props.onDismiss]); | 	}, [voiceTyping, props.onDismiss]); | ||||||
|  |  | ||||||
| 	const renderContent = () => { | 	const renderContent = () => { | ||||||
| 		// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | 		const components: Record<RecorderState, ()=> string> = { | ||||||
| 		const components: Record<RecorderState, Function> = { |  | ||||||
| 			[RecorderState.Loading]: () => _('Loading...'), | 			[RecorderState.Loading]: () => _('Loading...'), | ||||||
| 			[RecorderState.Recording]: () => _('Please record your voice...'), | 			[RecorderState.Recording]: () => _('Please record your voice...'), | ||||||
| 			[RecorderState.Processing]: () => _('Converting speech to text...'), | 			[RecorderState.Processing]: () => _('Converting speech to text...'), | ||||||
| 			[RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)), | 			[RecorderState.Downloading]: () => _('Downloading %s language files...', languageName(props.locale)), | ||||||
| 			[RecorderState.Error]: () => _('Error: %s', voskError.message), | 			[RecorderState.Error]: () => _('Error: %s', modelError.message), | ||||||
| 		}; | 		}; | ||||||
|  |  | ||||||
| 		return components[recorderState](); | 		return components[recorderState](); | ||||||
| @@ -106,6 +138,11 @@ export default (props: Props) => { | |||||||
| 		return components[recorderState]; | 		return components[recorderState]; | ||||||
| 	}; | 	}; | ||||||
|  |  | ||||||
|  | 	const renderPreview = () => { | ||||||
|  | 		return <Text variant='labelSmall'>{preview}</Text>; | ||||||
|  | 	}; | ||||||
|  |  | ||||||
|  | 	const headerAndStatus = <Text variant='bodyMedium'>{`${_('Voice typing...')}\n${renderContent()}`}</Text>; | ||||||
| 	return ( | 	return ( | ||||||
| 		<Banner | 		<Banner | ||||||
| 			visible={true} | 			visible={true} | ||||||
| @@ -115,8 +152,15 @@ export default (props: Props) => { | |||||||
| 					label: _('Done'), | 					label: _('Done'), | ||||||
| 					onPress: onDismiss, | 					onPress: onDismiss, | ||||||
| 				}, | 				}, | ||||||
| 			]}> | 			]} | ||||||
| 			{`${_('Voice typing...')}\n${renderContent()}`} | 		> | ||||||
|  | 			{headerAndStatus} | ||||||
|  | 			<Text>{'\n'}</Text> | ||||||
|  | 			{renderPreview()} | ||||||
| 		</Banner> | 		</Banner> | ||||||
| 	); | 	); | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | export default connect((state: AppState) => ({ | ||||||
|  | 	provider: state.settings['voiceTyping.preferredProvider'], | ||||||
|  | }))(VoiceTypingDialog); | ||||||
|   | |||||||
| @@ -80,6 +80,10 @@ jest.mock('react-native-image-picker', () => { | |||||||
| 	return { default: { } }; | 	return { default: { } }; | ||||||
| }); | }); | ||||||
|  |  | ||||||
|  | jest.mock('react-native-zip-archive', () => { | ||||||
|  | 	return { default: { } }; | ||||||
|  | }); | ||||||
|  |  | ||||||
| jest.mock('react-native-document-picker', () => ({ default: { } })); | jest.mock('react-native-document-picker', () => ({ default: { } })); | ||||||
|  |  | ||||||
| // Used by the renderer | // Used by the renderer | ||||||
|   | |||||||
							
								
								
									
										139
									
								
								packages/app-mobile/services/voiceTyping/VoiceTyping.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								packages/app-mobile/services/voiceTyping/VoiceTyping.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,139 @@ | |||||||
|  | import shim from '@joplin/lib/shim'; | ||||||
|  | import Logger from '@joplin/utils/Logger'; | ||||||
|  | import { PermissionsAndroid, Platform } from 'react-native'; | ||||||
|  | import { unzip } from 'react-native-zip-archive'; | ||||||
|  | const md5 = require('md5'); | ||||||
|  |  | ||||||
|  | const logger = Logger.create('voiceTyping'); | ||||||
|  |  | ||||||
|  | export type OnTextCallback = (text: string)=> void; | ||||||
|  |  | ||||||
|  | export interface SpeechToTextCallbacks { | ||||||
|  | 	// Called with a block of text that might change in the future | ||||||
|  | 	onPreview: OnTextCallback; | ||||||
|  | 	// Called with text that will not change and should be added to the document | ||||||
|  | 	onFinalize: OnTextCallback; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export interface VoiceTypingSession { | ||||||
|  | 	start(): Promise<void>; | ||||||
|  | 	stop(): Promise<void>; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export interface BuildProviderOptions { | ||||||
|  | 	locale: string; | ||||||
|  | 	modelPath: string; | ||||||
|  | 	callbacks: SpeechToTextCallbacks; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export interface VoiceTypingProvider { | ||||||
|  | 	modelName: string; | ||||||
|  | 	supported(): boolean; | ||||||
|  | 	modelLocalFilepath(locale: string): string; | ||||||
|  | 	getDownloadUrl(locale: string): string; | ||||||
|  | 	getUuidPath(locale: string): string; | ||||||
|  | 	build(options: BuildProviderOptions): Promise<VoiceTypingSession>; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export default class VoiceTyping { | ||||||
|  | 	private provider: VoiceTypingProvider|null = null; | ||||||
|  | 	public constructor( | ||||||
|  | 		private locale: string, | ||||||
|  | 		providers: VoiceTypingProvider[], | ||||||
|  | 	) { | ||||||
|  | 		this.provider = providers.find(p => p.supported()) ?? null; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	public supported() { | ||||||
|  | 		return this.provider !== null; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private getModelPath() { | ||||||
|  | 		const localFilePath = shim.fsDriver().resolveRelativePathWithinDir( | ||||||
|  | 			shim.fsDriver().getAppDirectoryPath(), | ||||||
|  | 			this.provider.modelLocalFilepath(this.locale), | ||||||
|  | 		); | ||||||
|  | 		if (localFilePath === shim.fsDriver().getAppDirectoryPath()) { | ||||||
|  | 			throw new Error('Invalid local file path!'); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return localFilePath; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	private getUuidPath() { | ||||||
|  | 		return shim.fsDriver().resolveRelativePathWithinDir( | ||||||
|  | 			shim.fsDriver().getAppDirectoryPath(), | ||||||
|  | 			this.provider.getUuidPath(this.locale), | ||||||
|  | 		); | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	public async isDownloaded() { | ||||||
|  | 		return await shim.fsDriver().exists(this.getUuidPath()); | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	public async download() { | ||||||
|  | 		const modelPath = this.getModelPath(); | ||||||
|  | 		const modelUrl = this.provider.getDownloadUrl(this.locale); | ||||||
|  |  | ||||||
|  | 		await shim.fsDriver().remove(modelPath); | ||||||
|  | 		logger.info(`Downloading model from: ${modelUrl}`); | ||||||
|  |  | ||||||
|  | 		const isZipped = modelUrl.endsWith('.zip'); | ||||||
|  | 		const downloadPath = isZipped ? `${modelPath}.zip` : modelPath; | ||||||
|  | 		const response = await shim.fetchBlob(modelUrl, { | ||||||
|  | 			path: downloadPath, | ||||||
|  | 		}); | ||||||
|  |  | ||||||
|  | 		if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); | ||||||
|  |  | ||||||
|  | 		if (isZipped) { | ||||||
|  | 			const modelName = this.provider.modelName; | ||||||
|  | 			const unzipDir = `${shim.fsDriver().getCacheDirectoryPath()}/voice-typing-extract/${modelName}/${this.locale}`; | ||||||
|  | 			try { | ||||||
|  | 				logger.info(`Unzipping ${downloadPath} => ${unzipDir}`); | ||||||
|  |  | ||||||
|  | 				await unzip(downloadPath, unzipDir); | ||||||
|  |  | ||||||
|  | 				const contents = await shim.fsDriver().readDirStats(unzipDir); | ||||||
|  | 				if (contents.length !== 1) { | ||||||
|  | 					logger.error('Expected 1 file or directory but got', contents); | ||||||
|  | 					throw new Error(`Expected 1 file or directory, but got ${contents.length}`); | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				const fullUnzipPath = `${unzipDir}/${contents[0].path}`; | ||||||
|  |  | ||||||
|  | 				logger.info(`Moving ${fullUnzipPath} => ${modelPath}`); | ||||||
|  | 				await shim.fsDriver().move(fullUnzipPath, modelPath); | ||||||
|  |  | ||||||
|  | 				await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8'); | ||||||
|  | 				if (!await this.isDownloaded()) { | ||||||
|  | 					logger.warn('Model should be downloaded!'); | ||||||
|  | 				} | ||||||
|  | 			} finally { | ||||||
|  | 				await shim.fsDriver().remove(unzipDir); | ||||||
|  | 				await shim.fsDriver().remove(downloadPath); | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	public async build(callbacks: SpeechToTextCallbacks) { | ||||||
|  | 		if (!this.provider) { | ||||||
|  | 			throw new Error('No supported provider found!'); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if (!await this.isDownloaded()) { | ||||||
|  | 			await this.download(); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		const audioPermission = 'android.permission.RECORD_AUDIO'; | ||||||
|  | 		if (Platform.OS === 'android' && !await PermissionsAndroid.check(audioPermission)) { | ||||||
|  | 			await PermissionsAndroid.request(audioPermission); | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return this.provider.build({ | ||||||
|  | 			locale: this.locale, | ||||||
|  | 			modelPath: this.getModelPath(), | ||||||
|  | 			callbacks, | ||||||
|  | 		}); | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -0,0 +1,61 @@ | |||||||
|  | import splitWhisperText from './splitWhisperText'; | ||||||
|  |  | ||||||
|  | describe('splitWhisperText', () => { | ||||||
|  | 	test.each([ | ||||||
|  | 		{ | ||||||
|  | 			// Should trim at sentence breaks | ||||||
|  | 			input: '<|0.00|> This is a test. <|5.00|><|6.00|> This is another sentence. <|7.00|>', | ||||||
|  | 			recordingLength: 8, | ||||||
|  | 			expected: { | ||||||
|  | 				trimTo: 6, | ||||||
|  | 				dataBeforeTrim: '<|0.00|> This is a test. ', | ||||||
|  | 				dataAfterTrim: ' This is another sentence. <|7.00|>', | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// Should prefer sentence break splits to non sentence break splits | ||||||
|  | 			input: '<|0.00|> This is <|4.00|><|4.50|> a test. <|5.00|><|5.50|> Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', | ||||||
|  | 			recordingLength: 8, | ||||||
|  | 			expected: { | ||||||
|  | 				trimTo: 5.50, | ||||||
|  | 				dataBeforeTrim: '<|0.00|> This is <|4.00|><|4.50|> a test. ', | ||||||
|  | 				dataAfterTrim: ' Testing, <|6.00|><|7.00|> this is a test. <|8.00|>', | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// Should avoid splitting for very small timestamps | ||||||
|  | 			input: '<|0.00|> This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', | ||||||
|  | 			recordingLength: 4, | ||||||
|  | 			expected: { | ||||||
|  | 				trimTo: 0, | ||||||
|  | 				dataBeforeTrim: '', | ||||||
|  | 				dataAfterTrim: ' This is a test. <|2.00|><|2.30|> Testing! <|3.00|>', | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// For larger timestamps, should allow splitting at pauses, even if not on sentence breaks. | ||||||
|  | 			input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', | ||||||
|  | 			recordingLength: 16, | ||||||
|  | 			expected: { | ||||||
|  | 				trimTo: 12, | ||||||
|  | 				dataBeforeTrim: '<|0.00|> This is a test, ', | ||||||
|  | 				dataAfterTrim: ' of splitting on timestamps. <|15.00|>', | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		{ | ||||||
|  | 			// Should prefer to break at the end, if a large gap after the last timestamp. | ||||||
|  | 			input: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. <|15.00|>', | ||||||
|  | 			recordingLength: 30, | ||||||
|  | 			expected: { | ||||||
|  | 				trimTo: 15, | ||||||
|  | 				dataBeforeTrim: '<|0.00|> This is a test, <|10.00|><|12.00|> of splitting on timestamps. ', | ||||||
|  | 				dataAfterTrim: '', | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	])('should prefer to split at the end of sentences (case %#)', ({ input, recordingLength, expected }) => { | ||||||
|  | 		const actual = splitWhisperText(input, recordingLength); | ||||||
|  | 		expect(actual.trimTo).toBeCloseTo(expected.trimTo); | ||||||
|  | 		expect(actual.dataBeforeTrim).toBe(expected.dataBeforeTrim); | ||||||
|  | 		expect(actual.dataAfterTrim).toBe(expected.dataAfterTrim); | ||||||
|  | 	}); | ||||||
|  | }); | ||||||
| @@ -0,0 +1,65 @@ | |||||||
|  | // Matches pairs of timestamps or single timestamps. | ||||||
|  | const timestampExp = /<\|(\d+\.\d*)\|>(?:<\|(\d+\.\d*)\|>)?/g; | ||||||
|  |  | ||||||
|  | const timestampMatchToNumber = (match: RegExpMatchArray) => { | ||||||
|  | 	const firstTimestamp = match[1]; | ||||||
|  | 	const secondTimestamp = match[2]; | ||||||
|  | 	// Prefer the second timestamp in the pair, to remove leading silence. | ||||||
|  | 	const timestamp = Number(secondTimestamp ? secondTimestamp : firstTimestamp); | ||||||
|  |  | ||||||
|  | 	// Should always be a finite number (i.e. not NaN) | ||||||
|  | 	if (!isFinite(timestamp)) throw new Error(`Timestamp match failed with ${match[0]}`); | ||||||
|  |  | ||||||
|  | 	return timestamp; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | const splitWhisperText = (textWithTimestamps: string, recordingLengthSeconds: number) => { | ||||||
|  | 	const timestamps = [ | ||||||
|  | 		...textWithTimestamps.matchAll(timestampExp), | ||||||
|  | 	].map(match => { | ||||||
|  | 		const timestamp = timestampMatchToNumber(match); | ||||||
|  | 		return { timestamp, match }; | ||||||
|  | 	}); | ||||||
|  |  | ||||||
|  | 	if (!timestamps.length) { | ||||||
|  | 		return { trimTo: 0, dataBeforeTrim: '', dataAfterTrim: textWithTimestamps }; | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	const firstTimestamp = timestamps[0]; | ||||||
|  | 	let breakAt = firstTimestamp; | ||||||
|  |  | ||||||
|  | 	const lastTimestamp = timestamps[timestamps.length - 1]; | ||||||
|  | 	const hasLongPauseAfterData = lastTimestamp.timestamp + 4 < recordingLengthSeconds; | ||||||
|  | 	if (hasLongPauseAfterData) { | ||||||
|  | 		breakAt = lastTimestamp; | ||||||
|  | 	} else { | ||||||
|  | 		const textWithTimestampsContentLength = textWithTimestamps.trimEnd().length; | ||||||
|  |  | ||||||
|  | 		for (const timestampData of timestamps) { | ||||||
|  | 			const { match, timestamp } = timestampData; | ||||||
|  | 			const contentBefore = textWithTimestamps.substring(Math.max(match.index - 3, 0), match.index); | ||||||
|  | 			const isNearEndOfLatinSentence = contentBefore.match(/[.?!]/); | ||||||
|  | 			const isNearEndOfData = match.index + match[0].length >= textWithTimestampsContentLength; | ||||||
|  |  | ||||||
|  | 			// Use a heuristic to determine whether to move content from the preview to the document. | ||||||
|  | 			// These are based on the maximum buffer length of 30 seconds -- as the buffer gets longer, the | ||||||
|  | 			// data should be more likely to be broken into chunks. Where possible, the break should be near | ||||||
|  | 			// the end of a sentence: | ||||||
|  | 			const canBreak = (timestamp > 4 && isNearEndOfLatinSentence && !isNearEndOfData) | ||||||
|  | 					|| (timestamp > 8 && !isNearEndOfData) | ||||||
|  | 					|| timestamp > 16; | ||||||
|  | 			if (canBreak) { | ||||||
|  | 				breakAt = timestampData; | ||||||
|  | 				break; | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	const trimTo = breakAt.timestamp; | ||||||
|  | 	const dataBeforeTrim = textWithTimestamps.substring(0, breakAt.match.index); | ||||||
|  | 	const dataAfterTrim = textWithTimestamps.substring(breakAt.match.index + breakAt.match[0].length); | ||||||
|  |  | ||||||
|  | 	return { trimTo, dataBeforeTrim, dataAfterTrim }; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | export default splitWhisperText; | ||||||
| @@ -4,9 +4,9 @@ import Setting from '@joplin/lib/models/Setting'; | |||||||
| import { rtrimSlashes } from '@joplin/lib/path-utils'; | import { rtrimSlashes } from '@joplin/lib/path-utils'; | ||||||
| import shim from '@joplin/lib/shim'; | import shim from '@joplin/lib/shim'; | ||||||
| import Vosk from 'react-native-vosk'; | import Vosk from 'react-native-vosk'; | ||||||
| import { unzip } from 'react-native-zip-archive'; |  | ||||||
| import RNFetchBlob from 'rn-fetch-blob'; | import RNFetchBlob from 'rn-fetch-blob'; | ||||||
| const md5 = require('md5'); | import { VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; | ||||||
|  | import { join } from 'path'; | ||||||
|  |  | ||||||
| const logger = Logger.create('voiceTyping/vosk'); | const logger = Logger.create('voiceTyping/vosk'); | ||||||
|  |  | ||||||
| @@ -24,15 +24,6 @@ let vosk_: Record<string, Vosk> = {}; | |||||||
|  |  | ||||||
| let state_: State = State.Idle; | let state_: State = State.Idle; | ||||||
|  |  | ||||||
| export const voskEnabled = true; |  | ||||||
|  |  | ||||||
| export { Vosk }; |  | ||||||
|  |  | ||||||
| export interface Recorder { |  | ||||||
| 	stop: ()=> Promise<string>; |  | ||||||
| 	cleanup: ()=> void; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| const defaultSupportedLanguages = { | const defaultSupportedLanguages = { | ||||||
| 	'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip', | 	'en': 'https://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip', | ||||||
| 	'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip', | 	'zh': 'https://alphacephei.com/vosk/models/vosk-model-small-cn-0.22.zip', | ||||||
| @@ -74,7 +65,7 @@ const getModelDir = (locale: string) => { | |||||||
| 	return `${getUnzipDir(locale)}/model`; | 	return `${getUnzipDir(locale)}/model`; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| const languageModelUrl = (locale: string) => { | const languageModelUrl = (locale: string): string => { | ||||||
| 	const lang = languageCodeOnly(locale).toLowerCase(); | 	const lang = languageCodeOnly(locale).toLowerCase(); | ||||||
| 	if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`); | 	if (!(lang in defaultSupportedLanguages)) throw new Error(`No language file for: ${locale}`); | ||||||
|  |  | ||||||
| @@ -90,16 +81,11 @@ const languageModelUrl = (locale: string) => { | |||||||
| 	} | 	} | ||||||
| }; | }; | ||||||
|  |  | ||||||
| export const modelIsDownloaded = async (locale: string) => { |  | ||||||
| 	const uuidFile = `${getModelDir(locale)}/uuid`; |  | ||||||
| 	return shim.fsDriver().exists(uuidFile); |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| export const getVosk = async (locale: string) => { | export const getVosk = async (modelDir: string, locale: string) => { | ||||||
| 	if (vosk_[locale]) return vosk_[locale]; | 	if (vosk_[locale]) return vosk_[locale]; | ||||||
|  |  | ||||||
| 	const vosk = new Vosk(); | 	const vosk = new Vosk(); | ||||||
| 	const modelDir = await downloadModel(locale); |  | ||||||
| 	logger.info(`Loading model from ${modelDir}`); | 	logger.info(`Loading model from ${modelDir}`); | ||||||
| 	await shim.fsDriver().readDirStats(modelDir); | 	await shim.fsDriver().readDirStats(modelDir); | ||||||
| 	const result = await vosk.loadModel(modelDir); | 	const result = await vosk.loadModel(modelDir); | ||||||
| @@ -110,51 +96,7 @@ export const getVosk = async (locale: string) => { | |||||||
| 	return vosk; | 	return vosk; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| const downloadModel = async (locale: string) => { | export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSession => { | ||||||
| 	const modelUrl = languageModelUrl(locale); |  | ||||||
| 	const unzipDir = getUnzipDir(locale); |  | ||||||
| 	const zipFilePath = `${unzipDir}.zip`; |  | ||||||
| 	const modelDir = getModelDir(locale); |  | ||||||
| 	const uuidFile = `${modelDir}/uuid`; |  | ||||||
|  |  | ||||||
| 	if (await modelIsDownloaded(locale)) { |  | ||||||
| 		logger.info(`Model for ${locale} already exists at ${modelDir}`); |  | ||||||
| 		return modelDir; |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	await shim.fsDriver().remove(unzipDir); |  | ||||||
|  |  | ||||||
| 	logger.info(`Downloading model from: ${modelUrl}`); |  | ||||||
|  |  | ||||||
| 	const response = await shim.fetchBlob(modelUrl, { |  | ||||||
| 		path: zipFilePath, |  | ||||||
| 	}); |  | ||||||
|  |  | ||||||
| 	if (!response.ok || response.status >= 400) throw new Error(`Could not download from ${modelUrl}: Error ${response.status}`); |  | ||||||
|  |  | ||||||
| 	logger.info(`Unzipping ${zipFilePath} => ${unzipDir}`); |  | ||||||
|  |  | ||||||
| 	await unzip(zipFilePath, unzipDir); |  | ||||||
|  |  | ||||||
| 	const dirs = await shim.fsDriver().readDirStats(unzipDir); |  | ||||||
| 	if (dirs.length !== 1) { |  | ||||||
| 		logger.error('Expected 1 directory but got', dirs); |  | ||||||
| 		throw new Error(`Expected 1 directory, but got ${dirs.length}`); |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	const fullUnzipPath = `${unzipDir}/${dirs[0].path}`; |  | ||||||
|  |  | ||||||
| 	logger.info(`Moving ${fullUnzipPath} =>  ${modelDir}`); |  | ||||||
| 	await shim.fsDriver().rename(fullUnzipPath, modelDir); |  | ||||||
|  |  | ||||||
| 	await shim.fsDriver().writeFile(uuidFile, md5(modelUrl)); |  | ||||||
|  |  | ||||||
| 	await shim.fsDriver().remove(zipFilePath); |  | ||||||
|  |  | ||||||
| 	return modelDir; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { |  | ||||||
| 	if (state_ !== State.Idle) throw new Error('Vosk is already recording'); | 	if (state_ !== State.Idle) throw new Error('Vosk is already recording'); | ||||||
|  |  | ||||||
| 	state_ = State.Recording; | 	state_ = State.Recording; | ||||||
| @@ -163,10 +105,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | |||||||
| 	// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | 	// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | ||||||
| 	const eventHandlers: any[] = []; | 	const eventHandlers: any[] = []; | ||||||
| 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||||
| 	let finalResultPromiseResolve: Function = null; | 	const finalResultPromiseResolve: Function = null; | ||||||
| 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | 	// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied | ||||||
| 	let finalResultPromiseReject: Function = null; | 	const finalResultPromiseReject: Function = null; | ||||||
| 	let finalResultTimeout = false; | 	const finalResultTimeout = false; | ||||||
|  |  | ||||||
| 	const completeRecording = (finalResult: string, error: Error) => { | 	const completeRecording = (finalResult: string, error: Error) => { | ||||||
| 		logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error); | 		logger.info(`Complete recording. Final result: ${finalResult}. Error:`, error); | ||||||
| @@ -212,31 +154,13 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | |||||||
| 		completeRecording(e.data, null); | 		completeRecording(e.data, null); | ||||||
| 	})); | 	})); | ||||||
|  |  | ||||||
| 	logger.info('Starting recording...'); |  | ||||||
|  |  | ||||||
| 	void vosk.start(); |  | ||||||
|  |  | ||||||
| 	return { | 	return { | ||||||
| 		stop: (): Promise<string> => { | 		start: async () => { | ||||||
| 			logger.info('Stopping recording...'); | 			logger.info('Starting recording...'); | ||||||
|  | 			await vosk.start(); | ||||||
| 			vosk.stopOnly(); |  | ||||||
|  |  | ||||||
| 			logger.info('Waiting for final result...'); |  | ||||||
|  |  | ||||||
| 			setTimeout(() => { |  | ||||||
| 				finalResultTimeout = true; |  | ||||||
| 				logger.warn('Timed out waiting for finalResult event'); |  | ||||||
| 				completeRecording('', new Error('Could not process your message. Please try again.')); |  | ||||||
| 			}, 5000); |  | ||||||
|  |  | ||||||
| 			// eslint-disable-next-line @typescript-eslint/ban-types -- Old code before rule was applied |  | ||||||
| 			return new Promise((resolve: Function, reject: Function) => { |  | ||||||
| 				finalResultPromiseResolve = resolve; |  | ||||||
| 				finalResultPromiseReject = reject; |  | ||||||
| 			}); |  | ||||||
| 		}, | 		}, | ||||||
| 		cleanup: () => { | 		stop: async () => { | ||||||
| 			if (state_ === State.Recording) { | 			if (state_ === State.Recording) { | ||||||
| 				logger.info('Cancelling...'); | 				logger.info('Cancelling...'); | ||||||
| 				state_ = State.Completing; | 				state_ = State.Completing; | ||||||
| @@ -246,3 +170,18 @@ export const startRecording = (vosk: Vosk, options: StartOptions): Recorder => { | |||||||
| 		}, | 		}, | ||||||
| 	}; | 	}; | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  |  | ||||||
|  | const vosk: VoiceTypingProvider = { | ||||||
|  | 	supported: () => true, | ||||||
|  | 	modelLocalFilepath: (locale: string) => getModelDir(locale), | ||||||
|  | 	getDownloadUrl: (locale) => languageModelUrl(locale), | ||||||
|  | 	getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'), | ||||||
|  | 	build: async ({ callbacks, locale, modelPath }) => { | ||||||
|  | 		const vosk = await getVosk(modelPath, locale); | ||||||
|  | 		return startRecording(vosk, { onResult: callbacks.onFinalize }); | ||||||
|  | 	}, | ||||||
|  | 	modelName: 'vosk', | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | export default vosk; | ||||||
|   | |||||||
| @@ -1,37 +1,14 @@ | |||||||
| // Currently disabled on non-Android platforms | import { VoiceTypingProvider } from './VoiceTyping'; | ||||||
|  |  | ||||||
| // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied | const vosk: VoiceTypingProvider = { | ||||||
| type Vosk = any; | 	supported: () => false, | ||||||
|  | 	modelLocalFilepath: () => null, | ||||||
| export { Vosk }; | 	getDownloadUrl: () => null, | ||||||
|  | 	getUuidPath: () => null, | ||||||
| interface StartOptions { | 	build: async () => { | ||||||
| 	onResult: (text: string)=> void; | 		throw new Error('Unsupported!'); | ||||||
| } | 	}, | ||||||
|  | 	modelName: 'vosk', | ||||||
| export interface Recorder { |  | ||||||
| 	stop: ()=> Promise<string>; |  | ||||||
| 	cleanup: ()=> void; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| export const isSupportedLanguage = (_locale: string) => { |  | ||||||
| 	return false; |  | ||||||
| }; | }; | ||||||
|  |  | ||||||
| export const modelIsDownloaded = async (_locale: string) => { | export default vosk; | ||||||
| 	return false; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| export const getVosk = async (_locale: string) => { |  | ||||||
| 	// eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied |  | ||||||
| 	return {} as any; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| export const startRecording = (_vosk: Vosk, _options: StartOptions): Recorder => { |  | ||||||
| 	return { |  | ||||||
| 		stop: async () => { return ''; }, |  | ||||||
| 		cleanup: () => {}, |  | ||||||
| 	}; |  | ||||||
| }; |  | ||||||
|  |  | ||||||
| export const voskEnabled = false; |  | ||||||
|   | |||||||
							
								
								
									
										119
									
								
								packages/app-mobile/services/voiceTyping/whisper.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								packages/app-mobile/services/voiceTyping/whisper.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,119 @@ | |||||||
|  | import Setting from '@joplin/lib/models/Setting'; | ||||||
|  | import shim from '@joplin/lib/shim'; | ||||||
|  | import Logger from '@joplin/utils/Logger'; | ||||||
|  | import { rtrimSlashes } from '@joplin/utils/path'; | ||||||
|  | import { dirname, join } from 'path'; | ||||||
|  | import { NativeModules } from 'react-native'; | ||||||
|  | import { SpeechToTextCallbacks, VoiceTypingProvider, VoiceTypingSession } from './VoiceTyping'; | ||||||
|  | import splitWhisperText from './utils/splitWhisperText'; | ||||||
|  |  | ||||||
|  | const logger = Logger.create('voiceTyping/whisper'); | ||||||
|  |  | ||||||
|  | const { SpeechToTextModule } = NativeModules; | ||||||
|  |  | ||||||
|  | // Timestamps are in the form <|0.00|>. They seem to be added: | ||||||
|  | // - After long pauses. | ||||||
|  | // - Between sentences (in pairs). | ||||||
|  | // - At the beginning and end of a sequence. | ||||||
|  | const timestampExp = /<\|(\d+\.\d*)\|>/g; | ||||||
|  | const postProcessSpeech = (text: string) => { | ||||||
|  | 	return text.replace(timestampExp, '').replace(/\[BLANK_AUDIO\]/g, ''); | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | class Whisper implements VoiceTypingSession { | ||||||
|  | 	private lastPreviewData: string; | ||||||
|  | 	private closeCounter = 0; | ||||||
|  |  | ||||||
|  | 	public constructor( | ||||||
|  | 		private sessionId: number|null, | ||||||
|  | 		private callbacks: SpeechToTextCallbacks, | ||||||
|  | 	) { } | ||||||
|  |  | ||||||
|  | 	public async start() { | ||||||
|  | 		if (this.sessionId === null) { | ||||||
|  | 			throw new Error('Session closed.'); | ||||||
|  | 		} | ||||||
|  | 		try { | ||||||
|  | 			logger.debug('starting recorder'); | ||||||
|  | 			await SpeechToTextModule.startRecording(this.sessionId); | ||||||
|  | 			logger.debug('recorder started'); | ||||||
|  |  | ||||||
|  | 			const loopStartCounter = this.closeCounter; | ||||||
|  | 			while (this.closeCounter === loopStartCounter) { | ||||||
|  | 				logger.debug('reading block'); | ||||||
|  | 				const data: string = await SpeechToTextModule.expandBufferAndConvert(this.sessionId, 4); | ||||||
|  | 				logger.debug('done reading block. Length', data?.length); | ||||||
|  |  | ||||||
|  | 				if (this.sessionId === null) { | ||||||
|  | 					logger.debug('Session stopped. Ending inference loop.'); | ||||||
|  | 					return; | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				const recordingLength = await SpeechToTextModule.getBufferLengthSeconds(this.sessionId); | ||||||
|  | 				logger.debug('recording length so far', recordingLength); | ||||||
|  | 				const { trimTo, dataBeforeTrim, dataAfterTrim } = splitWhisperText(data, recordingLength); | ||||||
|  |  | ||||||
|  | 				if (trimTo > 2) { | ||||||
|  | 					logger.debug('Trim to', trimTo, 'in recording with length', recordingLength); | ||||||
|  | 					this.callbacks.onFinalize(postProcessSpeech(dataBeforeTrim)); | ||||||
|  | 					this.callbacks.onPreview(postProcessSpeech(dataAfterTrim)); | ||||||
|  | 					this.lastPreviewData = dataAfterTrim; | ||||||
|  | 					await SpeechToTextModule.dropFirstSeconds(this.sessionId, trimTo); | ||||||
|  | 				} else { | ||||||
|  | 					logger.debug('Preview', data); | ||||||
|  | 					this.lastPreviewData = data; | ||||||
|  | 					this.callbacks.onPreview(postProcessSpeech(data)); | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} catch (error) { | ||||||
|  | 			logger.error('Whisper error:', error); | ||||||
|  | 			this.lastPreviewData = ''; | ||||||
|  | 			await this.stop(); | ||||||
|  | 			throw error; | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	public async stop() { | ||||||
|  | 		if (this.sessionId === null) { | ||||||
|  | 			logger.warn('Session already closed.'); | ||||||
|  | 			return; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		const sessionId = this.sessionId; | ||||||
|  | 		this.sessionId = null; | ||||||
|  | 		this.closeCounter ++; | ||||||
|  | 		await SpeechToTextModule.closeSession(sessionId); | ||||||
|  |  | ||||||
|  | 		if (this.lastPreviewData) { | ||||||
|  | 			this.callbacks.onFinalize(postProcessSpeech(this.lastPreviewData)); | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const modelLocalFilepath = () => { | ||||||
|  | 	return `${shim.fsDriver().getAppDirectoryPath()}/voice-typing-models/whisper_tiny.onnx`; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | const whisper: VoiceTypingProvider = { | ||||||
|  | 	supported: () => !!SpeechToTextModule, | ||||||
|  | 	modelLocalFilepath: modelLocalFilepath, | ||||||
|  | 	getDownloadUrl: () => { | ||||||
|  | 		let urlTemplate = rtrimSlashes(Setting.value('voiceTypingBaseUrl').trim()); | ||||||
|  |  | ||||||
|  | 		if (!urlTemplate) { | ||||||
|  | 			urlTemplate = 'https://github.com/personalizedrefrigerator/joplin-voice-typing-test/releases/download/test-release/{task}.zip'; | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx'); | ||||||
|  | 	}, | ||||||
|  | 	getUuidPath: () => { | ||||||
|  | 		return join(dirname(modelLocalFilepath()), 'uuid'); | ||||||
|  | 	}, | ||||||
|  | 	build: async ({ modelPath, callbacks, locale }) => { | ||||||
|  | 		const sessionId = await SpeechToTextModule.openSession(modelPath, locale); | ||||||
|  | 		return new Whisper(sessionId, callbacks); | ||||||
|  | 	}, | ||||||
|  | 	modelName: 'whisper', | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | export default whisper; | ||||||
| @@ -1615,6 +1615,25 @@ const builtInMetadata = (Setting: typeof SettingType) => { | |||||||
| 			section: 'note', | 			section: 'note', | ||||||
| 		}, | 		}, | ||||||
|  |  | ||||||
|  | 		'voiceTyping.preferredProvider': { | ||||||
|  | 			value: 'whisper-tiny', | ||||||
|  | 			type: SettingItemType.String, | ||||||
|  | 			public: true, | ||||||
|  | 			appTypes: [AppType.Mobile], | ||||||
|  | 			label: () => _('Preferred voice typing provider'), | ||||||
|  | 			isEnum: true, | ||||||
|  | 			// For now, iOS and web don't support voice typing. | ||||||
|  | 			show: () => shim.mobilePlatform() === 'android', | ||||||
|  | 			section: 'note', | ||||||
|  |  | ||||||
|  | 			options: () => { | ||||||
|  | 				return { | ||||||
|  | 					'vosk': _('Vosk'), | ||||||
|  | 					'whisper-tiny': _('Whisper'), | ||||||
|  | 				}; | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  |  | ||||||
| 		'trash.autoDeletionEnabled': { | 		'trash.autoDeletionEnabled': { | ||||||
| 			value: true, | 			value: true, | ||||||
| 			type: SettingItemType.Bool, | 			type: SettingItemType.Bool, | ||||||
|   | |||||||
| @@ -132,4 +132,6 @@ Famegear | |||||||
| rcompare | rcompare | ||||||
| tabindex | tabindex | ||||||
| Backblaze | Backblaze | ||||||
|  | onnx | ||||||
|  | onnxruntime | ||||||
| treeitem | treeitem | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user