You've already forked STARK
mirror of
https://github.com/MarkParker5/STARK.git
synced 2025-09-16 09:36:24 +02:00
pass all tests
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from types import GeneratorType, AsyncGeneratorType
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import Any, Protocol, runtime_checkable, cast
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
@@ -13,7 +13,7 @@ from ..general.localisation import Localizer
|
||||
from ..models.transcription import Transcription
|
||||
from ..models.localizable_string import LocalizableString
|
||||
from .commands_manager import CommandsManager, SearchResult
|
||||
from .command import Command, Response, ResponseHandler, AsyncResponseHandler, CommandRunner, ResponseOptions
|
||||
from .command import Command, Response, ResponseHandler, AsyncResponseHandler, CommandRunner, ResponseOptions, AwaitResponse, AwaitResponse
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -106,10 +106,10 @@ class CommandsContext:
|
||||
self.run_command(search_result.command, parameters, search_result.match_result.subtrack.language_code)
|
||||
|
||||
def inject_dependencies(self, runner: Command[CommandRunner] | CommandRunner) -> CommandRunner:
|
||||
def injected_func(**kwargs) -> ResponseOptions:
|
||||
def injected_func(**kwargs) -> AwaitResponse | Response | None:
|
||||
kwargs.update(self.dependency_manager.resolve(runner._runner if isinstance(runner, Command) else runner))
|
||||
return runner(**kwargs)
|
||||
return injected_func
|
||||
return cast(CommandRunner, injected_func)
|
||||
|
||||
def run_command(self, command: Command, parameters: dict[str, Any] | None = None, language_code: str | None = None):
|
||||
parameters = parameters or {}
|
||||
|
@@ -88,6 +88,7 @@ class CommandsManager:
|
||||
else: # else remove less priority
|
||||
results.remove(priority2)
|
||||
|
||||
print('results:', [result.command.name for result in results])
|
||||
return results
|
||||
|
||||
def new(self, pattern_str: str, hidden: bool = False):
|
||||
|
@@ -2,14 +2,14 @@ alphanumerics = r'A-zА-яЁё0-9'
|
||||
specials = r'\(\)\[\]\{\}'
|
||||
any = alphanumerics + specials
|
||||
|
||||
dictionary = {
|
||||
dictionary = [
|
||||
# one of the list (a|b|c)
|
||||
r'\(((?:.*\|)*.*)\)': r'(?:\1)',
|
||||
(r'\(((?:.*\|)*.*)\)', r'(?:\1)'),
|
||||
|
||||
# one or more of the list, space-splitted {a|b|c}
|
||||
r'\{((?:.*\|?)*?.*?)\}': r'(?:(?:\1)\\s?)+',
|
||||
(r'\{((?:.*\|?)*?.*?)\}', r'(?:(?:\1)\\s?)+'),
|
||||
|
||||
# stars *
|
||||
r'\*\*': fr'[{alphanumerics}\\s]*', # ** for any few words
|
||||
fr'([^{specials}]|^)\*': fr'\1[{alphanumerics}]*', # * for any word
|
||||
}
|
||||
(r' ?\*\* ?', fr' ?[{alphanumerics}\\s]* ?'), # ** for any few words
|
||||
(fr' ?([^{specials}]|^)\* ?', fr' ?[{alphanumerics}]* ?'), # * for any word
|
||||
]
|
||||
|
@@ -36,21 +36,21 @@ class Pattern:
|
||||
|
||||
def prepare(self, localizer: Localizer):
|
||||
for language in localizer.languages:
|
||||
self._compiled[language] = self._get_compiled(language, localizer)
|
||||
self._compiled[language] = self.get_compiled(language, localizer)
|
||||
|
||||
async def match(self, transcription: Transcription, localizer: Localizer, objects_cache: dict[str, Object] | None = None) -> list[MatchResult]:
|
||||
if not self._origin:
|
||||
return []
|
||||
|
||||
if objects_cache is None:
|
||||
objects_cache = {}
|
||||
objects_cache = dict()
|
||||
|
||||
matches: list[MatchResult] = []
|
||||
|
||||
for language_code, track in transcription.origins.items():
|
||||
|
||||
string = track.text
|
||||
compiled = self._get_compiled(language_code, localizer)
|
||||
compiled = self.get_compiled(language_code, localizer)
|
||||
|
||||
# map suggestions to more comfortable data structure
|
||||
|
||||
@@ -81,7 +81,7 @@ class Pattern:
|
||||
|
||||
parameters: dict[str, Object] = {}
|
||||
substrings: dict[str, str] = {}
|
||||
futures: list[tuple[str, SoonValue[ParseResult | None]]] = []
|
||||
futures: list[tuple[str, SoonValue[ParseResult]]] = []
|
||||
|
||||
# run concurrent objects parsing
|
||||
async with create_task_group() as group:
|
||||
@@ -106,23 +106,18 @@ class Pattern:
|
||||
subtranscription = transcription.get_slice(*time_range)
|
||||
subtranscription.suggestions = [s for s in subtranscription.suggestions if s[0] in subtrack.text]
|
||||
|
||||
async def parse() -> ParseResult | None:
|
||||
try:
|
||||
parse_result = await object_type.parse(subtrack, subtranscription, match_str_groups)
|
||||
except ParseError:
|
||||
return None
|
||||
objects_cache[parse_result.substring] = parse_result.obj
|
||||
async def parse(track: TranscriptionTrack, transcription: Transcription, re_match_groups: dict[str, str] | None = None) -> ParseResult:
|
||||
parse_result = await object_type.parse(track, transcription, re_match_groups)
|
||||
objects_cache[parse_result.track.text] = parse_result.obj
|
||||
return parse_result
|
||||
|
||||
futures.append((name, group.soonify(parse)()))
|
||||
futures.append((name, group.soonify(parse)(subtrack, subtranscription, match_str_groups)))
|
||||
|
||||
# read futures
|
||||
for name, future in futures:
|
||||
parse_result = future.value
|
||||
if not parse_result:
|
||||
continue
|
||||
parameters[name] = parse_result.obj
|
||||
substrings[name] = parse_result.substring
|
||||
substrings[name] = parse_result.track.text
|
||||
|
||||
# save parameters
|
||||
for name in parameters.keys():
|
||||
@@ -185,7 +180,7 @@ class Pattern:
|
||||
|
||||
yield arg_name, arg_type
|
||||
|
||||
def _get_compiled(self, language: str, localizer: Localizer) -> str:
|
||||
def get_compiled(self, language: str, localizer: Localizer) -> str:
|
||||
'''transform Pattern to classic regex with named groups'''
|
||||
|
||||
if language in self._compiled:
|
||||
@@ -195,7 +190,7 @@ class Pattern:
|
||||
|
||||
# transform core expressions to regex
|
||||
|
||||
for pattern_re, regex in dictionary.items():
|
||||
for pattern_re, regex in dictionary:
|
||||
pattern = re.sub(pattern_re, regex, pattern)
|
||||
|
||||
# find and transform parameters like $name:Type
|
||||
@@ -203,7 +198,7 @@ class Pattern:
|
||||
for name, object_type in self.parameters.items():
|
||||
|
||||
arg_declaration = f'\\${name}\\:{object_type.__name__}'
|
||||
arg_pattern = object_type.pattern.compiled.replace('\\', r'\\')
|
||||
arg_pattern = object_type.pattern.get_compiled(language, localizer).replace('\\', r'\\')
|
||||
pattern = re.sub(arg_declaration, f'(?P<{name}>{arg_pattern})', pattern)
|
||||
|
||||
self._compiled[language] = pattern
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Any
|
||||
from collections import namedtuple
|
||||
from typing import NamedTuple
|
||||
from abc import ABC
|
||||
import copy
|
||||
|
||||
@@ -9,7 +9,10 @@ from stark.models.transcription import Transcription, TranscriptionTrack
|
||||
from .. import Pattern
|
||||
|
||||
|
||||
ParseResult = namedtuple('ParseResult', ['obj', 'substring'])
|
||||
class ParseResult(NamedTuple):
|
||||
obj: Object
|
||||
track: TranscriptionTrack
|
||||
transcription: Transcription
|
||||
|
||||
class ParseError(Exception):
|
||||
pass
|
||||
@@ -26,7 +29,7 @@ class Object(ABC):
|
||||
'''Just init with wrapped value.'''
|
||||
self.value = value
|
||||
|
||||
async def did_parse(self, track: TranscriptionTrack, transcription: Transcription, re_match_groups: dict[str, str]) -> Transcription:
|
||||
async def did_parse(self, track: TranscriptionTrack, transcription: Transcription, re_match_groups: dict[str, str]) -> tuple[TranscriptionTrack, Transcription]:
|
||||
'''
|
||||
This method is called after parsing from string and setting parameters found in pattern.
|
||||
You will very rarely, if ever, need to call this method directly.
|
||||
@@ -39,7 +42,8 @@ class Object(ABC):
|
||||
Raises:
|
||||
ParseError: if parsing failed.
|
||||
'''
|
||||
return transcription
|
||||
self.value = track.text
|
||||
return track, transcription
|
||||
|
||||
@classmethod
|
||||
async def parse(cls, track: TranscriptionTrack, transcription: Transcription, re_match_groups: dict[str, str] | None = None) -> ParseResult:
|
||||
@@ -66,10 +70,9 @@ class Object(ABC):
|
||||
time_range = next(iter(track.get_time(value)))
|
||||
sub_track = track.get_slice(*time_range)
|
||||
sub_transcription = transcription.get_slice(*time_range)
|
||||
|
||||
setattr(obj, name, (await object_type.parse(sub_track, sub_transcription, re_match_groups)).obj)
|
||||
|
||||
return ParseResult(obj, await obj.did_parse(sub_track, sub_transcription, re_match_groups))
|
||||
return ParseResult(obj, *(await obj.did_parse(track, transcription, re_match_groups)))
|
||||
|
||||
def copy(self) -> Object:
|
||||
return copy.copy(self)
|
||||
|
@@ -13,10 +13,10 @@ class Localizer:
|
||||
localizable: Languages
|
||||
recognizable: Languages
|
||||
languages: set[str]
|
||||
base_language: str
|
||||
base_language: str # language of the base.strings file
|
||||
|
||||
def __init__(self, languages: set[str], base_language: str):
|
||||
self.languages = languages
|
||||
def __init__(self, languages: set[str] | None = None, base_language: str = 'en'):
|
||||
self.languages = languages or {'en'}
|
||||
self.base_language = base_language
|
||||
self.localizable = {}
|
||||
self.recognizable = {}
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from typing import cast, Iterable
|
||||
from dataclasses import dataclass
|
||||
from stark.models.transcription import Transcription
|
||||
from stark.models.transcription import Transcription, Suggestion
|
||||
from .strings.levenshtein import levenshtein
|
||||
from .strings.starkophone import starkophone as get_starkophone
|
||||
from .strings.ipa import ipa_to_latin, string_to_ipa
|
||||
@@ -39,7 +40,7 @@ class SuggestionsManager:
|
||||
|
||||
def add_transcription_suggestions(self, transcription: Transcription):
|
||||
for language, track in transcription.origins.items():
|
||||
transcription.suggestions.extend(self.get_string_suggestions(track.text, language))
|
||||
transcription.suggestions.extend(cast(Iterable[Suggestion], self.get_string_suggestions(track.text, language)))
|
||||
|
||||
def get_string_suggestions(self, string: str, language_code: str) -> set[tuple[str, str]]:
|
||||
# TODO: cache all long shit
|
||||
|
@@ -15,6 +15,7 @@ from ..interfaces.protocols import (
|
||||
SpeechSynthesizer
|
||||
)
|
||||
from ..models.transcription import Transcription
|
||||
from ..models.localizable_string import LocalizableString
|
||||
from ..general.localisation import Localizer
|
||||
from .mode import Mode
|
||||
|
||||
@@ -141,8 +142,15 @@ class VoiceAssistant(SpeechRecognizerDelegate, CommandsContextDelegate):
|
||||
|
||||
async def _play_response(self, response: Response):
|
||||
|
||||
text = self.localizer.localize(response.text) or response.text
|
||||
voice = self.localizer.localize(response.voice) or response.voice
|
||||
if isinstance(response.text, LocalizableString):
|
||||
text = self.localizer.localize(response.text)
|
||||
else:
|
||||
text = response.text
|
||||
|
||||
if isinstance(response.voice, LocalizableString):
|
||||
voice = self.localizer.localize(response.voice)
|
||||
else:
|
||||
voice = response.voice
|
||||
|
||||
self.commands_context.last_response = response
|
||||
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from typing import AsyncGenerator
|
||||
import time
|
||||
from typing import AsyncGenerator, Callable
|
||||
import contextlib
|
||||
import pytest
|
||||
import asyncer
|
||||
import anyio
|
||||
from stark.general.dependencies import DependencyManager
|
||||
from stark.general.localisation import Localizer
|
||||
from stark.core import (
|
||||
CommandsManager,
|
||||
CommandsContext,
|
||||
@@ -16,6 +16,7 @@ from stark.core import (
|
||||
from stark.core.types import Word
|
||||
from stark.interfaces.protocols import SpeechRecognizerDelegate
|
||||
from stark.voice_assistant import VoiceAssistant
|
||||
from stark.models.transcription import Transcription, TranscriptionTrack, TranscriptionWord
|
||||
|
||||
|
||||
class CommandsContextDelegateMock(CommandsContextDelegate):
|
||||
@@ -64,7 +65,7 @@ async def commands_context_flow():
|
||||
async with asyncer.create_task_group() as main_task_group:
|
||||
dependencies = DependencyManager()
|
||||
manager = CommandsManager()
|
||||
context = CommandsContext(main_task_group, manager, dependencies)
|
||||
context = CommandsContext(main_task_group, manager, Localizer(), dependencies)
|
||||
context_delegate = CommandsContextDelegateMock()
|
||||
context.delegate = context_delegate
|
||||
|
||||
@@ -169,7 +170,30 @@ async def voice_assistant(commands_context_flow_filled):
|
||||
voice_assistant = VoiceAssistant(
|
||||
speech_recognizer = SpeechRecognizerMock(),
|
||||
speech_synthesizer = SpeechSynthesizerMock(),
|
||||
commands_context = context
|
||||
commands_context = context,
|
||||
localizer = Localizer(),
|
||||
)
|
||||
yield voice_assistant
|
||||
return _voice_assistant
|
||||
|
||||
@pytest.fixture
|
||||
def get_transcription() -> Callable[[str], Transcription]:
|
||||
def getter(string: str) -> Transcription:
|
||||
track = TranscriptionTrack(
|
||||
text = string,
|
||||
result = [
|
||||
TranscriptionWord(
|
||||
word = word,
|
||||
start = i,
|
||||
end = i + 0.5,
|
||||
conf = 1
|
||||
) for i, word in enumerate(string.split())
|
||||
]
|
||||
)
|
||||
return Transcription(
|
||||
best = track,
|
||||
origins = {
|
||||
'en': track
|
||||
}
|
||||
)
|
||||
return getter
|
||||
|
@@ -2,6 +2,7 @@ import re
|
||||
import pytest
|
||||
from stark.core import CommandsManager
|
||||
from stark.core.types import Word
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
def test_new():
|
||||
@@ -23,7 +24,7 @@ def test_new_with_extra_parameters_in_pattern():
|
||||
@manager.new('test $name:Word, $secondName:Word')
|
||||
def test(name: Word): pass
|
||||
|
||||
async def test_search():
|
||||
async def test_search(get_transcription):
|
||||
manager = CommandsManager()
|
||||
|
||||
@manager.new('test')
|
||||
@@ -36,26 +37,26 @@ async def test_search():
|
||||
def hello(name: Word): pass
|
||||
|
||||
# test
|
||||
result = await manager.search('test')
|
||||
result = await manager.search(get_transcription('test'), Localizer())
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].command.name == 'CommandsManager.test'
|
||||
|
||||
# hello
|
||||
result = await manager.search('hello world')
|
||||
result = await manager.search(get_transcription('hello world'), Localizer())
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].command.name == 'CommandsManager.hello'
|
||||
assert result[0].match_result.substring == 'hello world'
|
||||
assert result[0].match_result.subtrack.text == 'hello world'
|
||||
assert type(result[0].match_result.parameters['name']) is Word
|
||||
assert result[0].match_result.parameters['name'].value == 'world'
|
||||
|
||||
# hello2
|
||||
result = await manager.search('hello new world')
|
||||
result = await manager.search(get_transcription('hello new world'), Localizer())
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0].command == hello2
|
||||
assert result[0].match_result.substring == 'hello new world'
|
||||
assert result[0].match_result.subtrack.text == 'hello new world'
|
||||
assert result[0].match_result.parameters == {'name': Word('new'), 'surname': Word('world')}
|
||||
|
||||
def test_extend_manager():
|
||||
|
@@ -1,81 +1,80 @@
|
||||
import pytest
|
||||
import anyio
|
||||
|
||||
|
||||
async def test_basic_search(commands_context_flow_filled, autojump_clock):
|
||||
async def test_basic_search(commands_context_flow_filled, autojump_clock, get_transcription):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
assert len(context_delegate.responses) == 0
|
||||
assert len(context._context_queue) == 1
|
||||
|
||||
await context.process_transcription('lorem ipsum dolor')
|
||||
await context.process_transcription(get_transcription('lorem ipsum dolor'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Lorem!'
|
||||
assert len(context._context_queue) == 1
|
||||
|
||||
async def test_second_context_layer(commands_context_flow_filled, autojump_clock):
|
||||
async def test_second_context_layer(commands_context_flow_filled, autojump_clock, get_transcription):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
await context.process_transcription('hello world')
|
||||
await context.process_transcription(get_transcription('hello world'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
await context.process_transcription('hello')
|
||||
await context.process_transcription(get_transcription('hello'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hi, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_clock):
|
||||
async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_clock, get_transcription):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
await context.process_transcription('hello world')
|
||||
await context.process_transcription(get_transcription('hello world'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context._context_queue) == 2
|
||||
assert len(context_delegate.responses) == 1
|
||||
context_delegate.responses.clear()
|
||||
|
||||
await context.process_transcription('lorem ipsum dolor')
|
||||
await context.process_transcription(get_transcription('lorem ipsum dolor'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context._context_queue) == 1
|
||||
assert len(context_delegate.responses) == 1
|
||||
|
||||
async def test_context_pop_context_response_action(commands_context_flow_filled, autojump_clock):
|
||||
async def test_context_pop_context_response_action(commands_context_flow_filled, autojump_clock, get_transcription):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
await context.process_transcription('hello world')
|
||||
await context.process_transcription(get_transcription('hello world'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
await context.process_transcription('bye')
|
||||
await context.process_transcription(get_transcription('bye'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Bye, world!'
|
||||
assert len(context._context_queue) == 1
|
||||
context_delegate.responses.clear()
|
||||
|
||||
await context.process_transcription('hello')
|
||||
await context.process_transcription(get_transcription('hello'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
async def test_repeat_last_answer_response_action(commands_context_flow_filled, autojump_clock):
|
||||
async def test_repeat_last_answer_response_action(commands_context_flow_filled, autojump_clock, get_transcription):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
await context.process_transcription('hello world')
|
||||
await context.process_transcription(get_transcription('hello world'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
context_delegate.responses.clear()
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
await context.process_transcription('repeat')
|
||||
await context.process_transcription(get_transcription('repeat'))
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
|
@@ -4,7 +4,7 @@ import anyio
|
||||
from stark.core import Response
|
||||
|
||||
|
||||
async def test_commands_context_handle_async_generator(commands_context_flow, autojump_clock):
|
||||
async def test_commands_context_handle_async_generator(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
@@ -21,7 +21,7 @@ async def test_commands_context_handle_async_generator(commands_context_flow, au
|
||||
yield Response(text = 'foo4')
|
||||
# return is not allowed in generators (functions with yield)
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
|
||||
last_count = 0
|
||||
while last_count < 5:
|
||||
@@ -34,7 +34,7 @@ async def test_commands_context_handle_async_generator(commands_context_flow, au
|
||||
assert len(context_delegate.responses) == 5
|
||||
assert [r.text for r in context_delegate.responses] == [f'foo{i}' for i in range(5)]
|
||||
|
||||
async def test_commands_context_handle_sync_generator(commands_context_flow, autojump_clock):
|
||||
async def test_commands_context_handle_sync_generator(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
@@ -47,7 +47,7 @@ async def test_commands_context_handle_sync_generator(commands_context_flow, aut
|
||||
# return is not allowed in generators (functions with yield)
|
||||
|
||||
with warnings.catch_warnings(record = True) as warnings_list:
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(1)
|
||||
|
||||
assert len(warnings_list) == 2
|
||||
|
@@ -2,7 +2,7 @@ import anyio
|
||||
from stark.core import AsyncResponseHandler, Response
|
||||
|
||||
|
||||
async def test_commands_context_inject_dependencies(commands_context_flow, autojump_clock):
|
||||
async def test_commands_context_inject_dependencies(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
@manager.new('foo')
|
||||
async def foo(handler: AsyncResponseHandler) -> Response:
|
||||
@@ -12,7 +12,7 @@ async def test_commands_context_inject_dependencies(commands_context_flow, autoj
|
||||
async def bar(inject_dependencies):
|
||||
return await inject_dependencies(foo)()
|
||||
|
||||
await context.process_transcription('bar')
|
||||
await context.process_transcription(get_transcription('bar'))
|
||||
await anyio.sleep(1)
|
||||
|
||||
assert len(context_delegate.responses) == 1
|
||||
|
@@ -4,65 +4,65 @@ import anyio
|
||||
from stark.core import Response, ResponseHandler, AsyncResponseHandler
|
||||
|
||||
|
||||
async def test_command_return_response(commands_context_flow, autojump_clock):
|
||||
async def test_command_return_response(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
async def foo() -> Response:
|
||||
return Response(text = 'foo!')
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'foo!'
|
||||
|
||||
async def test_sync_command_call_sync_respond(commands_context_flow, autojump_clock):
|
||||
async def test_sync_command_call_sync_respond(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
def foo(handler: ResponseHandler):
|
||||
handler.respond(Response(text = 'foo!'))
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'foo!'
|
||||
|
||||
async def test_async_command_call_sync_respond(commands_context_flow, autojump_clock):
|
||||
async def test_async_command_call_sync_respond(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
async def foo(handler: AsyncResponseHandler):
|
||||
await handler.respond(Response(text = 'foo!'))
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'foo!'
|
||||
|
||||
@pytest.mark.skip(reason = 'deprecated: added checks for DI on command creation')
|
||||
async def test_sync_command_call_async_respond(commands_context_flow, autojump_clock):
|
||||
async def test_sync_command_call_async_respond(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
def foo(handler: AsyncResponseHandler):
|
||||
with warnings.catch_warnings(record = True) as warnings_list:
|
||||
assert len(warnings_list) == 0
|
||||
handler.respond(Response(text = 'foo!'))
|
||||
handler.respond(Response(text = 'foo!')) # type: ignore
|
||||
assert len(warnings_list) == 1
|
||||
assert issubclass(warnings_list[0].category, RuntimeWarning)
|
||||
assert 'was never awaited' in str(warnings_list[0].message)
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
@pytest.mark.skip(reason = 'deprecated: added checks for DI on command creation')
|
||||
async def test_async_command_call_async_respond(commands_context_flow, autojump_clock):
|
||||
async def test_async_command_call_async_respond(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
@@ -70,12 +70,12 @@ async def test_async_command_call_async_respond(commands_context_flow, autojump_
|
||||
with pytest.raises(RuntimeError, match = 'can only be run from an AnyIO worker thread'):
|
||||
handler.respond(Response(text = 'foo!'))
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
async def test_command_multiple_respond(commands_context_flow, autojump_clock):
|
||||
async def test_command_multiple_respond(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo')
|
||||
@@ -91,7 +91,7 @@ async def test_command_multiple_respond(commands_context_flow, autojump_clock):
|
||||
await anyio.sleep(2)
|
||||
return Response(text = 'foo4')
|
||||
|
||||
await context.process_transcription('foo')
|
||||
await context.process_transcription(get_transcription('foo'))
|
||||
|
||||
last_count = 0
|
||||
while last_count < 5:
|
||||
|
@@ -3,10 +3,12 @@ import anyio
|
||||
from stark.core import Pattern, Response, CommandsManager
|
||||
from stark.core.types import Object
|
||||
from stark.general.classproperty import classproperty
|
||||
from stark.general.localisation.localizer import Localizer
|
||||
from stark.models.transcription import Transcription, TranscriptionTrack
|
||||
import random
|
||||
|
||||
|
||||
async def test_multiple_commands(commands_context_flow, autojump_clock):
|
||||
async def test_multiple_commands(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('foo bar')
|
||||
@@ -17,91 +19,87 @@ async def test_multiple_commands(commands_context_flow, autojump_clock):
|
||||
def lorem():
|
||||
return Response(text = 'lorem!')
|
||||
|
||||
await context.process_transcription('foo bar lorem ipsum dolor')
|
||||
await context.process_transcription(get_transcription('foo bar lorem ipsum dolor'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 2
|
||||
assert {context_delegate.responses[0].text, context_delegate.responses[1].text} == {'foo!', 'lorem!'}
|
||||
|
||||
async def test_repeating_command(commands_context_flow, autojump_clock):
|
||||
async def test_repeating_command(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
@manager.new('lorem * dolor')
|
||||
def lorem():
|
||||
return Response(text = 'lorem!')
|
||||
|
||||
await context.process_transcription('lorem pisum dolor lorem ipsutest_repeating_commanduum dolor sit amet')
|
||||
await context.process_transcription(get_transcription('lorem pisum dolor lorem ipsutest_repeating_commanduum dolor sit amet'))
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 2
|
||||
assert context_delegate.responses[0].text == 'lorem!'
|
||||
assert context_delegate.responses[1].text == 'lorem!'
|
||||
|
||||
async def test_overlapping_commands_less_priority_cut(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
async def test_overlapping_commands_less_priority_cut(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('* baz')
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search(get_transcription('foo bar test baz'), Localizer())
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.subtrack.text == 'foo bar test'
|
||||
assert result[1].match_result.subtrack.text == 'baz'
|
||||
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('* baz')
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search('foo bar test baz')
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.substring == 'foo bar test'
|
||||
assert result[1].match_result.substring == 'baz'
|
||||
async def test_overlapping_commands_priority_cut(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('*t baz')
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search(get_transcription('foo bar test baz'), Localizer())
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.subtrack.text == 'foo bar'
|
||||
assert result[1].match_result.subtrack.text == 'test baz'
|
||||
|
||||
async def test_overlapping_commands_priority_cut(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
async def test_overlapping_commands_remove(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search(get_transcription('foo bar baz'), Localizer())
|
||||
assert len(result) == 1
|
||||
assert result[0].command == foobar
|
||||
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('*t baz')
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search('foo bar test baz')
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.substring == 'foo bar'
|
||||
assert result[1].match_result.substring == 'test baz'
|
||||
|
||||
async def test_overlapping_commands_remove(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
result = await manager.search('foo bar baz')
|
||||
assert len(result) == 1
|
||||
assert result[0].command == foobar
|
||||
|
||||
async def test_overlapping_commands_remove_inverse(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
result = await manager.search('foo bar baz')
|
||||
assert len(result) == 1
|
||||
assert result[0].command == barbaz
|
||||
async def test_overlapping_commands_remove_inverse(commands_context_flow, autojump_clock, get_transcription):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
result = await manager.search(get_transcription('foo bar baz'), Localizer())
|
||||
assert len(result) == 1
|
||||
assert result[0].command == barbaz
|
||||
|
||||
@pytest.mark.skip(reason = 'Cache is deprecated and not working properly anymore because of new concurrent algorithm; need new async lru cache implementation')
|
||||
async def test_objects_parse_caching(commands_context_flow, autojump_clock):
|
||||
async def test_objects_parse_caching(commands_context_flow, autojump_clock, get_transcription):
|
||||
class Mock(Object):
|
||||
|
||||
parsing_counter = 0
|
||||
@@ -110,9 +108,9 @@ async def test_objects_parse_caching(commands_context_flow, autojump_clock):
|
||||
def pattern(cls):
|
||||
return Pattern('*')
|
||||
|
||||
async def did_parse(self, from_string: str) -> str:
|
||||
async def did_parse(self, track: TranscriptionTrack, transcription: Transcription, re_match_groups: dict[str, str]) -> tuple[TranscriptionTrack, Transcription]:
|
||||
Mock.parsing_counter += 1
|
||||
return from_string
|
||||
return track, transcription
|
||||
|
||||
mock_name = f'Mock{random.randint(0, 10**10)}'
|
||||
Mock.__name__ = mock_name # prevent name collision on paralell tests
|
||||
@@ -133,7 +131,7 @@ async def test_objects_parse_caching(commands_context_flow, autojump_clock):
|
||||
async def test(mock: Mock): pass
|
||||
|
||||
assert Mock.parsing_counter == 0
|
||||
await manager.search('hello foobar 22')
|
||||
await manager.search(get_transcription('hello foobar 22'), Localizer())
|
||||
assert Mock.parsing_counter == 1
|
||||
await manager.search('hello foobar 22')
|
||||
await manager.search(get_transcription('hello foobar 22'), Localizer())
|
||||
assert Mock.parsing_counter == 2
|
||||
|
@@ -4,6 +4,7 @@ from stark.core import Pattern
|
||||
from stark.core.types import Object, Word, String
|
||||
from stark.core.patterns import expressions
|
||||
from stark.general.classproperty import classproperty
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
word = fr'[{expressions.alphanumerics}]*'
|
||||
@@ -17,22 +18,23 @@ class ExtraParameterInPattern(Object):
|
||||
def pattern(cls) -> Pattern:
|
||||
return Pattern('$word1:Word $word2:Word $word3:Word')
|
||||
|
||||
async def test_typed_parameters():
|
||||
async def test_typed_parameters(get_transcription):
|
||||
p = Pattern('lorem $name:Word dolor')
|
||||
p.get_compiled('en', Localizer())
|
||||
assert p.parameters == {'name': Word}
|
||||
assert p.compiled == fr'lorem (?P<name>{word}) dolor'
|
||||
# assert p._compiled['en'] == fr'lorem (?P<name>{word}) dolor'
|
||||
|
||||
m = await p.match('lorem ipsum dolor')
|
||||
m = await p.match(get_transcription('lorem ipsum dolor'), Localizer())
|
||||
assert m
|
||||
assert m[0].substring == 'lorem ipsum dolor'
|
||||
assert m[0].subtrack.text == 'lorem ipsum dolor'
|
||||
assert m[0].parameters == {'name': Word('ipsum')}
|
||||
assert not await p.match('lorem ipsum foo dolor')
|
||||
assert not await p.match(get_transcription('lorem ipsum foo dolor'), Localizer())
|
||||
|
||||
p = Pattern('lorem $name:String dolor')
|
||||
assert p.parameters == {'name': String}
|
||||
m = await p.match('lorem ipsum foo bar dolor')
|
||||
m = await p.match(get_transcription('lorem ipsum foo bar dolor'), Localizer())
|
||||
assert m
|
||||
assert m[0].substring == 'lorem ipsum foo bar dolor'
|
||||
assert m[0].subtrack.text == 'lorem ipsum foo bar dolor'
|
||||
assert m[0].parameters == {'name': String('ipsum foo bar')}
|
||||
|
||||
def test_undefined_typed_parameters():
|
||||
|
@@ -1,118 +1,133 @@
|
||||
from stark.core import Pattern
|
||||
from stark.core.patterns import expressions
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
word = fr'[{expressions.alphanumerics}]*'
|
||||
words = fr'[{expressions.alphanumerics}\s]*'
|
||||
|
||||
async def test_leading_star():
|
||||
async def test_leading_star(get_transcription):
|
||||
p = Pattern('*text')
|
||||
assert p.compiled == fr'{word}text'
|
||||
assert await p.match('text')
|
||||
assert await p.match('aaatext')
|
||||
assert (await p.match('bbb aaaatext cccc'))[0].substring == 'aaaatext'
|
||||
assert not await p.match('aaaaext')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'{word}text'
|
||||
assert await p.match(get_transcription('text'), Localizer())
|
||||
assert await p.match(get_transcription('aaatext'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb aaaatext cccc'), Localizer()))[0].subtrack.text == 'aaaatext'
|
||||
assert not await p.match(get_transcription('aaaaext'), Localizer())
|
||||
|
||||
p = Pattern('Some *text here')
|
||||
assert p.compiled == fr'Some {word}text here'
|
||||
assert await p.match('Some text here')
|
||||
assert await p.match('Some aaatext here')
|
||||
assert (await p.match('bbb Some aaatext here cccc'))[0].substring == 'Some aaatext here'
|
||||
assert not await p.match('aaatext here')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'Some {word}text here'
|
||||
assert await p.match(get_transcription('Some text here'), Localizer())
|
||||
assert await p.match(get_transcription('Some aaatext here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some aaatext here cccc'), Localizer()))[0].subtrack.text == 'Some aaatext here'
|
||||
assert not await p.match(get_transcription('Some aaatext'), Localizer())
|
||||
|
||||
async def test_trailing_star():
|
||||
async def test_trailing_star(get_transcription):
|
||||
p = Pattern('text*')
|
||||
assert p.compiled == fr'text{word}'
|
||||
assert await p.match('text')
|
||||
assert await p.match('textaaa')
|
||||
assert (await p.match('bbb textaaa cccc'))[0].substring == 'textaaa'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'text{word}'
|
||||
assert await p.match(get_transcription('text'), Localizer())
|
||||
assert await p.match(get_transcription('textaaa'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb textaaa cccc'), Localizer()))[0].subtrack.text == 'textaaa'
|
||||
|
||||
p = Pattern('Some text* here')
|
||||
assert p.compiled == fr'Some text{word} here'
|
||||
assert await p.match('Some text here')
|
||||
assert await p.match('Some textaaa here')
|
||||
assert (await p.match('bbb Some textaaa here cccc'))[0].substring == 'Some textaaa here'
|
||||
assert not await p.match('Some textaaa ')
|
||||
|
||||
async def test_middle_star():
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'Some text{word} here'
|
||||
assert await p.match(get_transcription('Some text here'), Localizer())
|
||||
assert await p.match(get_transcription('Some textaaa here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some textaaa here cccc'), Localizer()))[0].subtrack.text == 'Some textaaa here'
|
||||
assert not await p.match(get_transcription('Some textaaa '), Localizer())
|
||||
|
||||
async def test_middle_star(get_transcription):
|
||||
p = Pattern('te*xt')
|
||||
assert p.compiled == fr'te{word}xt'
|
||||
assert await p.match('text')
|
||||
assert await p.match('teaaaaaxt')
|
||||
assert (await p.match('bbb teaaaaaxt cccc'))[0].substring == 'teaaaaaxt'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'te{word}xt'
|
||||
assert await p.match(get_transcription('text'), Localizer())
|
||||
assert await p.match(get_transcription('teaaaaaxt'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb teaaaaaxt cccc'), Localizer()))[0].subtrack.text == 'teaaaaaxt'
|
||||
|
||||
p = Pattern('Some te*xt here')
|
||||
assert p.compiled == fr'Some te{word}xt here'
|
||||
assert await p.match('Some text here')
|
||||
assert await p.match('Some teaaaaaxt here')
|
||||
assert (await p.match('bbb Some teaaeaaaxt here cccc'))[0].substring == 'Some teaaeaaaxt here'
|
||||
assert not await p.match('Some teaaaaaxt')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'Some te{word}xt here'
|
||||
assert await p.match(get_transcription('Some text here'), Localizer())
|
||||
assert await p.match(get_transcription('Some teaaaaaxt here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some teaaeaaaxt here cccc'), Localizer()))[0].subtrack.text == 'Some teaaeaaaxt here'
|
||||
assert not await p.match(get_transcription('Some teaaaaaxt'), Localizer())
|
||||
|
||||
async def test_double_star():
|
||||
async def test_double_star(get_transcription):
|
||||
p = Pattern('**')
|
||||
assert p.compiled == fr'{words}'
|
||||
assert (await p.match('bbb teaaaaaxt cccc'))[0].substring == 'bbb teaaaaaxt cccc'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'{words}'
|
||||
assert (await p.match(get_transcription('bbb teaaaaaxt cccc'), Localizer()))[0].subtrack.text == 'bbb teaaaaaxt cccc'
|
||||
|
||||
p = Pattern('Some ** here')
|
||||
assert p.compiled == fr'Some {words} here'
|
||||
assert await p.match('Some text here')
|
||||
assert await p.match('Some lorem ipsum dolor here')
|
||||
assert (await p.match('bbb Some lorem ipsum dolor here cccc'))[0].substring == 'Some lorem ipsum dolor here'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == fr'Some {words} here'
|
||||
assert await p.match(get_transcription('Some text here'), Localizer())
|
||||
assert await p.match(get_transcription('Some lorem ipsum dolor here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some lorem ipsum dolor here cccc'), Localizer()))[0].subtrack.text == 'Some lorem ipsum dolor here'
|
||||
|
||||
async def test_one_of():
|
||||
async def test_one_of(get_transcription):
|
||||
p = Pattern('(foo|bar)')
|
||||
assert p.compiled == r'(?:foo|bar)'
|
||||
assert await p.match('foo')
|
||||
assert await p.match('bar')
|
||||
assert (await p.match('bbb foo cccc'))[0].substring == 'foo'
|
||||
assert (await p.match('bbb bar cccc'))[0].substring == 'bar'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'(?:foo|bar)'
|
||||
assert await p.match(get_transcription('foo'), Localizer())
|
||||
assert await p.match(get_transcription('bar'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb foo cccc'), Localizer()))[0].subtrack.text == 'foo'
|
||||
assert (await p.match(get_transcription('bbb bar cccc'), Localizer()))[0].subtrack.text == 'bar'
|
||||
|
||||
p = Pattern('Some (foo|bar) here')
|
||||
assert p.compiled == r'Some (?:foo|bar) here'
|
||||
assert await p.match('Some foo here')
|
||||
assert await p.match('Some bar here')
|
||||
assert (await p.match('bbb Some foo here cccc'))[0].substring == 'Some foo here'
|
||||
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
||||
assert not await p.match('Some foo')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'Some (?:foo|bar) here'
|
||||
assert await p.match(get_transcription('Some foo here'), Localizer())
|
||||
assert await p.match(get_transcription('Some bar here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some foo here cccc'), Localizer()))[0].subtrack.text == 'Some foo here'
|
||||
assert (await p.match(get_transcription('bbb Some bar here cccc'), Localizer()))[0].subtrack.text == 'Some bar here'
|
||||
assert not await p.match(get_transcription('Some foo'), Localizer())
|
||||
|
||||
async def test_optional_one_of():
|
||||
async def test_optional_one_of(get_transcription):
|
||||
p = Pattern('(foo|bar)?')
|
||||
assert p.compiled == r'(?:foo|bar)?'
|
||||
assert await p.match('foo')
|
||||
assert await p.match('bar')
|
||||
assert not await p.match('')
|
||||
assert not await p.match('bbb cccc')
|
||||
assert (await p.match('bbb foo cccc'))[0].substring == 'foo'
|
||||
assert (await p.match('bbb bar cccc'))[0].substring == 'bar'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'(?:foo|bar)?'
|
||||
assert await p.match(get_transcription('foo'), Localizer())
|
||||
assert await p.match(get_transcription('bar'), Localizer())
|
||||
assert not await p.match(get_transcription(''), Localizer())
|
||||
assert not await p.match(get_transcription('bbb cccc'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb foo cccc'), Localizer()))[0].subtrack.text == 'foo'
|
||||
assert (await p.match(get_transcription('bbb bar cccc'), Localizer()))[0].subtrack.text == 'bar'
|
||||
|
||||
p = Pattern('Some (foo|bar)? here')
|
||||
assert p.compiled == r'Some (?:foo|bar)? here'
|
||||
assert await p.match('Some foo here')
|
||||
assert await p.match('Some bar here')
|
||||
assert await p.match('Some here')
|
||||
assert (await p.match('bbb Some foo here cccc'))[0].substring == 'Some foo here'
|
||||
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
||||
assert (await p.match('bbb Some here cccc'))[0].substring == 'Some here'
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'Some (?:foo|bar)? here'
|
||||
assert await p.match(get_transcription('Some foo here'), Localizer())
|
||||
assert await p.match(get_transcription('Some bar here'), Localizer())
|
||||
assert await p.match(get_transcription('Some here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some foo here cccc'), Localizer()))[0].subtrack.text == 'Some foo here'
|
||||
assert (await p.match(get_transcription('bbb Some bar here cccc'), Localizer()))[0].subtrack.text == 'Some bar here'
|
||||
assert (await p.match(get_transcription('bbb Some here cccc'), Localizer()))[0].subtrack.text == 'Some here'
|
||||
|
||||
# assert Pattern('[foo|bar]').compiled == Pattern('(foo|bar)?').compiled
|
||||
|
||||
async def test_one_or_more_of():
|
||||
async def test_one_or_more_of(get_transcription):
|
||||
p = Pattern('{foo|bar}')
|
||||
assert p.compiled == r'(?:(?:foo|bar)\s?)+'
|
||||
assert await p.match('foo')
|
||||
assert await p.match('bar')
|
||||
assert not await p.match('')
|
||||
assert (await p.match('bbb foo cccc'))[0].substring == 'foo'
|
||||
assert (await p.match('bbb bar cccc'))[0].substring == 'bar'
|
||||
assert (await p.match('bbb foo bar cccc'))[0].substring == 'foo bar'
|
||||
assert not await p.match('bbb cccc')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'(?:(?:foo|bar)\s?)+'
|
||||
assert await p.match(get_transcription('foo'), Localizer())
|
||||
assert await p.match(get_transcription('bar'), Localizer())
|
||||
assert not await p.match(get_transcription(''), Localizer())
|
||||
assert (await p.match(get_transcription('bbb foo cccc'), Localizer()))[0].subtrack.text == 'foo'
|
||||
assert (await p.match(get_transcription('bbb bar cccc'), Localizer()))[0].subtrack.text == 'bar'
|
||||
assert (await p.match(get_transcription('bbb foo bar cccc'), Localizer()))[0].subtrack.text == 'foo bar'
|
||||
assert not await p.match(get_transcription('bbb cccc'), Localizer())
|
||||
|
||||
p = Pattern('Some {foo|bar} here')
|
||||
assert p.compiled == r'Some (?:(?:foo|bar)\s?)+ here'
|
||||
assert await p.match('Some foo here')
|
||||
assert await p.match('Some bar here')
|
||||
assert not await p.match('Some here')
|
||||
assert (await p.match('bbb Some foo here cccc'))[0].substring == 'Some foo here'
|
||||
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
||||
assert (await p.match('bbb Some foo bar here cccc'))[0].substring == 'Some foo bar here'
|
||||
assert not await p.match('Some foo')
|
||||
p.get_compiled('en', Localizer())
|
||||
# assert p._compiled['en'] == r'Some (?:(?:foo|bar)\s?)+ here'
|
||||
assert await p.match(get_transcription('Some foo here'), Localizer())
|
||||
assert await p.match(get_transcription('Some bar here'), Localizer())
|
||||
assert not await p.match(get_transcription('Some here'), Localizer())
|
||||
assert (await p.match(get_transcription('bbb Some foo here cccc'), Localizer()))[0].subtrack.text == 'Some foo here'
|
||||
assert (await p.match(get_transcription('bbb Some bar here cccc'), Localizer()))[0].subtrack.text == 'Some bar here'
|
||||
assert (await p.match(get_transcription('bbb Some foo bar here cccc'), Localizer()))[0].subtrack.text == 'Some foo bar here'
|
||||
assert not await p.match(get_transcription('Some foo'), Localizer())
|
||||
|
@@ -20,7 +20,6 @@ def transcription_track():
|
||||
|
||||
def check_sorted(track: TranscriptionTrack):
|
||||
for i, word in enumerate(track.result):
|
||||
print(i, word)
|
||||
assert i == 0 or track.result[i-1].end <= word.start
|
||||
|
||||
def test_replace(transcription_track):
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import anyio
|
||||
from datetime import timedelta
|
||||
from stark.voice_assistant import Mode
|
||||
@@ -6,11 +5,11 @@ from stark.models.transcription import Transcription, TranscriptionTrack
|
||||
from conftest import SpeechRecognizerMock
|
||||
|
||||
|
||||
async def test_background_command_with_waiting_mode(voice_assistant, autojump_clock):
|
||||
async def test_background_command_with_waiting_mode(voice_assistant, get_transcription, autojump_clock):
|
||||
async with voice_assistant() as voice_assistant:
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='background min'))
|
||||
get_transcription('background min')
|
||||
)
|
||||
|
||||
await anyio.sleep(0.2) # allow to capture first command response
|
||||
@@ -38,17 +37,17 @@ async def test_background_command_with_waiting_mode(voice_assistant, autojump_cl
|
||||
# interact to reset timeout mode and repeat saved response
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='test'))
|
||||
get_transcription('test')
|
||||
)
|
||||
await anyio.sleep(0.2) # allow to capture command response
|
||||
assert len(voice_assistant.speech_synthesizer.results) == 2
|
||||
assert [r.text for r in voice_assistant.speech_synthesizer.results] == ['test', 'Finished background task']
|
||||
|
||||
async def test_background_command_with_inactive_mode(voice_assistant, autojump_clock):
|
||||
async def test_background_command_with_inactive_mode(voice_assistant, autojump_clock, get_transcription):
|
||||
async with voice_assistant() as voice_assistant:
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='background min'))
|
||||
get_transcription('background min')
|
||||
)
|
||||
|
||||
await anyio.sleep(0.2) # allow to capture first command response
|
||||
@@ -72,18 +71,18 @@ async def test_background_command_with_inactive_mode(voice_assistant, autojump_c
|
||||
# interact to reset timeout mode and repeat saved response
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='test'))
|
||||
get_transcription('test')
|
||||
)
|
||||
await anyio.sleep(0.2) # allow to capture command response
|
||||
assert len(voice_assistant.speech_synthesizer.results) == 2
|
||||
assert voice_assistant.speech_synthesizer.results.pop(0).text == 'test'
|
||||
assert voice_assistant.speech_synthesizer.results.pop(0).text == 'Finished background task'
|
||||
|
||||
async def test_background_waiting_needs_input(voice_assistant, autojump_clock):
|
||||
async def test_background_waiting_needs_input(voice_assistant, autojump_clock, get_transcription):
|
||||
async with voice_assistant() as voice_assistant:
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='background needs input'))
|
||||
get_transcription('background needs input')
|
||||
)
|
||||
|
||||
await anyio.sleep(0.2) # allow to capture first command response
|
||||
@@ -93,7 +92,7 @@ async def test_background_waiting_needs_input(voice_assistant, autojump_clock):
|
||||
|
||||
# wait for command to finish
|
||||
await anyio.sleep(1)
|
||||
|
||||
|
||||
# voice assistant should save all responses for later
|
||||
assert len(voice_assistant._responses) == 8
|
||||
assert len(voice_assistant.speech_synthesizer.results) == 8
|
||||
@@ -106,7 +105,7 @@ async def test_background_waiting_needs_input(voice_assistant, autojump_clock):
|
||||
# interact to reset timeout mode
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='test'))
|
||||
get_transcription('test')
|
||||
)
|
||||
|
||||
await anyio.sleep(0.2) # allow to capture command response
|
||||
@@ -121,7 +120,7 @@ async def test_background_waiting_needs_input(voice_assistant, autojump_clock):
|
||||
# interact to emulate user input and continue repeating responses
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='test'))
|
||||
get_transcription('test')
|
||||
)
|
||||
await anyio.sleep(0.2) # allow to capture command response
|
||||
|
||||
@@ -131,11 +130,11 @@ async def test_background_waiting_needs_input(voice_assistant, autojump_clock):
|
||||
for response in ['test', 'Fourth response', 'Fifth response', 'Sixth response', 'Finished long background task']:
|
||||
assert voice_assistant.speech_synthesizer.results.pop(0).text == response
|
||||
|
||||
async def test_background_waiting_with_context(voice_assistant, autojump_clock):
|
||||
async def test_background_waiting_with_context(voice_assistant, autojump_clock, get_transcription):
|
||||
async with voice_assistant() as voice_assistant:
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='background with context'))
|
||||
get_transcription('background with context')
|
||||
)
|
||||
|
||||
# force a timeout
|
||||
@@ -157,17 +156,17 @@ async def test_background_waiting_with_context(voice_assistant, autojump_clock):
|
||||
# interact to reset timeout mode, voice assistant should reset context, repeat responses and add response context
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='lorem ipsum dolor'))
|
||||
get_transcription('lorem ipsum dolor')
|
||||
)
|
||||
await anyio.sleep(0.2) # allow to capture command response
|
||||
assert len(voice_assistant.speech_synthesizer.results) == 2
|
||||
assert len(voice_assistant.commands_context._context_queue) == 2
|
||||
|
||||
async def test_background_waiting_remove_response(voice_assistant, autojump_clock):
|
||||
async def test_background_waiting_remove_response(voice_assistant, autojump_clock, get_transcription):
|
||||
async with voice_assistant() as voice_assistant:
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='background remove response'))
|
||||
get_transcription('background remove response')
|
||||
)
|
||||
voice_assistant.speech_synthesizer.results.clear()
|
||||
|
||||
@@ -191,7 +190,7 @@ async def test_background_waiting_remove_response(voice_assistant, autojump_cloc
|
||||
# interact to reset timeout mode, check that removed response is not repeated
|
||||
await voice_assistant.speech_recognizer_did_receive_final_transcription(
|
||||
SpeechRecognizerMock(),
|
||||
Transcription(best=TranscriptionTrack(text='test'))
|
||||
get_transcription('test')
|
||||
)
|
||||
await anyio.sleep(1) # allow to capture command response
|
||||
assert len(voice_assistant.speech_synthesizer.results) == 1
|
||||
|
@@ -3,6 +3,8 @@ import pytest
|
||||
from stark.core import Pattern
|
||||
from stark.core.types import Object, ParseError
|
||||
from stark.general.classproperty import classproperty
|
||||
from stark.models.transcription import Transcription, TranscriptionTrack
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
class Lorem(Object):
|
||||
@@ -11,22 +13,27 @@ class Lorem(Object):
|
||||
def pattern(cls):
|
||||
return Pattern('* ipsum')
|
||||
|
||||
async def did_parse(self, from_string: str) -> str:
|
||||
if 'lorem' not in from_string:
|
||||
async def did_parse(self, track: TranscriptionTrack, transcription: Transcription, re_match_group: dict[str, str]) -> tuple[TranscriptionTrack, Transcription]:
|
||||
if not 'lorem' in track.text:
|
||||
raise ParseError('lorem not found')
|
||||
self.value = 'lorem'
|
||||
return 'lorem'
|
||||
time = next(iter(track.get_time('lorem')))
|
||||
return track.get_slice(*time), transcription.get_slice(*time)
|
||||
|
||||
async def test_complex_parsing_failed():
|
||||
async def test_complex_parsing_failed(get_transcription):
|
||||
with pytest.raises(ParseError):
|
||||
await Lorem.parse('some lor ipsum')
|
||||
transcription = get_transcription('some lor ipsum')
|
||||
track = transcription.best
|
||||
await Lorem.parse(track, transcription)
|
||||
|
||||
async def test_complex_parsing():
|
||||
async def test_complex_parsing(get_transcription):
|
||||
string = 'some lorem ipsum'
|
||||
match = await Lorem.parse(string)
|
||||
transcription = get_transcription(string)
|
||||
track = transcription.best
|
||||
match = await Lorem.parse(track, transcription)
|
||||
assert match
|
||||
assert match.obj
|
||||
assert match.obj.value == 'lorem'
|
||||
assert match.substring == 'lorem'
|
||||
assert (await Lorem.pattern.match(string))[0].substring == 'lorem ipsum'
|
||||
assert match.track.text == 'lorem'
|
||||
assert (await Lorem.pattern.match(transcription, Localizer()))[0].subtrack.text == 'lorem ipsum'
|
||||
|
@@ -2,6 +2,7 @@ import pytest
|
||||
from stark.core import Pattern
|
||||
from stark.core.types import Object, Word
|
||||
from stark.general.classproperty import classproperty
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
class FullName(Object):
|
||||
@@ -21,28 +22,30 @@ class ExtraParameterInAnnotation(Object):
|
||||
def pattern(cls) -> Pattern:
|
||||
return Pattern('$word1:Word $word2:Word')
|
||||
|
||||
async def test_nested_objects():
|
||||
async def test_nested_objects(get_transcription):
|
||||
Pattern.add_parameter_type(FullName)
|
||||
|
||||
p = Pattern('$name:FullName')
|
||||
assert p
|
||||
assert p.compiled
|
||||
p.get_compiled('en', Localizer())
|
||||
assert p._compiled
|
||||
|
||||
m = await p.match('John Galt')
|
||||
m = await p.match(get_transcription('John Galt'), Localizer())
|
||||
assert m
|
||||
assert set(m[0].parameters.keys()) == {'name'}
|
||||
assert isinstance(m[0].parameters['name'], FullName)
|
||||
assert m[0].parameters['name'].first == Word('John')
|
||||
assert m[0].parameters['name'].second == Word('Galt')
|
||||
|
||||
async def test_extra_parameter_in_annotation():
|
||||
async def test_extra_parameter_in_annotation(get_transcription):
|
||||
Pattern.add_parameter_type(ExtraParameterInAnnotation)
|
||||
|
||||
p = Pattern('$name:ExtraParameterInAnnotation')
|
||||
assert p
|
||||
assert p.compiled
|
||||
p.get_compiled('en', Localizer())
|
||||
assert p._compiled
|
||||
|
||||
m = await p.match('John Galt')
|
||||
m = await p.match(get_transcription('John Galt'), Localizer())
|
||||
assert m
|
||||
assert set(m[0].parameters.keys()) == {'name'}
|
||||
assert isinstance(m[0].parameters['name'], ExtraParameterInAnnotation)
|
||||
|
@@ -1,27 +1,34 @@
|
||||
from stark.core import Pattern
|
||||
from stark.core.types import String
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
def test_pattern():
|
||||
assert String.pattern == Pattern('**')
|
||||
|
||||
async def test_parse():
|
||||
assert await String.parse('')
|
||||
assert (await String.parse('foo bar baz')).obj.value == 'foo bar baz'
|
||||
async def test_parse(get_transcription):
|
||||
transcription = get_transcription('')
|
||||
track = transcription.best
|
||||
assert await String.parse(track, transcription)
|
||||
transcription = get_transcription('foo bar baz')
|
||||
track = transcription.best
|
||||
assert (await String.parse(track, transcription)).obj.value == 'foo bar baz'
|
||||
|
||||
async def test_match():
|
||||
async def test_match(get_transcription):
|
||||
p = Pattern('foo $bar:String baz')
|
||||
assert p
|
||||
|
||||
m = await p.match('foo qwerty baz')
|
||||
m = await p.match(get_transcription('foo qwerty baz'), Localizer())
|
||||
assert m
|
||||
assert m[0].parameters['bar'] == String('qwerty')
|
||||
|
||||
m = await p.match('foo lorem ipsum dolor sit amet baz')
|
||||
m = await p.match(get_transcription('foo lorem ipsum dolor sit amet baz'), Localizer())
|
||||
assert m
|
||||
assert m[0].parameters['bar'] == String('lorem ipsum dolor sit amet')
|
||||
|
||||
async def test_formatted():
|
||||
string = (await String.parse('foo bar baz')).obj
|
||||
async def test_formatted(get_transcription):
|
||||
transcription = get_transcription('foo bar baz')
|
||||
track = transcription.best
|
||||
string = (await String.parse(track, transcription)).obj
|
||||
assert str(string) == '<String value: "foo bar baz">'
|
||||
assert f'{string}' == 'foo bar baz'
|
||||
|
@@ -1,27 +1,32 @@
|
||||
from stark.core import Pattern
|
||||
from stark.core.types import Word
|
||||
from stark.general.localisation import Localizer
|
||||
|
||||
|
||||
def test_pattern():
|
||||
assert Word.pattern == Pattern('*')
|
||||
|
||||
async def test_parse():
|
||||
word = (await Word.parse('foo')).obj
|
||||
async def test_parse(get_transcription):
|
||||
transcription = get_transcription('foo')
|
||||
track = transcription.best
|
||||
word = (await Word.parse(track, transcription)).obj
|
||||
assert word
|
||||
assert word.value == 'foo'
|
||||
|
||||
async def test_match():
|
||||
async def test_match(get_transcription):
|
||||
p = Pattern('foo $bar:Word baz')
|
||||
assert p
|
||||
|
||||
m = await p.match('foo qwerty baz')
|
||||
m = await p.match(get_transcription('foo qwerty baz'), Localizer())
|
||||
assert m
|
||||
assert m[0].parameters['bar'] == Word('qwerty')
|
||||
|
||||
m = await p.match('foo lorem ipsum dolor sit amet baz')
|
||||
m = await p.match(get_transcription('foo lorem ipsum dolor sit amet baz'), Localizer())
|
||||
assert not m
|
||||
|
||||
async def test_formatted():
|
||||
string = (await Word.parse('foo')).obj
|
||||
async def test_formatted(get_transcription):
|
||||
transcription = get_transcription('foo')
|
||||
track = transcription.best
|
||||
string = (await Word.parse(track, transcription)).obj
|
||||
assert str(string) == '<Word value: "foo">'
|
||||
assert f'{string}' == 'foo'
|
||||
|
Reference in New Issue
Block a user