1
0
mirror of https://github.com/laurent22/joplin.git synced 2025-01-23 18:53:36 +02:00

Android: Allow re-downloading voice typing models on URL change and error (#11557)

This commit is contained in:
Henry Heino 2025-01-06 09:33:44 -08:00 committed by GitHub
parent 4d827afccb
commit bacaf800f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 81 additions and 11 deletions

View File

@ -11,6 +11,9 @@ import { AppState } from '../../utils/types';
import { connect } from 'react-redux';
import { View, StyleSheet } from 'react-native';
import AccessibleView from '../accessibility/AccessibleView';
import Logger from '@joplin/utils/Logger';
const logger = Logger.create('VoiceTypingDialog');
interface Props {
locale: string;
@ -34,10 +37,11 @@ interface UseVoiceTypingProps {
onText: OnTextCallback;
}
const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => {
const useVoiceTyping = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps) => {
const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null);
const [error, setError] = useState<Error>(null);
const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null);
const [modelIsOutdated, setModelIsOutdated] = useState(false);
const onTextRef = useRef(onText);
onTextRef.current = onText;
@ -51,9 +55,20 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]);
}, [locale, provider]);
const [redownloadCounter, setRedownloadCounter] = useState(0);
useEffect(() => {
if (modelIsOutdated) {
logger.info('The downloaded version of the model is from an outdated URL.');
}
}, [modelIsOutdated]);
useAsyncEffect(async (event: AsyncEffectEvent) => {
try {
await voiceTypingRef.current?.stop();
onSetPreviewRef.current?.('');
setModelIsOutdated(await builder.isDownloadedFromOutdatedUrl());
if (!await builder.isDownloaded()) {
if (event.cancelled) return;
@ -72,7 +87,7 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
} finally {
setMustDownloadModel(false);
}
}, [builder]);
}, [builder, redownloadCounter]);
useAsyncEffect(async (_event: AsyncEffectEvent) => {
setMustDownloadModel(!(await builder.isDownloaded()));
@ -82,7 +97,16 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
void voiceTypingRef.current?.stop();
}, []);
return [error, mustDownloadModel, voiceTyping];
const onRequestRedownload = useCallback(async () => {
await voiceTypingRef.current?.stop();
await builder.clearDownloads();
setMustDownloadModel(true);
setRedownloadCounter(value => value + 1);
}, [builder]);
return {
error, mustDownloadModel, voiceTyping, onRequestRedownload, modelIsOutdated,
};
};
const styles = StyleSheet.create({
@ -112,7 +136,13 @@ const styles = StyleSheet.create({
const VoiceTypingDialog: React.FC<Props> = props => {
const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading);
const [preview, setPreview] = useState<string>('');
const [modelError, mustDownloadModel, voiceTyping] = useWhisper({
const {
error: modelError,
mustDownloadModel,
voiceTyping,
onRequestRedownload,
modelIsOutdated,
} = useVoiceTyping({
locale: props.locale,
onSetPreview: setPreview,
onText: props.onText,
@ -172,6 +202,11 @@ const VoiceTypingDialog: React.FC<Props> = props => {
return <Text variant='labelSmall'>{preview}</Text>;
};
const reDownloadButton = <Button onPress={onRequestRedownload}>
{modelIsOutdated ? _('Download updated model') : _('Re-download model')}
</Button>;
const allowReDownload = recorderState === RecorderState.Error || modelIsOutdated;
return (
<Surface>
<View style={styles.container}>
@ -203,6 +238,7 @@ const VoiceTypingDialog: React.FC<Props> = props => {
</View>
</View>
<View style={styles.actionContainer}>
{allowReDownload ? reDownloadButton : null}
<Button
onPress={onDismiss}
accessibilityHint={_('Ends voice typing')}

View File

@ -2,6 +2,7 @@ import shim from '@joplin/lib/shim';
import Logger from '@joplin/utils/Logger';
import { PermissionsAndroid, Platform } from 'react-native';
import unzip from './utils/unzip';
import { _ } from '@joplin/lib/locale';
const md5 = require('md5');
const logger = Logger.create('voiceTyping');
@ -30,6 +31,7 @@ export interface VoiceTypingProvider {
modelName: string;
supported(): boolean;
modelLocalFilepath(locale: string): string;
deleteCachedModels(locale: string): Promise<void>;
getDownloadUrl(locale: string): string;
getUuidPath(locale: string): string;
build(options: BuildProviderOptions): Promise<VoiceTypingSession>;
@ -39,9 +41,9 @@ export default class VoiceTyping {
private provider: VoiceTypingProvider|null = null;
public constructor(
private locale: string,
providers: VoiceTypingProvider[],
allProviders: VoiceTypingProvider[],
) {
this.provider = providers.find(p => p.supported()) ?? null;
this.provider = allProviders.find(p => p.supported()) ?? null;
}
public supported() {
@ -67,10 +69,31 @@ export default class VoiceTyping {
);
}
public async isDownloadedFromOutdatedUrl() {
const uuidPath = this.getUuidPath();
if (!await shim.fsDriver().exists(uuidPath)) {
// Not downloaded at all
return false;
}
const modelUrl = this.provider.getDownloadUrl(this.locale);
const urlHash = await shim.fsDriver().readFile(uuidPath);
return urlHash.trim() !== md5(modelUrl);
}
public async isDownloaded() {
return await shim.fsDriver().exists(this.getUuidPath());
}
public async clearDownloads() {
const confirmed = await shim.showConfirmationDialog(
_('Delete model and re-download?\nThis cannot be undone.'),
);
if (confirmed) {
await this.provider.deleteCachedModels(this.locale);
}
}
public async download() {
const modelPath = this.getModelPath();
const modelUrl = this.provider.getDownloadUrl(this.locale);
@ -104,16 +127,18 @@ export default class VoiceTyping {
logger.info(`Moving ${fullUnzipPath} => ${modelPath}`);
await shim.fsDriver().move(fullUnzipPath, modelPath);
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
if (!await this.isDownloaded()) {
logger.warn('Model should be downloaded!');
}
} finally {
await shim.fsDriver().remove(unzipDir);
await shim.fsDriver().remove(downloadPath);
}
}
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
if (!await this.isDownloaded()) {
logger.warn('Model should be downloaded!');
} else {
logger.info('Model stats', await shim.fsDriver().stat(modelPath));
}
}
public async build(callbacks: SpeechToTextCallbacks) {

View File

@ -175,6 +175,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSe
const vosk: VoiceTypingProvider = {
supported: () => true,
modelLocalFilepath: (locale: string) => getModelDir(locale),
deleteCachedModels: async (locale: string) => {
const path = getModelDir(locale);
await shim.fsDriver().remove(path, { recursive: true });
},
getDownloadUrl: (locale) => languageModelUrl(locale),
getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'),
build: async ({ callbacks, locale, modelPath }) => {

View File

@ -5,6 +5,7 @@ const vosk: VoiceTypingProvider = {
modelLocalFilepath: () => null,
getDownloadUrl: () => null,
getUuidPath: () => null,
deleteCachedModels: () => null,
build: async () => {
throw new Error('Unsupported!');
},

View File

@ -106,6 +106,10 @@ const whisper: VoiceTypingProvider = {
return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx');
},
deleteCachedModels: async (locale) => {
await shim.fsDriver().remove(modelLocalFilepath());
await shim.fsDriver().remove(whisper.getUuidPath(locale));
},
getUuidPath: () => {
return join(dirname(modelLocalFilepath()), 'uuid');
},