You've already forked STARK
mirror of
https://github.com/MarkParker5/STARK.git
synced 2025-10-05 22:47:03 +02:00
add a test
This commit is contained in:
@@ -20,6 +20,8 @@ from typing import (
|
|||||||
)
|
)
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from typing_extensions import get_args
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
import asyncer
|
import asyncer
|
||||||
@@ -54,6 +56,11 @@ class Command(Generic[CommandRunner]):
|
|||||||
parameters = parameters_dict or {}
|
parameters = parameters_dict or {}
|
||||||
parameters.update(kwparameters)
|
parameters.update(kwparameters)
|
||||||
|
|
||||||
|
# auto fill optionals
|
||||||
|
for param_name, param_type in self._runner.__annotations__.items():
|
||||||
|
if param_name not in parameters and type(None) in get_args(param_type):
|
||||||
|
parameters[param_name] = None
|
||||||
|
|
||||||
runner: AsyncCommandRunner
|
runner: AsyncCommandRunner
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(self._runner) or inspect.isasyncgen(self._runner):
|
if inspect.iscoroutinefunction(self._runner) or inspect.isasyncgen(self._runner):
|
||||||
|
@@ -1,16 +1,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from types import GeneratorType, AsyncGeneratorType
|
|
||||||
from typing import Any, Protocol, runtime_checkable
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import warnings
|
import warnings
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from types import AsyncGeneratorType, GeneratorType
|
||||||
|
from typing import Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from asyncer import syncify
|
from asyncer import syncify
|
||||||
from asyncer._main import TaskGroup
|
from asyncer._main import TaskGroup
|
||||||
|
|
||||||
from ..general.dependencies import DependencyManager, default_dependency_manager
|
from ..general.dependencies import DependencyManager, default_dependency_manager
|
||||||
|
from .command import (
|
||||||
|
AsyncResponseHandler,
|
||||||
|
Command,
|
||||||
|
CommandRunner,
|
||||||
|
Response,
|
||||||
|
ResponseHandler,
|
||||||
|
ResponseOptions,
|
||||||
|
)
|
||||||
from .commands_manager import CommandsManager, SearchResult
|
from .commands_manager import CommandsManager, SearchResult
|
||||||
from .command import Command, Response, ResponseHandler, AsyncResponseHandler, CommandRunner, ResponseOptions
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -24,18 +32,18 @@ class CommandsContextDelegate(Protocol):
|
|||||||
def remove_response(self, response: Response): pass
|
def remove_response(self, response: Response): pass
|
||||||
|
|
||||||
class CommandsContext:
|
class CommandsContext:
|
||||||
|
|
||||||
is_stopped = False
|
is_stopped = False
|
||||||
commands_manager: CommandsManager
|
commands_manager: CommandsManager
|
||||||
dependency_manager: DependencyManager
|
dependency_manager: DependencyManager
|
||||||
last_response: Response | None = None
|
last_response: Response | None = None
|
||||||
fallback_command: Command | None = None
|
fallback_command: Command | None = None
|
||||||
|
|
||||||
_delegate: CommandsContextDelegate | None = None
|
_delegate: CommandsContextDelegate | None = None
|
||||||
_response_queue: list[Response]
|
_response_queue: list[Response]
|
||||||
_context_queue: list[CommandsContextLayer]
|
_context_queue: list[CommandsContextLayer]
|
||||||
_task_group: TaskGroup
|
_task_group: TaskGroup
|
||||||
|
|
||||||
def __init__(self, task_group: TaskGroup, commands_manager: CommandsManager, dependency_manager: DependencyManager = default_dependency_manager):
|
def __init__(self, task_group: TaskGroup, commands_manager: CommandsManager, dependency_manager: DependencyManager = default_dependency_manager):
|
||||||
assert isinstance(task_group, TaskGroup), task_group
|
assert isinstance(task_group, TaskGroup), task_group
|
||||||
assert isinstance(commands_manager, CommandsManager)
|
assert isinstance(commands_manager, CommandsManager)
|
||||||
@@ -48,11 +56,11 @@ class CommandsContext:
|
|||||||
self.dependency_manager.add_dependency(None, AsyncResponseHandler, self)
|
self.dependency_manager.add_dependency(None, AsyncResponseHandler, self)
|
||||||
self.dependency_manager.add_dependency(None, ResponseHandler, SyncResponseHandler(self))
|
self.dependency_manager.add_dependency(None, ResponseHandler, SyncResponseHandler(self))
|
||||||
self.dependency_manager.add_dependency('inject_dependencies', None, self.inject_dependencies)
|
self.dependency_manager.add_dependency('inject_dependencies', None, self.inject_dependencies)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def delegate(self):
|
def delegate(self):
|
||||||
return self._delegate
|
return self._delegate
|
||||||
|
|
||||||
@delegate.setter
|
@delegate.setter
|
||||||
def delegate(self, delegate: CommandsContextDelegate | None):
|
def delegate(self, delegate: CommandsContextDelegate | None):
|
||||||
assert isinstance(delegate, CommandsContextDelegate) or delegate is None
|
assert isinstance(delegate, CommandsContextDelegate) or delegate is None
|
||||||
@@ -63,21 +71,22 @@ class CommandsContext:
|
|||||||
return CommandsContextLayer(self.commands_manager.commands, {})
|
return CommandsContextLayer(self.commands_manager.commands, {})
|
||||||
|
|
||||||
async def process_string(self, string: str):
|
async def process_string(self, string: str):
|
||||||
|
|
||||||
if not self._context_queue:
|
if not self._context_queue:
|
||||||
self._context_queue.append(self.root_context)
|
self._context_queue.append(self.root_context)
|
||||||
|
|
||||||
# search commands
|
# search commands
|
||||||
|
search_results = []
|
||||||
while self._context_queue:
|
while self._context_queue:
|
||||||
|
|
||||||
current_context = self._context_queue[0]
|
current_context = self._context_queue[0]
|
||||||
search_results = await self.commands_manager.search(string = string, commands = current_context.commands)
|
search_results = await self.commands_manager.search(string = string, commands = current_context.commands)
|
||||||
|
|
||||||
if search_results:
|
if search_results:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
self._context_queue.pop(0)
|
self._context_queue.pop(0)
|
||||||
|
|
||||||
if not search_results and self.fallback_command and (matches := await self.fallback_command.pattern.match(string)):
|
if not search_results and self.fallback_command and (matches := await self.fallback_command.pattern.match(string)):
|
||||||
for match in matches:
|
for match in matches:
|
||||||
search_results = [SearchResult(
|
search_results = [SearchResult(
|
||||||
@@ -91,27 +100,27 @@ class CommandsContext:
|
|||||||
parameters = current_context.parameters
|
parameters = current_context.parameters
|
||||||
parameters.update(search_result.match_result.parameters)
|
parameters.update(search_result.match_result.parameters)
|
||||||
parameters.update(self.dependency_manager.resolve(search_result.command._runner))
|
parameters.update(self.dependency_manager.resolve(search_result.command._runner))
|
||||||
|
|
||||||
self.run_command(search_result.command, parameters)
|
self.run_command(search_result.command, parameters)
|
||||||
|
|
||||||
def inject_dependencies(self, runner: Command[CommandRunner] | CommandRunner) -> CommandRunner:
|
def inject_dependencies(self, runner: Command[CommandRunner] | CommandRunner) -> CommandRunner:
|
||||||
def injected_func(**kwargs) -> ResponseOptions:
|
def injected_func(**kwargs) -> ResponseOptions:
|
||||||
kwargs.update(self.dependency_manager.resolve(runner._runner if isinstance(runner, Command) else runner))
|
kwargs.update(self.dependency_manager.resolve(runner._runner if isinstance(runner, Command) else runner))
|
||||||
return runner(**kwargs) # type: ignore
|
return runner(**kwargs) # type: ignore
|
||||||
return injected_func # type: ignore
|
return injected_func # type: ignore
|
||||||
|
|
||||||
def run_command(self, command: Command, parameters: dict[str, Any] = {}):
|
def run_command(self, command: Command, parameters: dict[str, Any] = {}):
|
||||||
async def command_runner():
|
async def command_runner():
|
||||||
command_return = await command(parameters)
|
command_return = await command(parameters)
|
||||||
|
|
||||||
if isinstance(command_return, Response):
|
if isinstance(command_return, Response):
|
||||||
await self.respond(command_return)
|
await self.respond(command_return)
|
||||||
|
|
||||||
elif isinstance(command_return, AsyncGeneratorType):
|
elif isinstance(command_return, AsyncGeneratorType):
|
||||||
async for response in command_return:
|
async for response in command_return:
|
||||||
if response:
|
if response:
|
||||||
await self.respond(response)
|
await self.respond(response)
|
||||||
|
|
||||||
elif isinstance(command_return, GeneratorType):
|
elif isinstance(command_return, GeneratorType):
|
||||||
message = f'[WARNING] Command {command} is a sync GeneratorType that is not fully supported and may block the main thread. ' + \
|
message = f'[WARNING] Command {command} is a sync GeneratorType that is not fully supported and may block the main thread. ' + \
|
||||||
'Consider using the ResponseHandler.respond() or async approach instead.'
|
'Consider using the ResponseHandler.respond() or async approach instead.'
|
||||||
@@ -119,77 +128,77 @@ class CommandsContext:
|
|||||||
for response in command_return:
|
for response in command_return:
|
||||||
if response:
|
if response:
|
||||||
await self.respond(response)
|
await self.respond(response)
|
||||||
|
|
||||||
elif command_return is None:
|
elif command_return is None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise TypeError(f'Command {command} returned {command_return} of type {type(command_return)} instead of Response or AsyncGeneratorType[Response]')
|
raise TypeError(f'Command {command} returned {command_return} of type {type(command_return)} instead of Response or AsyncGeneratorType[Response]')
|
||||||
|
|
||||||
self._task_group.soonify(command_runner)()
|
self._task_group.soonify(command_runner)()
|
||||||
|
|
||||||
# ResponseHandler
|
# ResponseHandler
|
||||||
|
|
||||||
async def respond(self, response: Response): # async forces to run in main thread
|
async def respond(self, response: Response): # async forces to run in main thread
|
||||||
assert isinstance(response, Response)
|
assert isinstance(response, Response)
|
||||||
self._response_queue.append(response)
|
self._response_queue.append(response)
|
||||||
|
|
||||||
async def unrespond(self, response: Response):
|
async def unrespond(self, response: Response):
|
||||||
if response in self._response_queue:
|
if response in self._response_queue:
|
||||||
self._response_queue.remove(response)
|
self._response_queue.remove(response)
|
||||||
self.delegate.remove_response(response)
|
self.delegate.remove_response(response)
|
||||||
|
|
||||||
async def pop_context(self):
|
async def pop_context(self):
|
||||||
self._context_queue.pop(0)
|
self._context_queue.pop(0)
|
||||||
|
|
||||||
# Context
|
# Context
|
||||||
|
|
||||||
def pop_to_root_context(self):
|
def pop_to_root_context(self):
|
||||||
self._context_queue = [self.root_context]
|
self._context_queue = [self.root_context]
|
||||||
|
|
||||||
def add_context(self, context: CommandsContextLayer):
|
def add_context(self, context: CommandsContextLayer):
|
||||||
self._context_queue.insert(0, context)
|
self._context_queue.insert(0, context)
|
||||||
|
|
||||||
# ResponseQueue
|
# ResponseQueue
|
||||||
|
|
||||||
async def handle_responses(self):
|
async def handle_responses(self):
|
||||||
self.is_stopped = False
|
self.is_stopped = False
|
||||||
while not self.is_stopped:
|
while not self.is_stopped:
|
||||||
while self._response_queue:
|
while self._response_queue:
|
||||||
await self._process_response(self._response_queue.pop(0))
|
await self._process_response(self._response_queue.pop(0))
|
||||||
await anyio.sleep(0.01)
|
await anyio.sleep(0.01)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.is_stopped = True
|
self.is_stopped = True
|
||||||
|
|
||||||
async def _process_response(self, response: Response):
|
async def _process_response(self, response: Response):
|
||||||
if response is Response.repeat_last and self.last_response:
|
if response is Response.repeat_last and self.last_response:
|
||||||
await self._process_response(self.last_response)
|
await self._process_response(self.last_response)
|
||||||
return
|
return
|
||||||
|
|
||||||
if not response is Response.repeat_last:
|
if response is not Response.repeat_last:
|
||||||
self.last_response = response
|
self.last_response = response
|
||||||
|
|
||||||
if response.commands:
|
if response.commands:
|
||||||
newContext = CommandsContextLayer(response.commands, response.parameters)
|
newContext = CommandsContextLayer(response.commands, response.parameters)
|
||||||
self._context_queue.insert(0, newContext)
|
self._context_queue.insert(0, newContext)
|
||||||
|
|
||||||
await self.delegate.commands_context_did_receive_response(response)
|
await self.delegate.commands_context_did_receive_response(response)
|
||||||
|
|
||||||
class SyncResponseHandler: # needs for changing thread from worker to main in commands ran with asyncify
|
class SyncResponseHandler: # needs for changing thread from worker to main in commands ran with asyncify
|
||||||
|
|
||||||
async_response_handler: ResponseHandler
|
async_response_handler: ResponseHandler
|
||||||
|
|
||||||
def __init__(self, async_response_handler: ResponseHandler):
|
def __init__(self, async_response_handler: ResponseHandler):
|
||||||
self.async_response_handler = async_response_handler
|
self.async_response_handler = async_response_handler
|
||||||
|
|
||||||
# ResponseHandler
|
# ResponseHandler
|
||||||
|
|
||||||
def respond(self, response: Response):
|
def respond(self, response: Response):
|
||||||
syncify(self.async_response_handler.respond)(response)
|
syncify(self.async_response_handler.respond)(response)
|
||||||
|
|
||||||
def unrespond(self, response: Response):
|
def unrespond(self, response: Response):
|
||||||
syncify(self.async_response_handler.unrespond)(response)
|
syncify(self.async_response_handler.unrespond)(response)
|
||||||
|
|
||||||
def pop_context(self):
|
def pop_context(self):
|
||||||
syncify(self.async_response_handler.pop_context)()
|
syncify(self.async_response_handler.pop_context)()
|
||||||
|
@@ -5,6 +5,7 @@ from dataclasses import dataclass
|
|||||||
from types import UnionType
|
from types import UnionType
|
||||||
|
|
||||||
from asyncer import SoonValue, create_task_group
|
from asyncer import SoonValue, create_task_group
|
||||||
|
from typing_extensions import get_args
|
||||||
|
|
||||||
from .command import AsyncResponseHandler, Command, CommandRunner, ResponseHandler
|
from .command import AsyncResponseHandler, Command, CommandRunner, ResponseHandler
|
||||||
from .patterns import MatchResult, Pattern
|
from .patterns import MatchResult, Pattern
|
||||||
@@ -101,13 +102,19 @@ class CommandsManager:
|
|||||||
# check that runner has all parameters from pattern
|
# check that runner has all parameters from pattern
|
||||||
|
|
||||||
error_msg = f'Command {self.name}.{runner.__name__} must have all parameters from pattern;'
|
error_msg = f'Command {self.name}.{runner.__name__} must have all parameters from pattern;'
|
||||||
pattern_params = list(
|
pattern_params = set(
|
||||||
(p.name, Pattern._parameter_types[p.type_name].type)
|
(p.name, Pattern._parameter_types[p.type_name].type.__name__)
|
||||||
for p in pattern.parameters.values()
|
for p in pattern.parameters.values()
|
||||||
)
|
)
|
||||||
difference = pattern_params - annotations.items()
|
command_params = set(
|
||||||
|
(k, get_args(v)[0].__name__
|
||||||
|
if type(None) in get_args(v)
|
||||||
|
else v.__name__)
|
||||||
|
for k, v in annotations.items()
|
||||||
|
)
|
||||||
|
difference = pattern_params - command_params
|
||||||
# TODO: handle unregistered parameter type as a separate error
|
# TODO: handle unregistered parameter type as a separate error
|
||||||
assert not difference, error_msg + f' pattern got {pattern_params}, function got {list(annotations.items())}, difference: {difference}'
|
assert not difference, error_msg + f'\n\tPattern got {dict(pattern_params)},\n\tFunction got {dict(command_params)},\n\tDifference: {dict(difference)}'
|
||||||
# assert {(p.name, p.type) for p in pattern.parameters.values()} <= annotations.items(), error_msg
|
# assert {(p.name, p.type) for p in pattern.parameters.values()} <= annotations.items(), error_msg
|
||||||
|
|
||||||
# additional checks for DI
|
# additional checks for DI
|
||||||
|
@@ -62,6 +62,8 @@ class Pattern:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_parameter_type(cls, object_type: ObjectType, parser: ObjectParser | None = None):
|
def add_parameter_type(cls, object_type: ObjectType, parser: ObjectParser | None = None):
|
||||||
|
from ..types import Object
|
||||||
|
assert issubclass(object_type, Object), f'Can`t add parameter type "{object_type.__name__}": it is not a subclass of Object'
|
||||||
error_msg = f'Can`t add parameter type "{object_type.__name__}": pattern parameters do not match properties annotated in class'
|
error_msg = f'Can`t add parameter type "{object_type.__name__}": pattern parameters do not match properties annotated in class'
|
||||||
# TODO: update schema and validation; handle optional parameters; handle short form where type is defined in object
|
# TODO: update schema and validation; handle optional parameters; handle short form where type is defined in object
|
||||||
# assert object_type.pattern.parameters.items() <= object_type.__annotations__.items(), error_msg
|
# assert object_type.pattern.parameters.items() <= object_type.__annotations__.items(), error_msg
|
||||||
|
@@ -1,75 +1,68 @@
|
|||||||
from typing import AsyncGenerator, Type, Any, Callable
|
import json
|
||||||
import pytest
|
from typing import Any, AsyncGenerator, Callable, Type
|
||||||
from stark.core import (
|
|
||||||
CommandsManager,
|
from stark.core import CommandsManager, Response
|
||||||
Response,
|
|
||||||
ResponseHandler,
|
|
||||||
AsyncResponseHandler,
|
|
||||||
ResponseStatus
|
|
||||||
)
|
|
||||||
from stark.core.types import Word
|
from stark.core.types import Word
|
||||||
from stark.general.json_encoder import StarkJsonEncoder
|
from stark.general.json_encoder import StarkJsonEncoder
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
async def test_command_json():
|
async def test_command_json():
|
||||||
manager = CommandsManager('TestManager')
|
manager = CommandsManager('TestManager')
|
||||||
|
|
||||||
@manager.new('test pattern $word:Word')
|
@manager.new('test pattern $word:Word')
|
||||||
def test(var: str, word: Word, foo: int | None = None) -> Response:
|
def test(var: str, word: Word, foo: int | None = None) -> Response:
|
||||||
'''test command'''
|
'''test command'''
|
||||||
return Response(text=var)
|
return Response(text=var)
|
||||||
|
|
||||||
string = json.dumps(test, cls = StarkJsonEncoder)
|
string = json.dumps(test, cls = StarkJsonEncoder)
|
||||||
parsed = json.loads(string)
|
parsed = json.loads(string)
|
||||||
|
|
||||||
assert parsed['name'] == 'TestManager.test'
|
assert parsed['name'] == 'TestManager.test'
|
||||||
assert parsed['pattern']['origin'] == r'test pattern $word:Word'
|
assert parsed['pattern']['origin'] == r'test pattern $word:Word'
|
||||||
assert parsed['declaration'] == 'def test(var: str, word: Word, foo: int | None = None) -> Response'
|
assert parsed['declaration'] == 'def test(var: str, word: Word, foo: int | None = None) -> Response'
|
||||||
assert parsed['docstring'] == 'test command'
|
assert parsed['docstring'] == 'test command'
|
||||||
|
|
||||||
async def test_async_command_complicate_type_json():
|
async def test_async_command_complicate_type_json():
|
||||||
manager = CommandsManager('TestManager')
|
manager = CommandsManager('TestManager')
|
||||||
|
|
||||||
@manager.new('async test')
|
@manager.new('async test')
|
||||||
async def test2(
|
async def test2(
|
||||||
some: AsyncGenerator[
|
some: AsyncGenerator[
|
||||||
Callable[
|
Callable[
|
||||||
[Any], Type
|
[Any], Type
|
||||||
],
|
],
|
||||||
list[None]
|
list[None]
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
return Response()
|
return Response()
|
||||||
|
|
||||||
string = json.dumps(test2, cls = StarkJsonEncoder)
|
string = json.dumps(test2, cls = StarkJsonEncoder)
|
||||||
parsed = json.loads(string)
|
parsed = json.loads(string)
|
||||||
|
|
||||||
assert parsed['name'] == 'TestManager.test2'
|
assert parsed['name'] == 'TestManager.test2'
|
||||||
assert parsed['pattern']['origin'] == r'async test'
|
assert parsed['pattern']['origin'] == r'async test'
|
||||||
assert parsed['declaration'] == 'async def test2(some: AsyncGenerator)' # TODO: improve AsyncGenerator to full type
|
assert parsed['declaration'] == 'async def test2(some: AsyncGenerator)' # TODO: improve AsyncGenerator to full type
|
||||||
# assert parsed['declaration'] == 'async def test2(some: AsyncGenerator[Callable[[Any], Type], list[None], None])'
|
# assert parsed['declaration'] == 'async def test2(some: AsyncGenerator[Callable[[Any], Type], list[None], None])'
|
||||||
assert parsed['docstring'] == ''
|
assert parsed['docstring'] == ''
|
||||||
|
|
||||||
def test_manager_json():
|
def test_manager_json():
|
||||||
|
|
||||||
manager = CommandsManager('TestManager')
|
manager = CommandsManager('TestManager')
|
||||||
|
|
||||||
@manager.new('')
|
@manager.new('')
|
||||||
def test(): ...
|
def test(): ...
|
||||||
|
|
||||||
@manager.new('')
|
@manager.new('')
|
||||||
def test2(): ...
|
def test2(): ...
|
||||||
|
|
||||||
@manager.new('')
|
@manager.new('')
|
||||||
def test3(): ...
|
def test3(): ...
|
||||||
|
|
||||||
@manager.new('')
|
@manager.new('')
|
||||||
def test4(): ...
|
def test4(): ...
|
||||||
|
|
||||||
string = json.dumps(manager, cls = StarkJsonEncoder)
|
string = json.dumps(manager, cls = StarkJsonEncoder)
|
||||||
parsed = json.loads(string)
|
parsed = json.loads(string)
|
||||||
|
|
||||||
assert parsed['name'] == 'TestManager'
|
assert parsed['name'] == 'TestManager'
|
||||||
assert {c['name'] for c in parsed['commands']} == {'TestManager.test', 'TestManager.test2', 'TestManager.test3', 'TestManager.test4',}
|
assert {c['name'] for c in parsed['commands']} == {'TestManager.test', 'TestManager.test2', 'TestManager.test3', 'TestManager.test4',}
|
||||||
|
|
@@ -1,4 +1,3 @@
|
|||||||
import pytest
|
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
|
|
||||||
@@ -6,7 +5,7 @@ async def test_basic_search(commands_context_flow_filled, autojump_clock):
|
|||||||
async with commands_context_flow_filled() as (context, context_delegate):
|
async with commands_context_flow_filled() as (context, context_delegate):
|
||||||
assert len(context_delegate.responses) == 0
|
assert len(context_delegate.responses) == 0
|
||||||
assert len(context._context_queue) == 1
|
assert len(context._context_queue) == 1
|
||||||
|
|
||||||
await context.process_string('lorem ipsum dolor')
|
await context.process_string('lorem ipsum dolor')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
@@ -15,30 +14,30 @@ async def test_basic_search(commands_context_flow_filled, autojump_clock):
|
|||||||
|
|
||||||
async def test_second_context_layer(commands_context_flow_filled, autojump_clock):
|
async def test_second_context_layer(commands_context_flow_filled, autojump_clock):
|
||||||
async with commands_context_flow_filled() as (context, context_delegate):
|
async with commands_context_flow_filled() as (context, context_delegate):
|
||||||
|
|
||||||
await context.process_string('hello world')
|
await context.process_string('hello world')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||||
assert len(context._context_queue) == 2
|
assert len(context._context_queue) == 2
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
|
|
||||||
await context.process_string('hello')
|
await context.process_string('hello')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
assert context_delegate.responses[0].text == 'Hi, world!'
|
assert context_delegate.responses[0].text == 'Hi, world!'
|
||||||
assert len(context._context_queue) == 2
|
assert len(context._context_queue) == 2
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
|
|
||||||
async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_clock):
|
async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_clock):
|
||||||
async with commands_context_flow_filled() as (context, context_delegate):
|
async with commands_context_flow_filled() as (context, context_delegate):
|
||||||
|
|
||||||
await context.process_string('hello world')
|
await context.process_string('hello world')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context._context_queue) == 2
|
assert len(context._context_queue) == 2
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
|
|
||||||
await context.process_string('lorem ipsum dolor')
|
await context.process_string('lorem ipsum dolor')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context._context_queue) == 1
|
assert len(context._context_queue) == 1
|
||||||
@@ -46,35 +45,35 @@ async def test_context_pop_on_not_found(commands_context_flow_filled, autojump_c
|
|||||||
|
|
||||||
async def test_context_pop_context_response_action(commands_context_flow_filled, autojump_clock):
|
async def test_context_pop_context_response_action(commands_context_flow_filled, autojump_clock):
|
||||||
async with commands_context_flow_filled() as (context, context_delegate):
|
async with commands_context_flow_filled() as (context, context_delegate):
|
||||||
|
|
||||||
await context.process_string('hello world')
|
await context.process_string('hello world')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||||
assert len(context._context_queue) == 2
|
assert len(context._context_queue) == 2
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
|
|
||||||
await context.process_string('bye')
|
await context.process_string('bye')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
assert context_delegate.responses[0].text == 'Bye, world!'
|
assert context_delegate.responses[0].text == 'Bye, world!'
|
||||||
assert len(context._context_queue) == 1
|
assert len(context._context_queue) == 1
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
|
|
||||||
await context.process_string('hello')
|
await context.process_string('hello')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 0
|
assert len(context_delegate.responses) == 0
|
||||||
|
|
||||||
async def test_repeat_last_answer_response_action(commands_context_flow_filled, autojump_clock):
|
async def test_repeat_last_answer_response_action(commands_context_flow_filled, autojump_clock):
|
||||||
async with commands_context_flow_filled() as (context, context_delegate):
|
async with commands_context_flow_filled() as (context, context_delegate):
|
||||||
|
|
||||||
await context.process_string('hello world')
|
await context.process_string('hello world')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
assert context_delegate.responses[0].text == 'Hello, world!'
|
assert context_delegate.responses[0].text == 'Hello, world!'
|
||||||
context_delegate.responses.clear()
|
context_delegate.responses.clear()
|
||||||
assert len(context_delegate.responses) == 0
|
assert len(context_delegate.responses) == 0
|
||||||
|
|
||||||
await context.process_string('repeat')
|
await context.process_string('repeat')
|
||||||
await anyio.sleep(5)
|
await anyio.sleep(5)
|
||||||
assert len(context_delegate.responses) == 1
|
assert len(context_delegate.responses) == 1
|
||||||
|
57
tests/test_commands_flow/test_complex_commands.py
Normal file
57
tests/test_commands_flow/test_complex_commands.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import anyio
|
||||||
|
from typing_extensions import Optional
|
||||||
|
|
||||||
|
from stark.core import Pattern, Response
|
||||||
|
from stark.core.types import Object
|
||||||
|
from stark.general.classproperty import classproperty
|
||||||
|
|
||||||
|
|
||||||
|
async def test_command_flow_optional_parameter(commands_context_flow, autojump_clock):
|
||||||
|
async with commands_context_flow() as (manager, context, context_delegate):
|
||||||
|
|
||||||
|
class Category(Object):
|
||||||
|
@classproperty
|
||||||
|
def pattern(self) -> Pattern:
|
||||||
|
return Pattern('c*')
|
||||||
|
|
||||||
|
class Device(Object):
|
||||||
|
@classproperty
|
||||||
|
def pattern(self) -> Pattern:
|
||||||
|
return Pattern('d*')
|
||||||
|
|
||||||
|
class Room(Object):
|
||||||
|
@classproperty
|
||||||
|
def pattern(self) -> Pattern:
|
||||||
|
return Pattern('r*')
|
||||||
|
|
||||||
|
Pattern.add_parameter_type(Category)
|
||||||
|
Pattern.add_parameter_type(Device)
|
||||||
|
Pattern.add_parameter_type(Room)
|
||||||
|
|
||||||
|
@manager.new('turn on ($category:Category|$device:Device|$room:Room)')
|
||||||
|
def turn_on(category: Optional[Category], device: Optional[Device], room: Optional[Room]) -> Response:
|
||||||
|
if category:
|
||||||
|
return Response(text=f'Category {category.value}')
|
||||||
|
elif device:
|
||||||
|
return Response(text=f'Device {device.value}')
|
||||||
|
elif room:
|
||||||
|
return Response(text=f'Room {room.value}')
|
||||||
|
else:
|
||||||
|
raise ValueError('No category, device or room provided')
|
||||||
|
|
||||||
|
print(turn_on.pattern.compiled)
|
||||||
|
|
||||||
|
await context.process_string('turn on cooling')
|
||||||
|
await anyio.sleep(5)
|
||||||
|
assert len(context_delegate.responses) == 1
|
||||||
|
assert context_delegate.responses[0].text == 'Category cooling'
|
||||||
|
|
||||||
|
await context.process_string('turn on dishwasher')
|
||||||
|
await anyio.sleep(5)
|
||||||
|
assert len(context_delegate.responses) == 2
|
||||||
|
assert context_delegate.responses[1].text == 'Device dishwasher'
|
||||||
|
|
||||||
|
await context.process_string('turn on restroom')
|
||||||
|
await anyio.sleep(5)
|
||||||
|
assert len(context_delegate.responses) == 3
|
||||||
|
assert context_delegate.responses[2].text == 'Room restroom'
|
@@ -1,3 +1,5 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
from stark.core import Pattern
|
from stark.core import Pattern
|
||||||
from stark.core.patterns import rules
|
from stark.core.patterns import rules
|
||||||
|
|
||||||
@@ -112,6 +114,14 @@ async def test_one_of():
|
|||||||
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
assert (await p.match('bbb Some bar here cccc'))[0].substring == 'Some bar here'
|
||||||
assert not await p.match('Some foo')
|
assert not await p.match('Some foo')
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Not implemented as a not important feature")
|
||||||
|
async def test_one_of_with_spaces():
|
||||||
|
p = Pattern('Hello ( foo | bar | baz ) world')
|
||||||
|
print(p.compiled)
|
||||||
|
assert await p.match('Hello foo world')
|
||||||
|
assert await p.match('Hello bar world')
|
||||||
|
assert await p.match('Hello baz world')
|
||||||
|
|
||||||
async def test_optional_one_of():
|
async def test_optional_one_of():
|
||||||
p = Pattern('(foo|bar)?')
|
p = Pattern('(foo|bar)?')
|
||||||
assert p.compiled == r'(?:foo|bar)?'
|
assert p.compiled == r'(?:foo|bar)?'
|
||||||
|
@@ -362,7 +362,6 @@ async def test_slots(pattern_str, input_str, is_match, match_str, expected_token
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_slots_required_optional_cases(cls_name, slots_dict, input_str, expected_values, expected_error):
|
async def test_slots_required_optional_cases(cls_name, slots_dict, input_str, expected_values, expected_error):
|
||||||
|
|
||||||
print(f'{cls_name=} {input_str=} {expected_error=} {expected_values=} {slots_dict=}')
|
print(f'{cls_name=} {input_str=} {expected_error=} {expected_values=} {slots_dict=}')
|
||||||
|
Reference in New Issue
Block a user