You've already forked STARK
mirror of
https://github.com/MarkParker5/STARK.git
synced 2025-10-05 22:47:03 +02:00
add a test
This commit is contained in:
@@ -20,6 +20,8 @@ from typing import (
|
||||
)
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from typing_extensions import get_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import asyncer
|
||||
@@ -54,6 +56,11 @@ class Command(Generic[CommandRunner]):
|
||||
parameters = parameters_dict or {}
|
||||
parameters.update(kwparameters)
|
||||
|
||||
# auto fill optionals
|
||||
for param_name, param_type in self._runner.__annotations__.items():
|
||||
if param_name not in parameters and type(None) in get_args(param_type):
|
||||
parameters[param_name] = None
|
||||
|
||||
runner: AsyncCommandRunner
|
||||
|
||||
if inspect.iscoroutinefunction(self._runner) or inspect.isasyncgen(self._runner):
|
||||
|
@@ -1,16 +1,24 @@
|
||||
from __future__ import annotations
|
||||
from types import GeneratorType, AsyncGeneratorType
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from dataclasses import dataclass
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from types import AsyncGeneratorType, GeneratorType
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
|
||||
import anyio
|
||||
from asyncer import syncify
|
||||
from asyncer._main import TaskGroup
|
||||
|
||||
from ..general.dependencies import DependencyManager, default_dependency_manager
|
||||
from .command import (
|
||||
AsyncResponseHandler,
|
||||
Command,
|
||||
CommandRunner,
|
||||
Response,
|
||||
ResponseHandler,
|
||||
ResponseOptions,
|
||||
)
|
||||
from .commands_manager import CommandsManager, SearchResult
|
||||
from .command import Command, Response, ResponseHandler, AsyncResponseHandler, CommandRunner, ResponseOptions
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,18 +32,18 @@ class CommandsContextDelegate(Protocol):
|
||||
def remove_response(self, response: Response): pass
|
||||
|
||||
class CommandsContext:
|
||||
|
||||
|
||||
is_stopped = False
|
||||
commands_manager: CommandsManager
|
||||
dependency_manager: DependencyManager
|
||||
last_response: Response | None = None
|
||||
fallback_command: Command | None = None
|
||||
|
||||
|
||||
_delegate: CommandsContextDelegate | None = None
|
||||
_response_queue: list[Response]
|
||||
_context_queue: list[CommandsContextLayer]
|
||||
_task_group: TaskGroup
|
||||
|
||||
|
||||
def __init__(self, task_group: TaskGroup, commands_manager: CommandsManager, dependency_manager: DependencyManager = default_dependency_manager):
|
||||
assert isinstance(task_group, TaskGroup), task_group
|
||||
assert isinstance(commands_manager, CommandsManager)
|
||||
@@ -48,11 +56,11 @@ class CommandsContext:
|
||||
self.dependency_manager.add_dependency(None, AsyncResponseHandler, self)
|
||||
self.dependency_manager.add_dependency(None, ResponseHandler, SyncResponseHandler(self))
|
||||
self.dependency_manager.add_dependency('inject_dependencies', None, self.inject_dependencies)
|
||||
|
||||
|
||||
@property
|
||||
def delegate(self):
|
||||
return self._delegate
|
||||
|
||||
|
||||
@delegate.setter
|
||||
def delegate(self, delegate: CommandsContextDelegate | None):
|
||||
assert isinstance(delegate, CommandsContextDelegate) or delegate is None
|
||||
@@ -63,21 +71,22 @@ class CommandsContext:
|
||||
return CommandsContextLayer(self.commands_manager.commands, {})
|
||||
|
||||
async def process_string(self, string: str):
|
||||
|
||||
|
||||
if not self._context_queue:
|
||||
self._context_queue.append(self.root_context)
|
||||
|
||||
# search commands
|
||||
search_results = []
|
||||
while self._context_queue:
|
||||
|
||||
|
||||
current_context = self._context_queue[0]
|
||||
search_results = await self.commands_manager.search(string = string, commands = current_context.commands)
|
||||
|
||||
|
||||
if search_results:
|
||||
break
|
||||
else:
|
||||
self._context_queue.pop(0)
|
||||
|
||||
|
||||
if not search_results and self.fallback_command and (matches := await self.fallback_command.pattern.match(string)):
|
||||
for match in matches:
|
||||
search_results = [SearchResult(
|
||||
@@ -91,27 +100,27 @@ class CommandsContext:
|
||||
parameters = current_context.parameters
|
||||
parameters.update(search_result.match_result.parameters)
|
||||
parameters.update(self.dependency_manager.resolve(search_result.command._runner))
|
||||
|
||||
|
||||
self.run_command(search_result.command, parameters)
|
||||
|
||||
|
||||
def inject_dependencies(self, runner: Command[CommandRunner] | CommandRunner) -> CommandRunner:
|
||||
def injected_func(**kwargs) -> ResponseOptions:
|
||||
kwargs.update(self.dependency_manager.resolve(runner._runner if isinstance(runner, Command) else runner))
|
||||
return runner(**kwargs) # type: ignore
|
||||
return injected_func # type: ignore
|
||||
|
||||
|
||||
def run_command(self, command: Command, parameters: dict[str, Any] = {}):
|
||||
async def command_runner():
|
||||
command_return = await command(parameters)
|
||||
|
||||
|
||||
if isinstance(command_return, Response):
|
||||
await self.respond(command_return)
|
||||
|
||||
|
||||
elif isinstance(command_return, AsyncGeneratorType):
|
||||
async for response in command_return:
|
||||
if response:
|
||||
await self.respond(response)
|
||||
|
||||
|
||||
elif isinstance(command_return, GeneratorType):
|
||||
message = f'[WARNING] Command {command} is a sync GeneratorType that is not fully supported and may block the main thread. ' + \
|
||||
'Consider using the ResponseHandler.respond() or async approach instead.'
|
||||
@@ -119,77 +128,77 @@ class CommandsContext:
|
||||
for response in command_return:
|
||||
if response:
|
||||
await self.respond(response)
|
||||
|
||||
elif command_return is None:
|
||||
|
||||
elif command_return is None:
|
||||
pass
|
||||
|
||||
|
||||
else:
|
||||
raise TypeError(f'Command {command} returned {command_return} of type {type(command_return)} instead of Response or AsyncGeneratorType[Response]')
|
||||
|
||||
|
||||
self._task_group.soonify(command_runner)()
|
||||
|
||||
# ResponseHandler
|
||||
|
||||
|
||||
async def respond(self, response: Response): # async forces to run in main thread
|
||||
assert isinstance(response, Response)
|
||||
self._response_queue.append(response)
|
||||
|
||||
|
||||
async def unrespond(self, response: Response):
|
||||
if response in self._response_queue:
|
||||
self._response_queue.remove(response)
|
||||
self.delegate.remove_response(response)
|
||||
|
||||
|
||||
async def pop_context(self):
|
||||
self._context_queue.pop(0)
|
||||
|
||||
|
||||
# Context
|
||||
|
||||
|
||||
def pop_to_root_context(self):
|
||||
self._context_queue = [self.root_context]
|
||||
|
||||
|
||||
def add_context(self, context: CommandsContextLayer):
|
||||
self._context_queue.insert(0, context)
|
||||
|
||||
|
||||
# ResponseQueue
|
||||
|
||||
|
||||
async def handle_responses(self):
|
||||
self.is_stopped = False
|
||||
while not self.is_stopped:
|
||||
while self._response_queue:
|
||||
await self._process_response(self._response_queue.pop(0))
|
||||
await anyio.sleep(0.01)
|
||||
|
||||
|
||||
def stop(self):
|
||||
self.is_stopped = True
|
||||
|
||||
|
||||
async def _process_response(self, response: Response):
|
||||
if response is Response.repeat_last and self.last_response:
|
||||
await self._process_response(self.last_response)
|
||||
return
|
||||
|
||||
if not response is Response.repeat_last:
|
||||
|
||||
if response is not Response.repeat_last:
|
||||
self.last_response = response
|
||||
|
||||
|
||||
if response.commands:
|
||||
newContext = CommandsContextLayer(response.commands, response.parameters)
|
||||
self._context_queue.insert(0, newContext)
|
||||
|
||||
|
||||
await self.delegate.commands_context_did_receive_response(response)
|
||||
|
||||
class SyncResponseHandler: # needs for changing thread from worker to main in commands ran with asyncify
|
||||
|
||||
|
||||
async_response_handler: ResponseHandler
|
||||
|
||||
|
||||
def __init__(self, async_response_handler: ResponseHandler):
|
||||
self.async_response_handler = async_response_handler
|
||||
|
||||
|
||||
# ResponseHandler
|
||||
|
||||
|
||||
def respond(self, response: Response):
|
||||
syncify(self.async_response_handler.respond)(response)
|
||||
|
||||
|
||||
def unrespond(self, response: Response):
|
||||
syncify(self.async_response_handler.unrespond)(response)
|
||||
|
||||
|
||||
def pop_context(self):
|
||||
syncify(self.async_response_handler.pop_context)()
|
||||
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
||||
from types import UnionType
|
||||
|
||||
from asyncer import SoonValue, create_task_group
|
||||
from typing_extensions import get_args
|
||||
|
||||
from .command import AsyncResponseHandler, Command, CommandRunner, ResponseHandler
|
||||
from .patterns import MatchResult, Pattern
|
||||
@@ -101,13 +102,19 @@ class CommandsManager:
|
||||
# check that runner has all parameters from pattern
|
||||
|
||||
error_msg = f'Command {self.name}.{runner.__name__} must have all parameters from pattern;'
|
||||
pattern_params = list(
|
||||
(p.name, Pattern._parameter_types[p.type_name].type)
|
||||
pattern_params = set(
|
||||
(p.name, Pattern._parameter_types[p.type_name].type.__name__)
|
||||
for p in pattern.parameters.values()
|
||||
)
|
||||
difference = pattern_params - annotations.items()
|
||||
command_params = set(
|
||||
(k, get_args(v)[0].__name__
|
||||
if type(None) in get_args(v)
|
||||
else v.__name__)
|
||||
for k, v in annotations.items()
|
||||
)
|
||||
difference = pattern_params - command_params
|
||||
# TODO: handle unregistered parameter type as a separate error
|
||||
assert not difference, error_msg + f' pattern got {pattern_params}, function got {list(annotations.items())}, difference: {difference}'
|
||||
assert not difference, error_msg + f'\n\tPattern got {dict(pattern_params)},\n\tFunction got {dict(command_params)},\n\tDifference: {dict(difference)}'
|
||||
# assert {(p.name, p.type) for p in pattern.parameters.values()} <= annotations.items(), error_msg
|
||||
|
||||
# additional checks for DI
|
||||
|
@@ -62,6 +62,8 @@ class Pattern:
|
||||
|
||||
@classmethod
|
||||
def add_parameter_type(cls, object_type: ObjectType, parser: ObjectParser | None = None):
|
||||
from ..types import Object
|
||||
assert issubclass(object_type, Object), f'Can`t add parameter type "{object_type.__name__}": it is not a subclass of Object'
|
||||
error_msg = f'Can`t add parameter type "{object_type.__name__}": pattern parameters do not match properties annotated in class'
|
||||
# TODO: update schema and validation; handle optional parameters; handle short form where type is defined in object
|
||||
# assert object_type.pattern.parameters.items() <= object_type.__annotations__.items(), error_msg
|
||||
|
@@ -1,75 +1,68 @@
|
||||
from typing import AsyncGenerator, Type, Any, Callable
|
||||
import pytest
|
||||
from stark.core import (
|
||||
CommandsManager,
|
||||
Response,
|
||||
ResponseHandler,
|
||||
AsyncResponseHandler,
|
||||
ResponseStatus
|
||||
)
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Callable, Type
|
||||
|
||||
from stark.core import CommandsManager, Response
|
||||
from stark.core.types import Word
|
||||
from stark.general.json_encoder import StarkJsonEncoder
|
||||
import json
|
||||
|
||||
|
||||
async def test_command_json():
|
||||
manager = CommandsManager('TestManager')
|
||||
|
||||
|
||||
@manager.new('test pattern $word:Word')
|
||||
def test(var: str, word: Word, foo: int | None = None) -> Response:
|
||||
'''test command'''
|
||||
return Response(text=var)
|
||||
|
||||
|
||||
string = json.dumps(test, cls = StarkJsonEncoder)
|
||||
parsed = json.loads(string)
|
||||
|
||||
|
||||
assert parsed['name'] == 'TestManager.test'
|
||||
assert parsed['pattern']['origin'] == r'test pattern $word:Word'
|
||||
assert parsed['declaration'] == 'def test(var: str, word: Word, foo: int | None = None) -> Response'
|
||||
assert parsed['docstring'] == 'test command'
|
||||
|
||||
|
||||
async def test_async_command_complicate_type_json():
|
||||
manager = CommandsManager('TestManager')
|
||||
|
||||
|
||||
@manager.new('async test')
|
||||
async def test2(
|
||||
some: AsyncGenerator[
|
||||
Callable[
|
||||
[Any], Type
|
||||
],
|
||||
],
|
||||
list[None]
|
||||
]
|
||||
):
|
||||
return Response()
|
||||
|
||||
|
||||
string = json.dumps(test2, cls = StarkJsonEncoder)
|
||||
parsed = json.loads(string)
|
||||
|
||||
|
||||
assert parsed['name'] == 'TestManager.test2'
|
||||
assert parsed['pattern']['origin'] == r'async test'
|
||||
assert parsed['declaration'] == 'async def test2(some: AsyncGenerator)' # TODO: improve AsyncGenerator to full type
|
||||
# assert parsed['declaration'] == 'async def test2(some: AsyncGenerator[Callable[[Any], Type], list[None], None])'
|
||||
assert parsed['docstring'] == ''
|
||||
|
||||
|
||||
def test_manager_json():
|
||||
|
||||
|
||||
manager = CommandsManager('TestManager')
|
||||
|
||||
|
||||
@manager.new('')
|
||||
def test(): ...
|
||||
|
||||
|
||||
@manager.new('')
|
||||
def test2(): ...
|
||||
|
||||
|
||||
@manager.new('')
|
||||
def test3(): ...
|
||||
|
||||
|
||||
@manager.new('')
|
||||
def test4(): ...
|
||||
|
||||
|
||||
string = json.dumps(manager, cls = StarkJsonEncoder)
|
||||
parsed = json.loads(string)
|
||||
|
||||
assert parsed['name'] == 'TestManager'
|
||||
assert {c['name'] for c in parsed['commands']} == {'TestManager.test', 'TestManager.test2', 'TestManager.test3', 'TestManager.test4',}
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import pytest
|
||||
import anyio
|
||||
|
||||
|
||||
@@ -6,7 +5,7 @@ async def test_basic_search(commands_context_flow_filled, autojump_clock):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
assert len(context_delegate.responses) == 0
|
||||
assert len(context._context_queue) == 1
|
||||
|
||||
|
||||
await context.process_string('lorem ipsum dolor')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
@@ -15,30 +14,30 @@ async def test_basic_search(commands_context_flow_filled, autojump_clock):
|
||||
|
||||
async def test_second_context_layer(commands_context_flow_filled, autojump_clock):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
|
||||
await context.process_string('hello world')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
|
||||
await context.process_string('hello')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hi, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
|
||||
async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_clock):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
|
||||
await context.process_string('hello world')
|
||||
await anyio.sleep(5)
|
||||
assert len(context._context_queue) == 2
|
||||
assert len(context_delegate.responses) == 1
|
||||
context_delegate.responses.clear()
|
||||
|
||||
|
||||
await context.process_string('lorem ipsum dolor')
|
||||
await anyio.sleep(5)
|
||||
assert len(context._context_queue) == 1
|
||||
@@ -46,35 +45,35 @@ async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_c
|
||||
|
||||
async def test_context_pop_context_response_action(commands_context_flow_filled, autojump_clock):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
|
||||
await context.process_string('hello world')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
assert len(context._context_queue) == 2
|
||||
context_delegate.responses.clear()
|
||||
|
||||
|
||||
await context.process_string('bye')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Bye, world!'
|
||||
assert len(context._context_queue) == 1
|
||||
context_delegate.responses.clear()
|
||||
|
||||
|
||||
await context.process_string('hello')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
|
||||
async def test_repeat_last_answer_response_action(commands_context_flow_filled, autojump_clock):
|
||||
async with commands_context_flow_filled() as (context, context_delegate):
|
||||
|
||||
|
||||
await context.process_string('hello world')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||
context_delegate.responses.clear()
|
||||
assert len(context_delegate.responses) == 0
|
||||
|
||||
|
||||
await context.process_string('repeat')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
|
57
tests/test_commands_flow/test_complex_commands.py
Normal file
57
tests/test_commands_flow/test_complex_commands.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import anyio
|
||||
from typing_extensions import Optional
|
||||
|
||||
from stark.core import Pattern, Response
|
||||
from stark.core.types import Object
|
||||
from stark.general.classproperty import classproperty
|
||||
|
||||
|
||||
async def test_command_flow_optional_parameter(commands_context_flow, autojump_clock):
|
||||
async with commands_context_flow() as (manager, context, context_delegate):
|
||||
|
||||
class Category(Object):
|
||||
@classproperty
|
||||
def pattern(self) -> Pattern:
|
||||
return Pattern('c*')
|
||||
|
||||
class Device(Object):
|
||||
@classproperty
|
||||
def pattern(self) -> Pattern:
|
||||
return Pattern('d*')
|
||||
|
||||
class Room(Object):
|
||||
@classproperty
|
||||
def pattern(self) -> Pattern:
|
||||
return Pattern('r*')
|
||||
|
||||
Pattern.add_parameter_type(Category)
|
||||
Pattern.add_parameter_type(Device)
|
||||
Pattern.add_parameter_type(Room)
|
||||
|
||||
@manager.new('turn on ($category:Category|$device:Device|$room:Room)')
|
||||
def turn_on(category: Optional[Category], device: Optional[Device], room: Optional[Room]) -> Response:
|
||||
if category:
|
||||
return Response(text=f'Category {category.value}')
|
||||
elif device:
|
||||
return Response(text=f'Device {device.value}')
|
||||
elif room:
|
||||
return Response(text=f'Room {room.value}')
|
||||
else:
|
||||
raise ValueError('No category, device or room provided')
|
||||
|
||||
print(turn_on.pattern.compiled)
|
||||
|
||||
await context.process_string('turn on cooling')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 1
|
||||
assert context_delegate.responses[0].text == 'Category cooling'
|
||||
|
||||
await context.process_string('turn on dishwasher')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 2
|
||||
assert context_delegate.responses[1].text == 'Device dishwasher'
|
||||
|
||||
await context.process_string('turn on restroom')
|
||||
await anyio.sleep(5)
|
||||
assert len(context_delegate.responses) == 3
|
||||
assert context_delegate.responses[2].text == 'Room restroom'
|
@@ -1,3 +1,5 @@
|
||||
import pytest
|
||||
|
||||
from stark.core import Pattern
|
||||
from stark.core.patterns import rules
|
||||
|
||||
@@ -112,6 +114,14 @@ async def test_one_of():
|
||||
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
||||
assert not await p.match('Some foo')
|
||||
|
||||
@pytest.mark.skip(reason="Not implemented as a not important feature")
|
||||
async def test_one_of_with_spaces():
|
||||
p = Pattern('Hello ( foo | bar | baz ) world')
|
||||
print(p.compiled)
|
||||
assert await p.match('Hello foo world')
|
||||
assert await p.match('Hello bar world')
|
||||
assert await p.match('Hello baz world')
|
||||
|
||||
async def test_optional_one_of():
|
||||
p = Pattern('(foo|bar)?')
|
||||
assert p.compiled == r'(?:foo|bar)?'
|
||||
|
@@ -362,7 +362,6 @@ async def test_slots(pattern_str, input_str, is_match, match_str, expected_token
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_slots_required_optional_cases(cls_name, slots_dict, input_str, expected_values, expected_error):
|
||||
|
||||
print(f'{cls_name=} {input_str=} {expected_error=} {expected_values=} {slots_dict=}')
|
||||
|
Reference in New Issue
Block a user