diff --git a/.eslintignore b/.eslintignore index e1a62ea017..025d324ce3 100644 --- a/.eslintignore +++ b/.eslintignore @@ -1344,6 +1344,7 @@ packages/lib/services/database/migrations/44.js packages/lib/services/database/migrations/45.js packages/lib/services/database/migrations/46.js packages/lib/services/database/migrations/47.js +packages/lib/services/database/migrations/48.js packages/lib/services/database/migrations/index.js packages/lib/services/database/sqlStringToLines.js packages/lib/services/database/types.js @@ -1412,6 +1413,8 @@ packages/lib/services/ocr/OcrDriverBase.js packages/lib/services/ocr/OcrService.test.js packages/lib/services/ocr/OcrService.js packages/lib/services/ocr/drivers/OcrDriverTesseract.js +packages/lib/services/ocr/drivers/OcrDriverTranscribe.test.js +packages/lib/services/ocr/drivers/OcrDriverTranscribe.js packages/lib/services/ocr/utils/filterOcrText.test.js packages/lib/services/ocr/utils/filterOcrText.js packages/lib/services/ocr/utils/types.js diff --git a/.gitignore b/.gitignore index 1a9ec7e911..a3dd43499a 100644 --- a/.gitignore +++ b/.gitignore @@ -1317,6 +1317,7 @@ packages/lib/services/database/migrations/44.js packages/lib/services/database/migrations/45.js packages/lib/services/database/migrations/46.js packages/lib/services/database/migrations/47.js +packages/lib/services/database/migrations/48.js packages/lib/services/database/migrations/index.js packages/lib/services/database/sqlStringToLines.js packages/lib/services/database/types.js @@ -1385,6 +1386,8 @@ packages/lib/services/ocr/OcrDriverBase.js packages/lib/services/ocr/OcrService.test.js packages/lib/services/ocr/OcrService.js packages/lib/services/ocr/drivers/OcrDriverTesseract.js +packages/lib/services/ocr/drivers/OcrDriverTranscribe.test.js +packages/lib/services/ocr/drivers/OcrDriverTranscribe.js packages/lib/services/ocr/utils/filterOcrText.test.js packages/lib/services/ocr/utils/filterOcrText.js packages/lib/services/ocr/utils/types.js diff --git a/packages/app-desktop/app.ts b/packages/app-desktop/app.ts index c3727b1639..61e6b4db23 100644 --- a/packages/app-desktop/app.ts +++ b/packages/app-desktop/app.ts @@ -55,11 +55,13 @@ import userFetcher, { initializeUserFetcher } from '@joplin/lib/utils/userFetche import { parseNotesParent } from '@joplin/lib/reducer'; import OcrService from '@joplin/lib/services/ocr/OcrService'; import OcrDriverTesseract from '@joplin/lib/services/ocr/drivers/OcrDriverTesseract'; +import OcrDriverTranscribe from '@joplin/lib/services/ocr/drivers/OcrDriverTranscribe'; import SearchEngine from '@joplin/lib/services/search/SearchEngine'; import { PackageInfo } from '@joplin/lib/versionInfo'; import { CustomProtocolHandler } from './utils/customProtocols/handleCustomProtocols'; import { refreshFolders } from '@joplin/lib/folders-screen-utils'; import initializeCommandService from './utils/initializeCommandService'; +import OcrDriverBase from '@joplin/lib/services/ocr/OcrDriverBase'; import PerformanceLogger from '@joplin/lib/PerformanceLogger'; const perfLogger = PerformanceLogger.create('app-desktop/app'); @@ -353,16 +355,19 @@ class Application extends BaseApplication { // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied const Tesseract = (window as any).Tesseract; - const driver = new OcrDriverTesseract( + const drivers: OcrDriverBase[] = []; + drivers.push(new OcrDriverTesseract( { createWorker: Tesseract.createWorker }, { workerPath: `${bridge().buildDir()}/tesseract.js/worker.min.js`, corePath: `${bridge().buildDir()}/tesseract.js-core`, languageDataPath: Setting.value('ocr.languageDataPath') || null, }, - ); + )); - this.ocrService_ = new OcrService(driver); + drivers.push(new OcrDriverTranscribe()); + + this.ocrService_ = new OcrService(drivers); } void this.ocrService_.runInBackground(); diff --git a/packages/app-desktop/gui/NoteEditor/utils/contextMenu.ts b/packages/app-desktop/gui/NoteEditor/utils/contextMenu.ts index f1fe8559ff..cef133e96d 100644 --- a/packages/app-desktop/gui/NoteEditor/utils/contextMenu.ts +++ b/packages/app-desktop/gui/NoteEditor/utils/contextMenu.ts @@ -8,14 +8,15 @@ const MenuItem = bridge().MenuItem; import Resource, { resourceOcrStatusToString } from '@joplin/lib/models/Resource'; import BaseItem from '@joplin/lib/models/BaseItem'; import BaseModel, { ModelType } from '@joplin/lib/BaseModel'; -import { NoteEntity, ResourceEntity, ResourceOcrStatus } from '@joplin/lib/services/database/types'; +import { NoteEntity, ResourceEntity, ResourceOcrDriverId, ResourceOcrStatus } from '@joplin/lib/services/database/types'; import { TinyMceEditorEvents } from '../NoteBody/TinyMCE/utils/types'; import { itemIsReadOnlySync, ItemSlice } from '@joplin/lib/models/utils/readOnly'; import Setting from '@joplin/lib/models/Setting'; import ItemChange from '@joplin/lib/models/ItemChange'; -import shim from '@joplin/lib/shim'; +import shim, { MessageBoxType } from '@joplin/lib/shim'; import { openFileWithExternalEditor } from '@joplin/lib/services/ExternalEditWatcher/utils'; import CommandService from '@joplin/lib/services/CommandService'; +import SyncTargetRegistry from '@joplin/lib/SyncTargetRegistry'; const fs = require('fs-extra'); const { writeFile } = require('fs-extra'); const { clipboard } = require('electron'); @@ -137,6 +138,40 @@ export function menuItems(dispatch: Function): ContextMenuItems { }, isActive: (itemType: ContextMenuItemType, options: ContextMenuOptions) => !!options.textToCopy && itemType === ContextMenuItemType.Image && options.mime?.startsWith('image/svg'), }, + recognizeHandwrittenImage: { + label: _('Recognize handwritten image'), + onAction: async (options: ContextMenuOptions) => { + const syncTargetId = Setting.value('sync.target'); + if (!SyncTargetRegistry.isJoplinServerOrCloud(syncTargetId)) { + await shim.showMessageBox(_('This feature is only available on Joplin Cloud and Joplin Server.'), { type: MessageBoxType.Error }); + return; + } + + if (!Setting.value('ocr.handwrittenTextDriverEnabled')) { + await shim.showMessageBox(_('This feature is disabled by default, you need to manually enable it by turning on the option to \'Enable handwritten transcription\'.'), { type: MessageBoxType.Error }); + return; + } + + const { resource } = await resourceInfo(options); + + if (!['image/png', 'image/jpg', 'image/jpeg', 'image/bmp'].includes(resource.mime)) { + await shim.showMessageBox(_('This image type is not supported by the recognition system.'), { type: MessageBoxType.Error }); + return; + } + + await Resource.save({ + id: resource.id, + ocr_status: ResourceOcrStatus.Todo, + ocr_driver_id: ResourceOcrDriverId.HandwrittenText, + ocr_details: '', + ocr_error: '', + ocr_text: '', + }); + }, + isActive: (itemType: ContextMenuItemType, options: ContextMenuOptions) => { + return itemType === ContextMenuItemType.Resource || (itemType === ContextMenuItemType.Image && options.resourceId); + }, + }, revealInFolder: { label: _('Reveal file in folder'), onAction: async (options: ContextMenuOptions) => { diff --git a/packages/lib/SyncTargetRegistry.ts b/packages/lib/SyncTargetRegistry.ts index 043a72ae51..96b6e9ea3e 100644 --- a/packages/lib/SyncTargetRegistry.ts +++ b/packages/lib/SyncTargetRegistry.ts @@ -97,4 +97,12 @@ export default class SyncTargetRegistry { ]; } + public static isJoplinServerOrCloud(id: number) { + return [ + SyncTargetRegistry.nameToId('joplinServer'), + SyncTargetRegistry.nameToId('joplinCloud'), + SyncTargetRegistry.nameToId('joplinServerSaml'), + ].includes(id); + } + } diff --git a/packages/lib/models/Resource.ts b/packages/lib/models/Resource.ts index 43f5d634f9..ee02defe5c 100644 --- a/packages/lib/models/Resource.ts +++ b/packages/lib/models/Resource.ts @@ -518,12 +518,13 @@ export default class Resource extends BaseItem { SELECT ${selectSql} FROM resources WHERE - ocr_status = ? AND + (ocr_status = ? or ocr_status = ?) AND encryption_applied = 0 AND mime IN ('${supportedMimeTypes.join('\',\'')}') `, params: [ ResourceOcrStatus.Todo, + ResourceOcrStatus.Processing, ], }; } diff --git a/packages/lib/models/settings/builtInMetadata.ts b/packages/lib/models/settings/builtInMetadata.ts index de202ae01b..850e493086 100644 --- a/packages/lib/models/settings/builtInMetadata.ts +++ b/packages/lib/models/settings/builtInMetadata.ts @@ -556,6 +556,16 @@ const builtInMetadata = (Setting: typeof SettingType) => { isGlobal: true, }, + 'ocr.handwrittenTextDriverEnabled': { + value: true, + type: SettingItemType.Bool, + public: true, + appTypes: [AppType.Desktop], + label: () => _('Enable handwritten transcription'), + storage: SettingStorage.File, + isGlobal: true, + }, + 'ocr.languageDataPath': { value: '', type: SettingItemType.String, diff --git a/packages/lib/services/database/migrations/48.ts b/packages/lib/services/database/migrations/48.ts new file mode 100644 index 0000000000..31a3288b9a --- /dev/null +++ b/packages/lib/services/database/migrations/48.ts @@ -0,0 +1,7 @@ +import { SqlQuery } from '../types'; + +export default (): (SqlQuery|string)[] => { + return [ + 'ALTER TABLE `resources` ADD COLUMN `ocr_driver_id` INT NOT NULL DEFAULT "1"', + ]; +}; diff --git a/packages/lib/services/database/migrations/index.ts b/packages/lib/services/database/migrations/index.ts index 7f0a5f1704..0f21cccf66 100644 --- a/packages/lib/services/database/migrations/index.ts +++ b/packages/lib/services/database/migrations/index.ts @@ -5,6 +5,7 @@ import migration44 from './44'; import migration45 from './45'; import migration46 from './46'; import migration47 from './47'; +import migration48 from './48'; import { Migration } from '../types'; @@ -15,6 +16,7 @@ const index: Migration[] = [ migration45, migration46, migration47, + migration48, ]; export default index; diff --git a/packages/lib/services/database/types.ts b/packages/lib/services/database/types.ts index fe48d33c92..634bc557b1 100644 --- a/packages/lib/services/database/types.ts +++ b/packages/lib/services/database/types.ts @@ -76,6 +76,11 @@ interface DatabaseTables { [key: string]: DatabaseTable; } +export enum ResourceOcrDriverId { + PrintedText = 1, + HandwrittenText = 2, +} + // AUTO-GENERATED BY packages/tools/generate-database-types.js /* @@ -283,6 +288,7 @@ export interface ResourceEntity { 'master_key_id'?: string; 'mime'?: string; 'ocr_details'?: string; + 'ocr_driver_id'?: number; 'ocr_error'?: string; 'ocr_status'?: number; 'ocr_text'?: string; @@ -330,9 +336,9 @@ export interface SyncItemEntity { 'item_type'?: number; 'sync_disabled'?: number; 'sync_disabled_reason'?: string; - 'sync_warning_ignored'?: number; 'sync_target'?: number; 'sync_time'?: number; + 'sync_warning_ignored'?: number; 'type_'?: number; } export interface TableFieldEntity { @@ -435,9 +441,9 @@ export const databaseSchema: DatabaseTables = { item_type: { type: 'number' }, sync_disabled: { type: 'number' }, sync_disabled_reason: { type: 'string' }, - sync_warning_ignored: { type: 'number' }, sync_target: { type: 'number' }, sync_time: { type: 'number' }, + sync_warning_ignored: { type: 'number' }, type_: { type: 'number' }, }, version: { @@ -502,6 +508,7 @@ export const databaseSchema: DatabaseTables = { master_key_id: { type: 'string' }, mime: { type: 'string' }, ocr_details: { type: 'string' }, + ocr_driver_id: { type: 'number' }, ocr_error: { type: 'string' }, ocr_status: { type: 'number' }, ocr_text: { type: 'string' }, diff --git a/packages/lib/services/ocr/OcrDriverBase.ts b/packages/lib/services/ocr/OcrDriverBase.ts index 90c4418bc9..88cb905be6 100644 --- a/packages/lib/services/ocr/OcrDriverBase.ts +++ b/packages/lib/services/ocr/OcrDriverBase.ts @@ -1,11 +1,16 @@ +import { ResourceOcrDriverId } from '../database/types'; import { RecognizeResult } from './utils/types'; export default class OcrDriverBase { - public async recognize(_language: string, _filePath: string): Promise { + public async recognize(_language: string, _filePath: string, _id: string): Promise { throw new Error('Not implemented'); } public async dispose(): Promise {} + public get driverId() { + return ResourceOcrDriverId.PrintedText; + } + } diff --git a/packages/lib/services/ocr/OcrService.ts b/packages/lib/services/ocr/OcrService.ts index 5b21641600..620f53a906 100644 --- a/packages/lib/services/ocr/OcrService.ts +++ b/packages/lib/services/ocr/OcrService.ts @@ -2,12 +2,11 @@ import { toIso639Alpha3 } from '../../locale'; import Resource from '../../models/Resource'; import Setting from '../../models/Setting'; import shim from '../../shim'; -import { ResourceEntity, ResourceOcrStatus } from '../database/types'; +import { ResourceEntity, ResourceOcrDriverId, ResourceOcrStatus } from '../database/types'; import OcrDriverBase from './OcrDriverBase'; -import { RecognizeResult } from './utils/types'; +import { emptyRecognizeResult, RecognizeResult } from './utils/types'; import { Minute } from '@joplin/utils/time'; import Logger from '@joplin/utils/Logger'; -import filterOcrText from './utils/filterOcrText'; import TaskQueue from '../../TaskQueue'; import eventManager, { EventName } from '../../eventManager'; @@ -30,19 +29,24 @@ const resourceInfo = (resource: ResourceEntity) => { export default class OcrService { - private driver_: OcrDriverBase; + private drivers_: OcrDriverBase[]; private isRunningInBackground_ = false; // eslint-disable-next-line @typescript-eslint/no-explicit-any -- Old code before rule was applied private maintenanceTimer_: any = null; private pdfExtractDir_: string = null; private isProcessingResources_ = false; - private recognizeQueue_: TaskQueue = null; + private printedTextQueue_: TaskQueue = null; + private handwrittenTextQueue_: TaskQueue = null; - public constructor(driver: OcrDriverBase) { - this.driver_ = driver; - this.recognizeQueue_ = new TaskQueue('recognize', logger); - this.recognizeQueue_.setConcurrency(5); - this.recognizeQueue_.keepTaskResults = false; + public constructor(drivers: OcrDriverBase[]) { + this.drivers_ = drivers; + this.printedTextQueue_ = new TaskQueue('printed', logger); + this.printedTextQueue_.setConcurrency(5); + this.printedTextQueue_.keepTaskResults = false; + + this.handwrittenTextQueue_ = new TaskQueue('handwritten', logger); + this.handwrittenTextQueue_.setConcurrency(1); + this.handwrittenTextQueue_.keepTaskResults = false; } private async pdfExtractDir(): Promise { @@ -62,6 +66,9 @@ export default class OcrService { const resourceFilePath = Resource.fullPath(resource); + const driver = this.drivers_.find(d => d.driverId === resource.ocr_driver_id); + if (!driver) throw new Error(`Unknown driver ID: ${resource.ocr_driver_id}`); + if (resource.mime === 'application/pdf') { // OCR can be slow for large PDFs. // Skip it if the PDF already includes text. @@ -70,7 +77,9 @@ export default class OcrService { if (pagesWithText.length > 0) { return { - text: pageTexts.join('\n'), + ...emptyRecognizeResult(), + ocr_status: ResourceOcrStatus.Done, + ocr_text: pageTexts.join('\n'), }; } @@ -80,7 +89,7 @@ export default class OcrService { let pageIndex = 0; for (const imageFilePath of imageFilePaths) { logger.info(`Recognize: ${resourceInfo(resource)}: Processing PDF page ${pageIndex + 1} / ${imageFilePaths.length}...`); - results.push(await this.driver_.recognize(language, imageFilePath)); + results.push(await driver.recognize(language, imageFilePath, resource.id)); pageIndex++; } @@ -89,15 +98,19 @@ export default class OcrService { } return { - text: results.map(r => r.text).join('\n'), + ...emptyRecognizeResult(), + ocr_status: ResourceOcrStatus.Done, + ocr_text: results.map(r => r.ocr_text).join('\n'), }; } else { - return this.driver_.recognize(language, resourceFilePath); + return driver.recognize(language, resourceFilePath, resource.id); } } public async dispose() { - await this.driver_.dispose(); + for (const d of this.drivers_) { + await d.dispose(); + } } public async processResources() { @@ -115,7 +128,7 @@ export default class OcrService { return async () => { logger.info(`Processing resource ${totalProcessed + 1} / ${totalResourcesToProcess}: ${resourceInfo(resource)}...`); - const toSave: ResourceEntity = { + let toSave: ResourceEntity = { id: resource.id, }; @@ -132,11 +145,11 @@ export default class OcrService { return; } - const result = await this.recognize(language, resource); - toSave.ocr_status = ResourceOcrStatus.Done; - toSave.ocr_text = filterOcrText(result.text); - toSave.ocr_details = Resource.serializeOcrDetails(result.lines); - toSave.ocr_error = ''; + const recognizeResult = await this.recognize(language, resource); + toSave = { + ...toSave, + ...recognizeResult, + }; } catch (error) { const errorMessage = typeof error === 'string' ? error : error?.message; logger.warn(`Could not process resource ${resourceInfo(resource)}`, error); @@ -162,18 +175,29 @@ export default class OcrService { 'mime', 'file_extension', 'encryption_applied', + 'ocr_driver_id', ], }); if (!resources.length) break; - for (const resource of resources) { + const ocrResources = resources.filter(r => r.ocr_driver_id === ResourceOcrDriverId.PrintedText); + + for (const resource of ocrResources) { inProcessResourceIds.push(resource.id); - await this.recognizeQueue_.pushAsync(resource.id, makeQueueAction(totalProcessed++, language, resource)); + await this.printedTextQueue_.pushAsync(resource.id, makeQueueAction(totalProcessed++, language, resource)); + } + + const htrResources = resources.filter(r => r.ocr_driver_id === ResourceOcrDriverId.HandwrittenText); + + for (const resource of htrResources) { + inProcessResourceIds.push(resource.id); + await this.handwrittenTextQueue_.pushAsync(resource.id, makeQueueAction(totalProcessed++, language, resource)); } } - await this.recognizeQueue_.waitForAll(); + await this.printedTextQueue_.waitForAll(); + await this.handwrittenTextQueue_.waitForAll(); if (totalProcessed) { eventManager.emit(EventName.OcrServiceResourcesProcessed); @@ -212,7 +236,8 @@ export default class OcrService { if (this.maintenanceTimer_) shim.clearInterval(this.maintenanceTimer_); this.maintenanceTimer_ = null; this.isRunningInBackground_ = false; - await this.recognizeQueue_.stop(); + await this.printedTextQueue_.stop(); + await this.handwrittenTextQueue_.stop(); } } diff --git a/packages/lib/services/ocr/drivers/OcrDriverTesseract.ts b/packages/lib/services/ocr/drivers/OcrDriverTesseract.ts index e2792ee3fd..ceefc605fe 100644 --- a/packages/lib/services/ocr/drivers/OcrDriverTesseract.ts +++ b/packages/lib/services/ocr/drivers/OcrDriverTesseract.ts @@ -4,6 +4,9 @@ import OcrDriverBase from '../OcrDriverBase'; import { Minute } from '@joplin/utils/time'; import shim from '../../../shim'; import Logger from '@joplin/utils/Logger'; +import filterOcrText from '../utils/filterOcrText'; +import Resource from '../../../models/Resource'; +import { ResourceOcrDriverId, ResourceOcrStatus } from '../../database/types'; const logger = Logger.create('OcrDriverTesseract'); @@ -55,6 +58,10 @@ export default class OcrDriverTesseract extends OcrDriverBase { this.languageDataPath_ = languageDataPath; } + public get driverId() { + return ResourceOcrDriverId.PrintedText; + } + public static async clearLanguageDataCache() { if (typeof indexedDB === 'undefined') { throw new Error('Missing indexedDB access!'); @@ -224,8 +231,10 @@ export default class OcrDriverTesseract extends OcrDriverBase { // Note that Tesseract provides a `.text` property too, but it's the // concatenation of all lines, even those with a low confidence // score, so we recreate it here based on the good lines. - text: goodParagraphs.map(p => p.text).join('\n'), - lines: goodLines, + ocr_text: filterOcrText(goodParagraphs.map(p => p.text).join('\n')), + ocr_details: Resource.serializeOcrDetails(goodLines), + ocr_status: ResourceOcrStatus.Done, + ocr_error: '', }); }); } diff --git a/packages/lib/services/ocr/drivers/OcrDriverTranscribe.test.ts b/packages/lib/services/ocr/drivers/OcrDriverTranscribe.test.ts new file mode 100644 index 0000000000..75a0488aad --- /dev/null +++ b/packages/lib/services/ocr/drivers/OcrDriverTranscribe.test.ts @@ -0,0 +1,107 @@ +import Setting from '../../../models/Setting'; +import { createNoteAndResource, setupDatabaseAndSynchronizer, switchClient } from '../../../testing/test-utils'; +import { ResourceOcrStatus } from '../../database/types'; +import OcrDriverTranscribe from './OcrDriverTranscribe'; +import { reg } from '../../../registry'; + +type JobGenerated = { jobId: string }; +type GetResultPending = { state: string; jobId: string }; +type GetResultCompleted = { state: 'completed'; jobId: string; output: { result: string } }; +type GetResultFailed = { state: 'failed'; jobId: string; output: { stack: string; message: string } }; + +type Response = JobGenerated | GetResultPending | GetResultCompleted | GetResultFailed | Error; + +interface MockApi { + exec: jest.MockedFunction<( + method: string, + path: string, + query?: unknown, + body?: unknown, + headers?: Record, + options?: Record + )=> Promise>; +} + +describe('OcrDriverTranscribe', () => { + let mockApi: MockApi; + + beforeEach(async () => { + await setupDatabaseAndSynchronizer(1); + await switchClient(1); + + mockApi = { + exec: jest.fn(), + }; + + const mockApiMethod = jest.fn().mockResolvedValue(mockApi); + const mockDriver = { api: mockApiMethod }; + const mockFileApi = { driver: jest.fn().mockReturnValue(mockDriver) }; + const mockSyncTarget = { fileApi: jest.fn().mockResolvedValue(mockFileApi) }; + + reg.syncTarget = jest.fn().mockReturnValue(mockSyncTarget); + }); + + it('should return an error if synchronization target is not set', async () => { + const { resource } = await createNoteAndResource(); + const htr = new OcrDriverTranscribe(); + const response = await htr.recognize('', 'mock-path', resource.id); + + expect(response.ocr_status).toBe(ResourceOcrStatus.Error); + }); + + it('should return correct response when successful', async () => { + const { resource } = await createNoteAndResource(); + + mockApi.exec.mockResolvedValue(Promise.resolve({ jobId: 'not-a-real-job-id' })); + mockApi.exec.mockResolvedValue(Promise.resolve({ state: 'pending', jobId: 'not-a-real-job-id' })); + mockApi.exec.mockResolvedValue(Promise.resolve({ state: 'completed', jobId: 'not-a-real-job-id', output: { result: 'this is the final transcription' } })); + + const htr = new OcrDriverTranscribe([1]); + Setting.setValue('sync.target', 9); + + const response = await htr.recognize('', resource.filename, resource.id); + + expect(response.ocr_status).toBe(ResourceOcrStatus.Done); + expect(response.ocr_text).toBe('this is the final transcription'); + }); + + it('should return error when unsuccessful', async () => { + const { resource } = await createNoteAndResource(); + + mockApi.exec.mockResolvedValue(Promise.resolve({ jobId: 'not-a-real-job-id' })); + mockApi.exec.mockResolvedValue(Promise.resolve({ state: 'failed', jobId: 'not-a-real-job-id', output: { stack: '', message: 'Something went wrong' } })); + + const htr = new OcrDriverTranscribe([1]); + Setting.setValue('sync.target', 9); + + const response = await htr.recognize('', resource.filename, resource.id); + + expect(response.ocr_status).toBe(ResourceOcrStatus.Error); + expect(response.ocr_error).toEqual({ stack: '', message: 'Something went wrong' }); + }); + + it('should be able to retrieve jobId from database instead of creating a new job', async () => { + const { resource } = await createNoteAndResource(); + const jobId = 'jobIdThat should be reused latter'; + + mockApi.exec.mockResolvedValue(Promise.resolve({ jobId })); + mockApi.exec.mockImplementationOnce(() => { throw new Error('Network request failed'); }); + + const htr = new OcrDriverTranscribe([1]); + Setting.setValue('sync.target', 9); + + const response = await htr.recognize('', resource.filename, resource.id); + await htr.dispose(); + expect(response.ocr_status).toBe(ResourceOcrStatus.Todo); + expect(response.ocr_error).toBe(''); + + // Simulating closing/opening application + mockApi.exec.mockResolvedValue({ jobId, state: 'completed', output: { result: 'result' } }); + const htr2 = new OcrDriverTranscribe([1]); + + const response2 = await htr2.recognize('', resource.filename, resource.id); + expect(response2.ocr_status).toBe(ResourceOcrStatus.Done); + expect(response2.ocr_text).toBe('result'); + + }); +}); diff --git a/packages/lib/services/ocr/drivers/OcrDriverTranscribe.ts b/packages/lib/services/ocr/drivers/OcrDriverTranscribe.ts new file mode 100644 index 0000000000..412518811b --- /dev/null +++ b/packages/lib/services/ocr/drivers/OcrDriverTranscribe.ts @@ -0,0 +1,134 @@ +import { emptyRecognizeResult, RecognizeResult } from '../utils/types'; +import OcrDriverBase from '../OcrDriverBase'; +import Logger from '@joplin/utils/Logger'; +import { ResourceOcrDriverId, ResourceOcrStatus } from '../../database/types'; +import KvStore from '../../KvStore'; +import shim from '../../../shim'; +import { msleep } from '@joplin/utils/time'; +import Resource from '../../../models/Resource'; +import { reg } from '../../../registry'; + +const logger = Logger.create('OcrDriverTranscribe'); + +type CreateJobResult = { jobId: string }; + +export default class OcrDriverTranscribe extends OcrDriverBase { + + private retryIntervals_ = [10 * 1000, 15 * 1000, 30 * 1000, 60 * 1000]; + private jobIdKeyPrefix_ = 'OcrDriverTranscribe::JobId::'; + private disposed_ = false; + + public constructor(interval?: number[]) { + super(); + this.retryIntervals_ = interval ?? this.retryIntervals_; + } + + public get driverId() { + return ResourceOcrDriverId.HandwrittenText; + } + + public async recognize(_language: string, filePath: string, resourceId: string): Promise { + logger.info(`${resourceId}: Starting to recognize resource from ${filePath}`); + + const key = `${this.jobIdKeyPrefix_}${resourceId}`; + let jobId = await KvStore.instance().value(key); + + try { + if (!jobId) { + await Resource.save({ + id: resourceId, + ocr_status: ResourceOcrStatus.Processing, + }); + logger.info(`${resourceId}: Job does not exist yet, creating...`); + jobId = await this.queueJob(filePath, resourceId); + + logger.info(`${resourceId}: Job created, reference: ${jobId}`); + await KvStore.instance().setValue(key, jobId); + } + + const ocrResult = await this.checkJobIsFinished(jobId, resourceId); + await KvStore.instance().deleteValue(key); + + return { + ...emptyRecognizeResult(), + ...ocrResult, + }; + } catch (error) { + if (shim.fetchRequestCanBeRetried(error) || error.code === 503) { + return emptyRecognizeResult(); + } + await KvStore.instance().deleteValue(key); + return { + ...emptyRecognizeResult(), + ocr_status: ResourceOcrStatus.Error, + ocr_error: error.message, + }; + } + } + + private async queueJob(filePath: string, resourceId: string) { + const api = await this.api(); + + const result: CreateJobResult = await api.exec('POST', 'api/transcribe', null, null, { + 'Content-Type': 'application/octet-stream', + }, { path: filePath, source: 'file' }); + + logger.info(`${resourceId}: Job queued`); + return result.jobId; + } + + private async checkJobIsFinished(jobId: string, resourceId: string) { + logger.info(`${resourceId}: Checking if job is finished...`); + let i = 0; + while (true) { + if (this.disposed_) break; + + const api = await this.api(); + + const response = await api.exec('GET', `api/transcribe/${jobId}`); + + if (this.disposed_) break; + + if (response.state === 'completed') { + logger.info(`${resourceId}: Finished.`); + return { + ocr_status: ResourceOcrStatus.Done, + ocr_text: response.output.result, + }; + } else if (response.state === 'failed') { + logger.info(`${resourceId}: Failed.`); + return { + ocr_status: ResourceOcrStatus.Error, + ocr_error: response.output, + }; + } + + logger.info(`${resourceId}: Job not finished yet, waiting... ${this.getInterval(i)}`); + await msleep(this.getInterval(i)); + i += 1; + } + + return { + ocr_status: ResourceOcrStatus.Error, + ocr_error: 'OcrDriverTranscribe was stopped while waiting for a transcription', + }; + } + + private getInterval(index: number) { + if (index >= this.retryIntervals_.length) { + return this.retryIntervals_[this.retryIntervals_.length - 1]; + } + return this.retryIntervals_[index]; + } + + private async api() { + const fileApi = await reg.syncTarget().fileApi(); + return fileApi.driver().api(); + } + + public dispose() { + this.disposed_ = true; + return Promise.resolve(); + } + +} diff --git a/packages/lib/services/ocr/utils/types.ts b/packages/lib/services/ocr/utils/types.ts index 0cdfd1dc2d..83e6be0c1c 100644 --- a/packages/lib/services/ocr/utils/types.ts +++ b/packages/lib/services/ocr/utils/types.ts @@ -1,7 +1,11 @@ +import { ResourceOcrStatus } from '../../database/types'; + export const emptyRecognizeResult = (): RecognizeResult => { return { - text: '', - lines: [], + ocr_status: ResourceOcrStatus.Todo, + ocr_text: '', + ocr_details: '', + ocr_error: '', }; }; @@ -18,6 +22,8 @@ export interface RecognizeResultLine { } export interface RecognizeResult { - text: string; - lines?: RecognizeResultLine[]; // We do not store detailed data for PDFs + ocr_status: ResourceOcrStatus; + ocr_text: string; + ocr_details: string; + ocr_error: string; } diff --git a/packages/lib/testing/test-utils.ts b/packages/lib/testing/test-utils.ts index 24067e4c4b..dbd9cd1dc1 100644 --- a/packages/lib/testing/test-utils.ts +++ b/packages/lib/testing/test-utils.ts @@ -1116,7 +1116,7 @@ const simulateReadOnlyShareEnv = (shareIds: string[]|string, store?: Store) => { export const newOcrService = () => { const driver = new OcrDriverTesseract({ createWorker }, { workerPath: null, corePath: null, languageDataPath: null }); - return new OcrService(driver); + return new OcrService([driver]); }; export const mockMobilePlatform = (platform: string) => { diff --git a/packages/server/assets/tests/htr_example.png b/packages/server/assets/tests/htr_example.png new file mode 100644 index 0000000000..8c0335f5eb Binary files /dev/null and b/packages/server/assets/tests/htr_example.png differ diff --git a/packages/server/src/env.ts b/packages/server/src/env.ts index f796766524..ccf35206e7 100644 --- a/packages/server/src/env.ts +++ b/packages/server/src/env.ts @@ -161,6 +161,14 @@ const defaultEnvValues: EnvVariables = { SAML_IDP_CONFIG_FILE: '', // Config file for the Identity Provider. Should point to an XML file generated by the Identity Provider. SAML_SP_CONFIG_FILE: '', // Config file for the Service Provider (Joplin, in this case). Should point to an XML file generated by the Identity Provider. SAML_ORGANIZATION_DISPLAY_NAME: '', // The name of the organization to display on the login screen. Optional. + + // ================================================== + // Transcribe Server + // ================================================== + + TRANSCRIBE_ENABLED: false, + TRANSCRIBE_API_KEY: '', + TRANSCRIBE_BASE_URL: '', }; export interface EnvVariables { @@ -260,6 +268,10 @@ export interface EnvVariables { SAML_ORGANIZATION_DISPLAY_NAME: string; LOCAL_AUTH_ENABLED: boolean; + + TRANSCRIBE_ENABLED: boolean; + TRANSCRIBE_API_KEY: string; + TRANSCRIBE_BASE_URL: string; } const parseBoolean = (s: string): boolean => { diff --git a/packages/server/src/routes/api/transcribe.test.ts b/packages/server/src/routes/api/transcribe.test.ts new file mode 100644 index 0000000000..d73bbb2ab4 --- /dev/null +++ b/packages/server/src/routes/api/transcribe.test.ts @@ -0,0 +1,177 @@ +import { readFile } from 'fs-extra'; +import { ApiError } from '../../utils/errors'; +import { getApi, postApi } from '../../utils/testing/apiUtils'; +import { beforeAllDb, afterAllTests, beforeEachDb, createUserAndSession, testAssetDir, checkThrowAsync, expectThrow, makeTempFileWithContent } from '../../utils/testing/testUtils'; + +export type TranscribeJob = { + jobId: number; +}; + +type OutputError = { stack: string; message: string }; +type OutputSuccess = { result: string }; +type Output = OutputError | OutputSuccess; + +type JobWithResult = { + id: string; + completedOn?: Date; + result?: Output; + state: string; +}; + + +describe('api_transcribe', () => { + + beforeAll(async () => { + await beforeAllDb('api_transcribe', { + envValues: { + TRANSCRIBE_ENABLED: 'true', + TRANSCRIBE_API_KEY: 'something', + TRANSCRIBE_SERVER_ADDRESS: 'something', + }, + }); + }); + + afterAll(async () => { + await afterAllTests(); + }); + + beforeEach(async () => { + await beforeEachDb(); + }); + + test('should create job', async () => { + const { session } = await createUserAndSession(1); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: () => Promise.resolve( + { jobId: '608626f1-cad9-4b07-a02e-ec427c47147f' }, + ), + status: 200, + })) as jest.Mock, + ); + const fileContent = await readFile(`${testAssetDir}/htr_example.png`); + const tempFilePath = await makeTempFileWithContent(fileContent); + const response = await postApi(session.id, 'transcribe', {}, + { + filePath: tempFilePath, + }, + ); + + expect(response.jobId).toBe('608626f1-cad9-4b07-a02e-ec427c47147f'); + }); + + test('should create job and return response eventually', async () => { + const { session } = await createUserAndSession(1); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: () => Promise.resolve( + { jobId: '608626f1-cad9-4b07-a02e-ec427c47147f' }, + ), + status: 200, + })) as jest.Mock, + ); + + const fileContent = await readFile(`${testAssetDir}/htr_example.png`); + const tempFilePath = await makeTempFileWithContent(fileContent); + const postResponse = await postApi(session.id, 'transcribe', {}, + { + filePath: tempFilePath, + }, + ); + + expect(postResponse.jobId).not.toBe(undefined); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: (): Promise => Promise.resolve( + { + id: '608626f1-cad9-4b07-a02e-ec427c47147f', + state: 'completed', + result: { result: 'transcription' }, + }, + ), + status: 200, + })) as jest.Mock, + ); + + const getResponse = await getApi(session.id, `transcribe/${postResponse.jobId}`, {}); + expect(getResponse.id).toBe(postResponse.jobId); + expect(getResponse.state).toBe('completed'); + expect((getResponse.result as OutputSuccess).result).toBe('transcription'); + }); + + test('should throw a error if API returns error 400', async () => { + const { session } = await createUserAndSession(1); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: () => Promise.resolve(''), + status: 400, + })) as jest.Mock, + ); + + const fileContent = await readFile(`${testAssetDir}/htr_example.png`); + const tempFilePath = await makeTempFileWithContent(fileContent); + const error = await checkThrowAsync(() => + postApi(session.id, 'transcribe', {}, + { + filePath: tempFilePath, + }, + )); + + expect(error instanceof ApiError).toBe(true); + }); + + test('should throw error if API returns error 500', async () => { + const { session } = await createUserAndSession(1); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: () => Promise.resolve(''), + status: 500, + })) as jest.Mock, + ); + + const fileContent = await readFile(`${testAssetDir}/htr_example.png`); + const tempFilePath = await makeTempFileWithContent(fileContent); + const error = await checkThrowAsync(() => + postApi(session.id, 'transcribe', {}, + { + filePath: tempFilePath, + }, + )); + + expect(error instanceof ApiError).toBe(true); + }); + test('should throw 500 error is something unexpected', async () => { + const { session } = await createUserAndSession(1); + + jest.spyOn(global, 'fetch').mockImplementation( + jest.fn(() => Promise.resolve( + { + json: () => Promise.reject(new Error('Something went wrong')), + status: 200, + })) as jest.Mock, + ); + + const fileContent = await readFile(`${testAssetDir}/htr_example.png`); + const tempFilePath = await makeTempFileWithContent(fileContent); + const error = await expectThrow(() => + postApi(session.id, 'transcribe', {}, + { + filePath: tempFilePath, + }, + )); + + expect(error.httpCode).toBe(500); + expect(error.message.startsWith('POST /api/transcribe {"status":500,"body":{"error":"Something went wrong"')).toBe(true); + }); + +}); diff --git a/packages/server/src/routes/api/transcribe.ts b/packages/server/src/routes/api/transcribe.ts new file mode 100644 index 0000000000..56df2d6a17 --- /dev/null +++ b/packages/server/src/routes/api/transcribe.ts @@ -0,0 +1,97 @@ +import { readFile } from 'fs-extra'; +import { ErrorBadGateway, ErrorBadRequest, ErrorNotImplemented, ErrorServiceUnavailable } from '../../utils/errors'; +import { formParse } from '../../utils/requestUtils'; +import Router from '../../utils/Router'; +import { SubPath } from '../../utils/routeUtils'; +import { AppContext, RouteType } from '../../utils/types'; +import Logger from '@joplin/utils/Logger'; +import shim from '@joplin/lib/shim'; +import config from '../../config'; +import { safeRemove } from '../../utils/fileUtils'; + +const logger = Logger.create('api/transcribe'); + +const router = new Router(RouteType.Api); + +const isHtrSupported = () => { + return config().TRANSCRIBE_ENABLED; +}; + +router.get('api/transcribe/:id', async (path: SubPath, _ctx: AppContext) => { + if (!isHtrSupported()) { + throw new ErrorNotImplemented('HTR feature is not enabled in this server'); + } + + try { + logger.info(`Checking Transcribe for Job: ${path.id}`); + const response = await fetch(`${config().TRANSCRIBE_BASE_URL}/transcribe/${path.id}`, + { + headers: { + 'Authorization': config().TRANSCRIBE_API_KEY, + }, + }, + ); + + if (response.status >= 400 && response.status < 500) { + const responseJson = await response.json(); + throw new ErrorBadRequest(responseJson.error); + } else if (response.status >= 500) { + const responseJson = await response.json(); + throw new ErrorBadGateway(responseJson.error); + } + + const responseJson = await response.json(); + return responseJson; + } catch (error) { + if (shim.fetchRequestCanBeRetried(error) || shim.fetchRequestCanBeRetried(error.cause)) { + throw new ErrorServiceUnavailable('Transcribe Server not available right now.', error); + } + throw error; + } +}); + +router.post('api/transcribe', async (_path: SubPath, ctx: AppContext) => { + if (!isHtrSupported()) { + throw new ErrorNotImplemented('HTR feature is not enabled in this server'); + } + + const request = await formParse(ctx.req); + if (!request.files.file) throw new ErrorBadRequest('No file provided. Use a multipart/form request with a \'file\' property.'); + + const form = new FormData(); + const file = await readFile(request.files.file.filepath); + const blob = new Blob([file]); + form.append('file', blob, 'file'); + + try { + logger.info('Sending file to Transcribe Server'); + const response = await fetch(`${config().TRANSCRIBE_BASE_URL}/transcribe`, { + method: 'POST', + body: form, + headers: { + 'Authorization': config().TRANSCRIBE_API_KEY, + }, + }); + + if (response.status >= 400 && response.status < 500) { + const responseJson = await response.json(); + throw new ErrorBadRequest(responseJson.error); + } else if (response.status >= 500) { + const responseJson = await response.json(); + throw new ErrorBadGateway(responseJson.error); + } + + const responseJson = await response.json(); + logger.info(`Job created successfully: ${responseJson.jobId}`); + return responseJson; + } catch (error) { + if (shim.fetchRequestCanBeRetried(error) || shim.fetchRequestCanBeRetried(error.cause)) { + throw new ErrorServiceUnavailable('Transcribe Server not available right now.', error); + } + throw error; + } finally { + await safeRemove(request.files.file.filepath); + } +}); + +export default router; diff --git a/packages/server/src/routes/routes.ts b/packages/server/src/routes/routes.ts index 39d6f305cc..0399ba0def 100644 --- a/packages/server/src/routes/routes.ts +++ b/packages/server/src/routes/routes.ts @@ -12,6 +12,7 @@ import apiShares from './api/shares'; import apiShareUsers from './api/share_users'; import apiUsers from './api/users'; import apiLogin from './api/login'; +import apiTranscribe from './api/transcribe'; import adminDashboard from './admin/dashboard'; import adminEmails from './admin/emails'; @@ -52,6 +53,7 @@ const routes: Routers = { 'api/share_users': apiShareUsers, 'api/shares': apiShares, 'api/users': apiUsers, + 'api/transcribe': apiTranscribe, 'admin/dashboard': adminDashboard, 'admin/emails': adminEmails, diff --git a/packages/server/src/utils/errors.ts b/packages/server/src/utils/errors.ts index b47539aa9d..a455c281e4 100644 --- a/packages/server/src/utils/errors.ts +++ b/packages/server/src/utils/errors.ts @@ -142,6 +142,37 @@ export class ErrorTooManyRequests extends ApiError { } } +export class ErrorNotImplemented extends ApiError { + public static httpCode = 501; + public retryAfterMs = 0; + + public constructor(message = 'Not Implemented', options: ErrorOptions = null) { + super(message, ErrorNotImplemented.httpCode, options); + Object.setPrototypeOf(this, ErrorNotImplemented.prototype); + } +} + +export class ErrorBadGateway extends ApiError { + public static httpCode = 502; + public retryAfterMs = 0; + + public constructor(message = 'Bad Gateway', options: ErrorOptions = null) { + super(message, ErrorBadGateway.httpCode, options); + Object.setPrototypeOf(this, ErrorBadGateway.prototype); + } +} + +export class ErrorServiceUnavailable extends ApiError { + public static httpCode = 503; + public retryAfterMs = 0; + + public constructor(message = 'Service Unavailable', options: ErrorOptions = null) { + super(message, ErrorServiceUnavailable.httpCode, options); + Object.setPrototypeOf(this, ErrorServiceUnavailable.prototype); + } +} + + export function errorToString(error: Error): string { // const msg: string[] = []; // msg.push(error.message ? error.message : 'Unknown error'); diff --git a/readme/privacy.md b/readme/privacy.md index de45172d3f..cbb91ea1a3 100644 --- a/readme/privacy.md +++ b/readme/privacy.md @@ -15,6 +15,7 @@ In order to provide certain features, Joplin may need to connect to third-party | Voice typing | If you use the voice typing feature on Android, the application will download the language files from https://github.com/joplin/voice-typing-models/ or https://alphacephei.com/vosk/models. | Disabled | Yes | OCR | If you have enabled optical character recognition on desktop, the application will download the language files from https://cdn.jsdelivr.net/npm/@tesseract.js-data/. | Disabled | Yes | Crash reports | If you have enabled crash auto-upload, the application will upload the report to Sentry when a crash happens. When Sentry is initialised it will also connect to `sentry.io`. | Disabled | Yes +| Handwriting recognition | This option allows the user to send images to Joplin Server/Cloud to be transcribed, only images selected with the 'Recognize handwritten image' are affected. | Enabled | Yes (1) https://github.com/laurent22/joplin/issues/5705
(2) If the spellchecker is disabled, [it will not download the dictionary](https://discourse.joplinapp.org/t/new-version-of-joplin-contacting-google-servers-on-startup/23000/40?u=laurent).