diff --git a/azure/functions/__init__.py b/azure/functions/__init__.py index 9f222304..a5f08660 100644 --- a/azure/functions/__init__.py +++ b/azure/functions/__init__.py @@ -11,7 +11,7 @@ from .kafka import KafkaEvent, KafkaConverter, KafkaTriggerConverter # NoQA from ._queue import QueueMessage # NoQA from ._servicebus import ServiceBusMessage # NoQA -from ._durable_functions import OrchestrationContext # NoQA +from ._durable_functions import OrchestrationContext, EntityContext # NoQA from .meta import get_binding_registry # NoQA # Import binding implementations to register them @@ -47,6 +47,7 @@ 'KafkaConverter', 'KafkaTriggerConverter', 'OrchestrationContext', + 'EntityContext', 'QueueMessage', 'ServiceBusMessage', 'TimerRequest', diff --git a/azure/functions/_durable_functions.py b/azure/functions/_durable_functions.py index 7a14e32b..aa533679 100644 --- a/azure/functions/_durable_functions.py +++ b/azure/functions/_durable_functions.py @@ -109,3 +109,31 @@ def __repr__(self): def __str__(self): return self.__body + + +class EntityContext(_abc.OrchestrationContext): + """A durable function entity context. + + :param str body: + The body of orchestration context json. + """ + + def __init__(self, + body: Union[str, bytes]) -> None: + if isinstance(body, str): + self.__body = body + if isinstance(body, bytes): + self.__body = body.decode('utf-8') + + @property + def body(self) -> str: + return self.__body + + def __repr__(self): + return ( + f'' + ) + + def __str__(self): + return self.__body diff --git a/azure/functions/durable_functions.py b/azure/functions/durable_functions.py index 7b230135..16c4fc4c 100644 --- a/azure/functions/durable_functions.py +++ b/azure/functions/durable_functions.py @@ -39,6 +39,36 @@ def has_implicit_output(cls) -> bool: return True +class EnitityTriggerConverter(meta.InConverter, + meta.OutConverter, + binding='entityTrigger', + trigger=True): + @classmethod + def check_input_type_annotation(cls, pytype): + return issubclass(pytype, _durable_functions.EntityContext) + + @classmethod + def check_output_type_annotation(cls, pytype): + # Implicit output should accept any return type + return True + + @classmethod + def decode(cls, + data: meta.Datum, *, + trigger_metadata) -> _durable_functions.EntityContext: + return _durable_functions.EntityContext(data.value) + + @classmethod + def encode(cls, obj: typing.Any, *, + expected_type: typing.Optional[type]) -> meta.Datum: + # Durable function context should be a json + return meta.Datum(type='json', value=obj) + + @classmethod + def has_implicit_output(cls) -> bool: + return True + + # Durable Function Activity Trigger class ActivityTriggerConverter(meta.InConverter, meta.OutConverter, diff --git a/tests/test_durable_functions.py b/tests/test_durable_functions.py index 461439c1..1739cdcd 100644 --- a/tests/test_durable_functions.py +++ b/tests/test_durable_functions.py @@ -6,77 +6,93 @@ from azure.functions.durable_functions import ( OrchestrationTriggerConverter, + EnitityTriggerConverter, ActivityTriggerConverter ) -from azure.functions._durable_functions import OrchestrationContext +from azure.functions._durable_functions import ( + OrchestrationContext, + EntityContext +) from azure.functions.meta import Datum +CONTEXT_CLASSES = [OrchestrationContext, EntityContext] +CONVERTERS = [OrchestrationTriggerConverter, EnitityTriggerConverter] -class TestDurableFunctions(unittest.TestCase): - def test_orchestration_context_string_body(self): - raw_string = '{ "name": "great function" }' - context = OrchestrationContext(raw_string) - self.assertIsNotNone(getattr(context, 'body', None)) - - content = json.loads(context.body) - self.assertEqual(content.get('name'), 'great function') - - def test_orchestration_context_string_cast(self): - raw_string = '{ "name": "great function" }' - context = OrchestrationContext(raw_string) - self.assertEqual(str(context), raw_string) - - content = json.loads(str(context)) - self.assertEqual(content.get('name'), 'great function') - - def test_orchestration_context_bytes_body(self): - raw_bytes = '{ "name": "great function" }'.encode('utf-8') - context = OrchestrationContext(raw_bytes) - self.assertIsNotNone(getattr(context, 'body', None)) - - content = json.loads(context.body) - self.assertEqual(content.get('name'), 'great function') - - def test_orchestration_context_bytes_cast(self): - raw_bytes = '{ "name": "great function" }'.encode('utf-8') - context = OrchestrationContext(raw_bytes) - self.assertIsNotNone(getattr(context, 'body', None)) - content = json.loads(context.body) - self.assertEqual(content.get('name'), 'great function') - - def test_orchestration_trigger_converter(self): +class TestDurableFunctions(unittest.TestCase): + def test_context_string_body(self): + body = '{ "name": "great function" }' + for ctx in CONTEXT_CLASSES: + context = ctx(body) + self.assertIsNotNone(getattr(context, 'body', None)) + + content = json.loads(context.body) + self.assertEqual(content.get('name'), 'great function') + + def test_context_string_cast(self): + body = '{ "name": "great function" }' + for ctx in CONTEXT_CLASSES: + context = ctx(body) + self.assertEqual(str(context), body) + + content = json.loads(str(context)) + self.assertEqual(content.get('name'), 'great function') + + def test_context_bytes_body(self): + body = '{ "name": "great function" }'.encode('utf-8') + for ctx in CONTEXT_CLASSES: + context = ctx(body) + self.assertIsNotNone(getattr(context, 'body', None)) + + content = json.loads(context.body) + self.assertEqual(content.get('name'), 'great function') + + def test_context_bytes_cast(self): + # TODO: this is just like the test above + # (test_orchestration_context_bytes_body) + body = '{ "name": "great function" }'.encode('utf-8') + for ctx in CONTEXT_CLASSES: + context = ctx(body) + self.assertIsNotNone(getattr(context, 'body', None)) + + content = json.loads(context.body) + self.assertEqual(content.get('name'), 'great function') + + def test_trigger_converter(self): datum = Datum(value='{ "name": "great function" }', type=str) - otc = OrchestrationTriggerConverter.decode(datum, - trigger_metadata=None) - content = json.loads(otc.body) - self.assertEqual(content.get('name'), 'great function') + for converter in CONVERTERS: + otc = converter.decode(datum, trigger_metadata=None) + content = json.loads(otc.body) + self.assertEqual(content.get('name'), 'great function') - def test_orchestration_trigger_converter_type(self): + def test_trigger_converter_type(self): datum = Datum(value='{ "name": "great function" }'.encode('utf-8'), type=bytes) - otc = OrchestrationTriggerConverter.decode(datum, - trigger_metadata=None) - content = json.loads(otc.body) - self.assertEqual(content.get('name'), 'great function') + for converter in CONVERTERS: + otc = converter.decode(datum, trigger_metadata=None) + content = json.loads(otc.body) + self.assertEqual(content.get('name'), 'great function') - def test_orchestration_trigger_check_good_annotation(self): - for dt in (OrchestrationContext,): + def test_trigger_check_good_annotation(self): + + for converter, ctx in zip(CONVERTERS, CONTEXT_CLASSES): self.assertTrue( - OrchestrationTriggerConverter.check_input_type_annotation(dt) + converter.check_input_type_annotation(ctx) ) - def test_orchestration_trigger_check_bad_annotation(self): + def test_trigger_check_bad_annotation(self): for dt in (str, bytes, int): - self.assertFalse( - OrchestrationTriggerConverter.check_input_type_annotation(dt) - ) + for converter in CONVERTERS: + self.assertFalse( + converter.check_input_type_annotation(dt) + ) - def test_orchestration_trigger_has_implicit_return(self): - self.assertTrue( - OrchestrationTriggerConverter.has_implicit_output() - ) + def test_trigger_has_implicit_return(self): + for converter in CONVERTERS: + self.assertTrue( + converter.has_implicit_output() + ) def test_activity_trigger_inputs(self): # Activity Trigger only accept string type from durable extensions