mirror of
https://github.com/laurent22/joplin.git
synced 2025-01-23 18:53:36 +02:00
Android: Allow re-downloading voice typing models on URL change and error (#11557)
This commit is contained in:
parent
4d827afccb
commit
bacaf800f2
@ -11,6 +11,9 @@ import { AppState } from '../../utils/types';
|
||||
import { connect } from 'react-redux';
|
||||
import { View, StyleSheet } from 'react-native';
|
||||
import AccessibleView from '../accessibility/AccessibleView';
|
||||
import Logger from '@joplin/utils/Logger';
|
||||
|
||||
const logger = Logger.create('VoiceTypingDialog');
|
||||
|
||||
interface Props {
|
||||
locale: string;
|
||||
@ -34,10 +37,11 @@ interface UseVoiceTypingProps {
|
||||
onText: OnTextCallback;
|
||||
}
|
||||
|
||||
const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps): [Error | null, boolean, VoiceTypingSession|null] => {
|
||||
const useVoiceTyping = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingProps) => {
|
||||
const [voiceTyping, setVoiceTyping] = useState<VoiceTypingSession>(null);
|
||||
const [error, setError] = useState<Error>(null);
|
||||
const [mustDownloadModel, setMustDownloadModel] = useState<boolean | null>(null);
|
||||
const [modelIsOutdated, setModelIsOutdated] = useState(false);
|
||||
|
||||
const onTextRef = useRef(onText);
|
||||
onTextRef.current = onText;
|
||||
@ -51,9 +55,20 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
|
||||
return new VoiceTyping(locale, provider?.startsWith('whisper') ? [whisper] : [vosk]);
|
||||
}, [locale, provider]);
|
||||
|
||||
const [redownloadCounter, setRedownloadCounter] = useState(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (modelIsOutdated) {
|
||||
logger.info('The downloaded version of the model is from an outdated URL.');
|
||||
}
|
||||
}, [modelIsOutdated]);
|
||||
|
||||
useAsyncEffect(async (event: AsyncEffectEvent) => {
|
||||
try {
|
||||
await voiceTypingRef.current?.stop();
|
||||
onSetPreviewRef.current?.('');
|
||||
|
||||
setModelIsOutdated(await builder.isDownloadedFromOutdatedUrl());
|
||||
|
||||
if (!await builder.isDownloaded()) {
|
||||
if (event.cancelled) return;
|
||||
@ -72,7 +87,7 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
|
||||
} finally {
|
||||
setMustDownloadModel(false);
|
||||
}
|
||||
}, [builder]);
|
||||
}, [builder, redownloadCounter]);
|
||||
|
||||
useAsyncEffect(async (_event: AsyncEffectEvent) => {
|
||||
setMustDownloadModel(!(await builder.isDownloaded()));
|
||||
@ -82,7 +97,16 @@ const useWhisper = ({ locale, provider, onSetPreview, onText }: UseVoiceTypingPr
|
||||
void voiceTypingRef.current?.stop();
|
||||
}, []);
|
||||
|
||||
return [error, mustDownloadModel, voiceTyping];
|
||||
const onRequestRedownload = useCallback(async () => {
|
||||
await voiceTypingRef.current?.stop();
|
||||
await builder.clearDownloads();
|
||||
setMustDownloadModel(true);
|
||||
setRedownloadCounter(value => value + 1);
|
||||
}, [builder]);
|
||||
|
||||
return {
|
||||
error, mustDownloadModel, voiceTyping, onRequestRedownload, modelIsOutdated,
|
||||
};
|
||||
};
|
||||
|
||||
const styles = StyleSheet.create({
|
||||
@ -112,7 +136,13 @@ const styles = StyleSheet.create({
|
||||
const VoiceTypingDialog: React.FC<Props> = props => {
|
||||
const [recorderState, setRecorderState] = useState<RecorderState>(RecorderState.Loading);
|
||||
const [preview, setPreview] = useState<string>('');
|
||||
const [modelError, mustDownloadModel, voiceTyping] = useWhisper({
|
||||
const {
|
||||
error: modelError,
|
||||
mustDownloadModel,
|
||||
voiceTyping,
|
||||
onRequestRedownload,
|
||||
modelIsOutdated,
|
||||
} = useVoiceTyping({
|
||||
locale: props.locale,
|
||||
onSetPreview: setPreview,
|
||||
onText: props.onText,
|
||||
@ -172,6 +202,11 @@ const VoiceTypingDialog: React.FC<Props> = props => {
|
||||
return <Text variant='labelSmall'>{preview}</Text>;
|
||||
};
|
||||
|
||||
const reDownloadButton = <Button onPress={onRequestRedownload}>
|
||||
{modelIsOutdated ? _('Download updated model') : _('Re-download model')}
|
||||
</Button>;
|
||||
const allowReDownload = recorderState === RecorderState.Error || modelIsOutdated;
|
||||
|
||||
return (
|
||||
<Surface>
|
||||
<View style={styles.container}>
|
||||
@ -203,6 +238,7 @@ const VoiceTypingDialog: React.FC<Props> = props => {
|
||||
</View>
|
||||
</View>
|
||||
<View style={styles.actionContainer}>
|
||||
{allowReDownload ? reDownloadButton : null}
|
||||
<Button
|
||||
onPress={onDismiss}
|
||||
accessibilityHint={_('Ends voice typing')}
|
||||
|
@ -2,6 +2,7 @@ import shim from '@joplin/lib/shim';
|
||||
import Logger from '@joplin/utils/Logger';
|
||||
import { PermissionsAndroid, Platform } from 'react-native';
|
||||
import unzip from './utils/unzip';
|
||||
import { _ } from '@joplin/lib/locale';
|
||||
const md5 = require('md5');
|
||||
|
||||
const logger = Logger.create('voiceTyping');
|
||||
@ -30,6 +31,7 @@ export interface VoiceTypingProvider {
|
||||
modelName: string;
|
||||
supported(): boolean;
|
||||
modelLocalFilepath(locale: string): string;
|
||||
deleteCachedModels(locale: string): Promise<void>;
|
||||
getDownloadUrl(locale: string): string;
|
||||
getUuidPath(locale: string): string;
|
||||
build(options: BuildProviderOptions): Promise<VoiceTypingSession>;
|
||||
@ -39,9 +41,9 @@ export default class VoiceTyping {
|
||||
private provider: VoiceTypingProvider|null = null;
|
||||
public constructor(
|
||||
private locale: string,
|
||||
providers: VoiceTypingProvider[],
|
||||
allProviders: VoiceTypingProvider[],
|
||||
) {
|
||||
this.provider = providers.find(p => p.supported()) ?? null;
|
||||
this.provider = allProviders.find(p => p.supported()) ?? null;
|
||||
}
|
||||
|
||||
public supported() {
|
||||
@ -67,10 +69,31 @@ export default class VoiceTyping {
|
||||
);
|
||||
}
|
||||
|
||||
public async isDownloadedFromOutdatedUrl() {
|
||||
const uuidPath = this.getUuidPath();
|
||||
if (!await shim.fsDriver().exists(uuidPath)) {
|
||||
// Not downloaded at all
|
||||
return false;
|
||||
}
|
||||
|
||||
const modelUrl = this.provider.getDownloadUrl(this.locale);
|
||||
const urlHash = await shim.fsDriver().readFile(uuidPath);
|
||||
return urlHash.trim() !== md5(modelUrl);
|
||||
}
|
||||
|
||||
public async isDownloaded() {
|
||||
return await shim.fsDriver().exists(this.getUuidPath());
|
||||
}
|
||||
|
||||
public async clearDownloads() {
|
||||
const confirmed = await shim.showConfirmationDialog(
|
||||
_('Delete model and re-download?\nThis cannot be undone.'),
|
||||
);
|
||||
if (confirmed) {
|
||||
await this.provider.deleteCachedModels(this.locale);
|
||||
}
|
||||
}
|
||||
|
||||
public async download() {
|
||||
const modelPath = this.getModelPath();
|
||||
const modelUrl = this.provider.getDownloadUrl(this.locale);
|
||||
@ -104,16 +127,18 @@ export default class VoiceTyping {
|
||||
|
||||
logger.info(`Moving ${fullUnzipPath} => ${modelPath}`);
|
||||
await shim.fsDriver().move(fullUnzipPath, modelPath);
|
||||
|
||||
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
|
||||
if (!await this.isDownloaded()) {
|
||||
logger.warn('Model should be downloaded!');
|
||||
}
|
||||
} finally {
|
||||
await shim.fsDriver().remove(unzipDir);
|
||||
await shim.fsDriver().remove(downloadPath);
|
||||
}
|
||||
}
|
||||
|
||||
await shim.fsDriver().writeFile(this.getUuidPath(), md5(modelUrl), 'utf8');
|
||||
if (!await this.isDownloaded()) {
|
||||
logger.warn('Model should be downloaded!');
|
||||
} else {
|
||||
logger.info('Model stats', await shim.fsDriver().stat(modelPath));
|
||||
}
|
||||
}
|
||||
|
||||
public async build(callbacks: SpeechToTextCallbacks) {
|
||||
|
@ -175,6 +175,10 @@ export const startRecording = (vosk: Vosk, options: StartOptions): VoiceTypingSe
|
||||
const vosk: VoiceTypingProvider = {
|
||||
supported: () => true,
|
||||
modelLocalFilepath: (locale: string) => getModelDir(locale),
|
||||
deleteCachedModels: async (locale: string) => {
|
||||
const path = getModelDir(locale);
|
||||
await shim.fsDriver().remove(path, { recursive: true });
|
||||
},
|
||||
getDownloadUrl: (locale) => languageModelUrl(locale),
|
||||
getUuidPath: (locale: string) => join(getModelDir(locale), 'uuid'),
|
||||
build: async ({ callbacks, locale, modelPath }) => {
|
||||
|
@ -5,6 +5,7 @@ const vosk: VoiceTypingProvider = {
|
||||
modelLocalFilepath: () => null,
|
||||
getDownloadUrl: () => null,
|
||||
getUuidPath: () => null,
|
||||
deleteCachedModels: () => null,
|
||||
build: async () => {
|
||||
throw new Error('Unsupported!');
|
||||
},
|
||||
|
@ -106,6 +106,10 @@ const whisper: VoiceTypingProvider = {
|
||||
|
||||
return urlTemplate.replace(/\{task\}/g, 'whisper_tiny.onnx');
|
||||
},
|
||||
deleteCachedModels: async (locale) => {
|
||||
await shim.fsDriver().remove(modelLocalFilepath());
|
||||
await shim.fsDriver().remove(whisper.getUuidPath(locale));
|
||||
},
|
||||
getUuidPath: () => {
|
||||
return join(dirname(modelLocalFilepath()), 'uuid');
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user