You've already forked STARK
mirror of
https://github.com/MarkParker5/STARK.git
synced 2025-09-16 09:36:24 +02:00
improve greedy params matching
This commit is contained in:
@@ -100,8 +100,12 @@ class CommandsManager:
|
||||
|
||||
# check that runner has all parameters from pattern
|
||||
|
||||
error_msg = f'Command {self.name}.{runner.__name__} must have all parameters from pattern; ' # TODO: show difference
|
||||
assert {(p.name, p.type) for p in pattern.parameters.values()} <= annotations.items(), error_msg
|
||||
error_msg = f'Command {self.name}.{runner.__name__} must have all parameters from pattern;'
|
||||
pattern_params = list((p.name, p.type) for p in pattern.parameters.values())
|
||||
difference = pattern_params - annotations.items()
|
||||
# TODO: handle unregistered parameter type as a separate error
|
||||
assert not difference, error_msg + f' pattern got {pattern_params}, function got {list(annotations.items())}, difference: {difference}'
|
||||
# assert {(p.name, p.type) for p in pattern.parameters.values()} <= annotations.items(), error_msg
|
||||
|
||||
# additional checks for DI
|
||||
|
||||
|
@@ -64,25 +64,13 @@ class Pattern:
|
||||
|
||||
logger.debug(f"Starting looking for \"{self.compiled}\" in \"{string}\"")
|
||||
|
||||
# forces non-greedy regex of greedy objects to stretch till the end, but has fallback tail capturing mechanism
|
||||
# TODO: triple check this and whether tail capturing group is necessary
|
||||
# trailing_anchor = r'(.*?)$'
|
||||
# trailing_anchor = r'$'
|
||||
trailing_anchor = r''
|
||||
|
||||
for match in sorted(re.finditer(self.compiled+trailing_anchor, string), key = lambda match: match.start()):
|
||||
|
||||
for match in sorted(re.finditer(self.compiled, string), key = lambda match: match.start()):
|
||||
if match.start() == match.end():
|
||||
continue # skip empty
|
||||
|
||||
# start and end in string, not in match.group(0)
|
||||
# TODO: consider moving it inside the loop
|
||||
# match_start = match.start()
|
||||
# match_end = match.end()
|
||||
command_str = string[match.start():match.end()].strip()
|
||||
match_str_groups = match.groupdict()
|
||||
|
||||
# parsed_parameters: dict[str, ParameterMatch | None] = {}
|
||||
parsed_parameters: dict[str, ParameterMatch] = {}
|
||||
|
||||
logger.debug(f'Captured candidate "{command_str}"')
|
||||
@@ -97,7 +85,7 @@ class Pattern:
|
||||
prefill = {name: parameter.parsed_substr for name, parameter in parsed_parameters.items()}
|
||||
|
||||
# re-run regex only in the current command_str
|
||||
compiled = self._compile(prefill=prefill)+trailing_anchor
|
||||
compiled = self._compile(prefill=prefill)
|
||||
new_matches = list(re.finditer(compiled, command_str))
|
||||
|
||||
logger.debug(f'Recapturing parameters command_str={command_str} prefill={prefill} compiled={compiled}')
|
||||
@@ -264,14 +252,19 @@ class Pattern:
|
||||
|
||||
for parameter in self.parameters.values():
|
||||
|
||||
arg_declaration = f'\\${parameter.name}\\:{parameter.type.__name__}'
|
||||
arg_declaration = f'\\${parameter.name}\\:{parameter.type.__name__}' # NOTE: special chars escaped for regex
|
||||
if parameter.name in prefill:
|
||||
arg_pattern = re.escape(prefill[parameter.name])
|
||||
else:
|
||||
arg_pattern = parameter.type.pattern.compiled.replace('\\', r'\\')
|
||||
|
||||
if parameter.type.greedy and arg_pattern[-1] in {'*', '+', '}', '?'}:
|
||||
arg_pattern += '?' # compensate greedy did_parse with non-greedy regex TODO: review
|
||||
# compensate greedy did_parse with non-greedy regex, so it won't consume next params during initial regex
|
||||
arg_pattern += '?'
|
||||
if pattern.endswith(f'${parameter.name}:{parameter.type.__name__}'): # NOTE: regex chars are NOT escaped
|
||||
# compensate the last non_greedy regex to allow consuming till the end of the string
|
||||
# middle greedy params are limited/stretched by neighbor params
|
||||
arg_pattern += '$'
|
||||
|
||||
pattern = re.sub(arg_declaration, f'(?P<{parameter.name}>{arg_pattern})', pattern)
|
||||
|
||||
|
@@ -1,23 +1,24 @@
|
||||
from typing import Callable
|
||||
# from typing import Callable
|
||||
|
||||
|
||||
class classproperty[GetterReturnType]:
|
||||
def __init__(self, func: Callable[..., GetterReturnType]):
|
||||
if isinstance(func, (classmethod, staticmethod)):
|
||||
fget = func
|
||||
else:
|
||||
fget = classmethod(func)
|
||||
self.fget = fget
|
||||
# class classproperty[GetterReturnType]:
|
||||
# def __init__(self, func: Callable[..., GetterReturnType]):
|
||||
# if isinstance(func, (classmethod, staticmethod)):
|
||||
# fget = func
|
||||
# else:
|
||||
# fget = classmethod(func)
|
||||
# self.fget = fget
|
||||
|
||||
def __get__(self, obj, klass=None) -> GetterReturnType:
|
||||
if klass is None:
|
||||
klass = type(obj)
|
||||
return self.fget.__get__(obj, klass)()
|
||||
# def __get__(self, obj, klass=None) -> GetterReturnType:
|
||||
# if klass is None:
|
||||
# klass = type(obj)
|
||||
# return self.fget.__get__(obj, klass)()
|
||||
|
||||
# from typing import Any, Callable
|
||||
from typing import Any, Callable
|
||||
|
||||
# def classproperty[T](fget: Callable[[type[Any]], T]) -> T:
|
||||
# class _ClassProperty(property):
|
||||
# def __get__(self, cls: type[Any], owner: type[Any]) -> T: # type: ignore
|
||||
# return classmethod(fget).__get__(None, owner)()
|
||||
# return _ClassProperty(fget) # type: ignore
|
||||
|
||||
def classproperty[T](fget: Callable[[type[Any]], T]) -> T:
|
||||
class _ClassProperty(property):
|
||||
def __get__(self, cls: type[Any], owner: type[Any]) -> T: # type: ignore
|
||||
return classmethod(fget).__get__(None, owner)()
|
||||
return _ClassProperty(fget) # type: ignore
|
||||
|
@@ -150,7 +150,7 @@ class VoskSpeechRecognizer(SpeechRecognizer):
|
||||
try:
|
||||
result = KaldiMBR.parse_raw(raw_json)
|
||||
text = result.text
|
||||
# print('\nConfidence:', result.confidence)
|
||||
# print('\nConfidence:', result.confidence) # TODO: log or os.getenv("STARK_VOICE_CLI", "0") == "1"
|
||||
except ValidationError:
|
||||
try:
|
||||
result = KaldiResult.parse_raw(raw_json)
|
||||
@@ -162,7 +162,7 @@ class VoskSpeechRecognizer(SpeechRecognizer):
|
||||
if text:
|
||||
# if result.spk:
|
||||
# speaker, similarity = self._get_speaker(result.spk)
|
||||
# print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n')
|
||||
# print(f'\nSpeaker: {speaker} ({similarity * 100:.2f}%)\n') # TODO: log or os.getenv("STARK_VOICE_CLI", "0") == "1"
|
||||
|
||||
self.last_result = text
|
||||
await delegate.speech_recognizer_did_receive_final_result(text)
|
||||
@@ -197,7 +197,7 @@ class VoskSpeechRecognizer(SpeechRecognizer):
|
||||
if not best_similarity or best_similarity < self._speaker_trashold:
|
||||
matched_speaker_id = len(self._stored_speakers)
|
||||
self._stored_speakers[matched_speaker_id] = vector
|
||||
# print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%')
|
||||
# print(f'New speaker: {matched_speaker_id}, similarity: {best_similarity * 100:.2f}%') # TODO: log or os.getenv("STARK_VOICE_CLI", "0") == "1"
|
||||
best_similarity = 1
|
||||
|
||||
return cast(int, matched_speaker_id), best_similarity
|
||||
|
@@ -1,38 +1,73 @@
|
||||
import pytest
|
||||
import random
|
||||
|
||||
import anyio
|
||||
from stark.core import Pattern, Response, CommandsManager
|
||||
import pytest
|
||||
|
||||
from stark.core import CommandsManager, Pattern, Response
|
||||
from stark.core.types import Object
|
||||
from stark.general.classproperty import classproperty
|
||||
import random
|
||||
|
||||
|
||||
async def test_multiple_commands(commands_context_flow, autojump_clock):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
def foobar():
|
||||
return Response(text='foo!')
|
||||
|
||||
@manager.new('lorem * dolor')
|
||||
def lorem():
|
||||
return Response(text = 'lorem!')
|
||||
|
||||
def lorem():
|
||||
return Response(text='lorem!')
|
||||
|
||||
# original test
|
||||
await context.process_string('foo bar lorem ipsum dolor')
|
||||
await anyio.sleep(5)
|
||||
|
||||
assert len(context_delegate.responses) == 2
|
||||
assert {context_delegate.responses[0].text, context_delegate.responses[1].text} == {'foo!', 'lorem!'}
|
||||
|
||||
async def test_two_commands_greedy_param(commands_context_flow, autojump_clock):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
class AnotherGreedy(Object):
|
||||
|
||||
@classproperty
|
||||
def greedy(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classproperty
|
||||
def pattern(cls):
|
||||
return Pattern('**')
|
||||
|
||||
async def did_parse(self, from_string: str) -> str:
|
||||
# print(f'Parsing Greedy from "{from_string}"')
|
||||
self.value = from_string
|
||||
return from_string
|
||||
|
||||
Pattern.add_parameter_type(AnotherGreedy)
|
||||
|
||||
@manager.new('command1 $g:AnotherGreedy')
|
||||
def cmd1(g: AnotherGreedy):
|
||||
return Response(text=f'cmd1:{g.value}')
|
||||
|
||||
@manager.new('command2')
|
||||
def cmd2():
|
||||
return Response(text='cmd2!')
|
||||
|
||||
await context.process_string('command1 some words command2')
|
||||
await anyio.sleep(1)
|
||||
texts = {resp.text for resp in context_delegate.responses[-2:]}
|
||||
assert texts == {'cmd1:some words', 'cmd2!'}
|
||||
|
||||
async def test_repeating_command(commands_context_flow, autojump_clock):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
|
||||
@manager.new('lorem * dolor')
|
||||
def lorem():
|
||||
def lorem():
|
||||
return Response(text = 'lorem!')
|
||||
|
||||
|
||||
await context.process_string('lorem pisum dolor lorem ipsutest_repeating_commanduum dolor sit amet')
|
||||
await anyio.sleep(5)
|
||||
|
||||
|
||||
assert len(context_delegate.responses) == 2
|
||||
assert context_delegate.responses[0].text == 'lorem!'
|
||||
assert context_delegate.responses[1].text == 'lorem!'
|
||||
@@ -41,13 +76,13 @@ async def test_overlapping_commands_less_priority_cut(commands_context_flow, aut
|
||||
manager = CommandsManager()
|
||||
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
|
||||
@manager.new('* baz')
|
||||
def baz():
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
|
||||
result = await manager.search('foo bar test baz')
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.substring == 'foo bar test'
|
||||
@@ -55,83 +90,83 @@ async def test_overlapping_commands_less_priority_cut(commands_context_flow, aut
|
||||
|
||||
async def test_overlapping_commands_priority_cut(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
|
||||
|
||||
@manager.new('foo bar *')
|
||||
def foobar():
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
|
||||
@manager.new('*t baz')
|
||||
def baz():
|
||||
def baz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
|
||||
result = await manager.search('foo bar test baz')
|
||||
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0].match_result.substring == 'foo bar'
|
||||
assert result[1].match_result.substring == 'test baz'
|
||||
|
||||
async def test_overlapping_commands_remove(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
|
||||
result = await manager.search('foo bar baz')
|
||||
assert len(result) == 1
|
||||
assert result[0].command == foobar
|
||||
|
||||
|
||||
async def test_overlapping_commands_remove_inverse(commands_context_flow, autojump_clock):
|
||||
manager = CommandsManager()
|
||||
|
||||
|
||||
@manager.new('bar baz')
|
||||
def barbaz():
|
||||
def barbaz():
|
||||
return Response(text = 'baz!')
|
||||
|
||||
|
||||
@manager.new('foo bar')
|
||||
def foobar():
|
||||
def foobar():
|
||||
return Response(text = 'foo!')
|
||||
|
||||
|
||||
result = await manager.search('foo bar baz')
|
||||
assert len(result) == 1
|
||||
assert result[0].command == barbaz
|
||||
|
||||
|
||||
@pytest.mark.skip(reason = 'Cache is deprecated and not working properly anymore because of new concurrent algorithm; need new async lru cache implementation')
|
||||
async def test_objects_parse_caching(commands_context_flow, autojump_clock):
|
||||
class Mock(Object):
|
||||
|
||||
|
||||
parsing_counter = 0
|
||||
|
||||
|
||||
@classproperty
|
||||
def pattern(cls):
|
||||
return Pattern('*')
|
||||
|
||||
|
||||
async def did_parse(self, from_string: str) -> str:
|
||||
Mock.parsing_counter += 1
|
||||
return from_string
|
||||
|
||||
|
||||
mock_name = f'Mock{random.randint(0, 10**10)}'
|
||||
Mock.__name__ = mock_name # prevent name collision on paralell tests
|
||||
|
||||
|
||||
manager = CommandsManager()
|
||||
Pattern.add_parameter_type(Mock)
|
||||
|
||||
|
||||
@manager.new(f'hello $mock:{mock_name}')
|
||||
def hello(mock: Mock): pass
|
||||
|
||||
|
||||
@manager.new(f'hello $mock:{mock_name} 2')
|
||||
def hello2(mock: Mock): pass
|
||||
|
||||
|
||||
@manager.new(f'hello $mock:{mock_name} 22')
|
||||
def hello22(mock: Mock): pass
|
||||
|
||||
|
||||
@manager.new(f'test $mock:{mock_name}')
|
||||
async def test(mock: Mock): pass
|
||||
|
||||
|
||||
assert Mock.parsing_counter == 0
|
||||
await manager.search('hello foobar 22')
|
||||
assert Mock.parsing_counter == 1
|
||||
|
@@ -104,41 +104,34 @@ Pattern.add_parameter_type(Bar)
|
||||
Pattern.add_parameter_type(Baz)
|
||||
Pattern.add_parameter_type(Greedy)
|
||||
|
||||
@pytest.mark.parametrize('string', ['foo bar', 'hey foo bar two']) #, 'hey foo one bar two']) TODO: add support for just enum of param w/o exact pattern structure
|
||||
async def test_complex_parsing__wildcard_params(string):
|
||||
print('Testing:', string)
|
||||
pattern = Pattern('$f:Foo $b:Bar')
|
||||
matches = await pattern.match(string)
|
||||
expected = {'f': 'foo', 'b': 'bar'}
|
||||
assert matches
|
||||
assert {name: obj.value if obj else None for name, obj in matches[0].parameters.items()} == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_string,expected",
|
||||
"pattern_string,input_string,expected_params",
|
||||
[
|
||||
("foo", {"f": "foo", "b": None, "z": None}),
|
||||
("foo bar", {"f": "foo", "b": "bar", "z": None}),
|
||||
("foo bar baz", {"f": "foo", "b": "bar", "z": "baz"}),
|
||||
("bar baz", {"f": None, "b": "bar", "z": "baz"}),
|
||||
("foo baz", {"f": "foo", "b": None, "z": "baz"}),
|
||||
# wildcard regex params
|
||||
('$f:Foo $b:Bar', 'foo bar', {'f': 'foo', 'b': 'bar'}),
|
||||
('$f:Foo $b:Bar', 'hey foo bar two', {'f': 'foo', 'b': 'bar'}),
|
||||
#, 'hey foo one bar two' TODO: add support for just enum of param w/o exact pattern structure
|
||||
# optional params with wildcard regex
|
||||
('$f:Foo? ?$b:Bar? ?$z:Baz?', 'foo', {"f": "foo", "b": None, "z": None}),
|
||||
('$f:Foo? ?$b:Bar? ?$z:Baz?', 'foo bar', {"f": "foo", "b": "bar", "z": None}),
|
||||
('$f:Foo? ?$b:Bar? ?$z:Baz?', 'foo bar baz', {"f": "foo", "b": "bar", "z": "baz"}),
|
||||
('$f:Foo? ?$b:Bar? ?$z:Baz?', 'bar baz', {"f": None, "b": "bar", "z": "baz"}),
|
||||
('$f:Foo? ?$b:Bar? ?$z:Baz?', 'foo baz', {"f": "foo", "b": None, "z": "baz"}),
|
||||
# greedy and trailing anchor
|
||||
('command1 $g:Greedy end', 'command1 a few words of greedy end', {"g": "a few words of greedy"}),
|
||||
('command1 $g:Greedy', 'command1 a few words of greedy', {"g": "a few words of greedy"}),
|
||||
# greedy with other params
|
||||
('command1 $g:Greedy $f:Foo', 'command1 a few words of greedy foo', {"g": "a few words of greedy", "f": "foo"}),
|
||||
('command1 $g:Greedy $ag:Greedy', 'command1 a few words of greedy another greedy words', {"g": "a", "ag": "few words of greedy another greedy words"}), # TODO: review
|
||||
# greedy with optional params, note optional spaces
|
||||
('$g:Greedy ?$f:Foo? ?$b:Bar?$', 'one two three', {"g": "one two three", "f": None, "b": None}),
|
||||
('$g:Greedy ?$f:Foo? ?$b:Bar?$', 'one two foo bar', {"g": "one two", "f": "foo", "b": "bar"}),
|
||||
]
|
||||
)
|
||||
async def test_complex_parsing__optional_wildcard(input_string: str, expected: dict[str, str | None]) -> None:
|
||||
print('Expected:', input_string, expected)
|
||||
pattern = Pattern('$f:Foo? ?$b:Bar? ?$z:Baz?') # TODO: better solution for optional space
|
||||
async def test_complex_parsing__parametrized(pattern_string: str, input_string: str, expected_params: dict[str, str | None]) -> None:
|
||||
pattern = Pattern(pattern_string)
|
||||
matches = await pattern.match(input_string)
|
||||
print(f'Pattern: {pattern_string} "{pattern.compiled}", Input: {input_string}, Expected Params: {expected_params}')
|
||||
assert matches
|
||||
assert {name: obj.value if obj else None for name, obj in matches[0].parameters.items()} == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_string,expected",
|
||||
[
|
||||
("one two three", {"g": "one two three", "f": None, "b": None}),
|
||||
("one two foo bar", {"g": "one two", "f": "foo", "b": "bar"}),
|
||||
]
|
||||
)
|
||||
async def test_complex_parsing__greedy_and_optional_wildcard(input_string: str, expected: dict[str, str | None]) -> None:
|
||||
print('Expected:', input_string, expected)
|
||||
matches = await Pattern('$g:Greedy ?$f:Foo? ?$b:Bar?$').match(input_string)
|
||||
assert matches
|
||||
assert {name: obj.value if obj else None for name, obj in matches[0].parameters.items()} == expected
|
||||
print(f'Match: {matches[0].substring}, Got Params: {matches[0].parameters}')
|
||||
assert {name: obj.value if obj else None for name, obj in matches[0].parameters.items()} == expected_params
|
||||
|
Reference in New Issue
Block a user