From a5fcdc05214b81abbedc42ac9f8b7fd852fe8d1b Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 27 Jan 2026 08:27:46 -0500 Subject: [PATCH 1/5] Add tagged typehint support. --- sdks/python/apache_beam/pvalue.py | 12 +- sdks/python/apache_beam/transforms/core.py | 24 +- .../apache_beam/transforms/ptransform.py | 28 +- .../apache_beam/typehints/decorators.py | 217 +++++++++-- .../apache_beam/typehints/decorators_test.py | 180 +++++++++ .../typehints/tagged_output_typehints_test.py | 356 ++++++++++++++++++ .../apache_beam/typehints/typehints_test.py | 2 +- 7 files changed, 781 insertions(+), 38 deletions(-) create mode 100644 sdks/python/apache_beam/typehints/tagged_output_typehints_test.py diff --git a/sdks/python/apache_beam/pvalue.py b/sdks/python/apache_beam/pvalue.py index ca9a662d399e..1cd220cc2566 100644 --- a/sdks/python/apache_beam/pvalue.py +++ b/sdks/python/apache_beam/pvalue.py @@ -246,6 +246,8 @@ def __init__( self._tags = tags self._main_tag = main_tag self._transform = transform + self._tagged_output_types = ( + transform.get_type_hints().tagged_output_types() if transform else {}) self._allow_unknown_tags = ( not tags if allow_unknown_tags is None else allow_unknown_tags) # The ApplyPTransform instance for the application of the multi FlatMap @@ -303,7 +305,7 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection: pcoll = PCollection( self._pipeline, tag=tag, - element_type=typehints.Any, + element_type=self._tagged_output_types.get(tag, typehints.Any), is_bounded=is_bounded) # Transfer the producer from the DoOutputsTuple to the resulting # PCollection. @@ -323,7 +325,11 @@ def __getitem__(self, tag: Union[int, str, None]) -> PCollection: return pcoll -class TaggedOutput(object): +TagType = TypeVar('TagType', bound=str) +ValueType = TypeVar('ValueType') + + +class TaggedOutput(Generic[TagType, ValueType]): """An object representing a tagged value. ParDo, Map, and FlatMap transforms can emit values on multiple outputs which @@ -331,7 +337,7 @@ class TaggedOutput(object): if it wants to emit on the main output and TaggedOutput objects if it wants to emit a value on a specific tagged output. """ - def __init__(self, tag: str, value: Any) -> None: + def __init__(self, tag: TagType, value: ValueType) -> None: if not isinstance(tag, str): raise TypeError( 'Attempting to create a TaggedOutput with non-string tag %s' % diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index ea11bca9474d..19c96fa51f2b 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -1834,6 +1834,17 @@ def with_outputs(self, *tags, main=None, allow_unknown_tags=None): raise ValueError( 'Main output tag %r must be different from side output tags %r.' % (main, tags)) + type_hints = self.get_type_hints() + declared_tags = set(type_hints.tagged_output_types().keys()) + requested_tags = set(tags) + + unknown = requested_tags - declared_tags + if unknown and declared_tags: # Only warn if type hints exist + logging.warning( + "Tags %s requested in with_outputs() but not declared " + "in type hints. Declared tags: %s", + unknown, + declared_tags) return _MultiParDo(self, tags, main, allow_unknown_tags) def _do_fn_info(self): @@ -2120,8 +2131,10 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name wrapper) output_hint = type_hints.simple_output_type(label) if output_hint: + tagged_output_types = type_hints.tagged_output_types() wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)])( + typehints.Iterable[_strip_output_annotations(output_hint)], + **tagged_output_types)( wrapper) # pylint: disable=protected-access wrapper._argspec_fn = fn @@ -2189,8 +2202,10 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: + tagged_output_types = type_hints.tagged_output_types() wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)])( + typehints.Iterable[_strip_output_annotations(output_hint)], + **tagged_output_types)( wrapper) # Replace the first (args) component. @@ -2261,7 +2276,10 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: - wrapper = with_output_types(_strip_output_annotations(output_hint))(wrapper) + tagged_output_types = type_hints.tagged_output_types() + wrapper = with_output_types( + _strip_output_annotations(output_hint), **tagged_output_types)( + wrapper) # Replace the first (args) component. modified_arg_names = ['tuple_element'] + arg_names[-num_defaults:] diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 94e9a0644d04..d5985b6212df 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -414,12 +414,15 @@ def with_input_types(self, input_type_hint): input_type_hint, 'Type hints for a PTransform') return super().with_input_types(input_type_hint) - def with_output_types(self, type_hint): + def with_output_types(self, type_hint, **tagged_type_hints): """Annotates the output type of a :class:`PTransform` with a type-hint. Args: type_hint (type): An instance of an allowed built-in type, a custom class, - or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. + or a :class:`~apache_beam.typehints.typehints.TypeConstraint`. This is + the type hint for the main output. + **tagged_type_hints: Type hints for tagged outputs. Each keyword argument + specifies the type for a tagged output e.g., ``errors=str``. Raises: TypeError: If **type_hint** is not a valid type-hint. See @@ -430,10 +433,22 @@ def with_output_types(self, type_hint): PTransform: A reference to the instance of this particular :class:`PTransform` object. This allows chaining type-hinting related methods. + + Example:: + result = pcoll | beam.ParDo(MyDoFn()).with_output_types( + int, # main output type + errors=str, # 'errors' tagged output type + warnings=str # 'warnings' tagged output type + ).with_outputs('errors', 'warnings', main='main') """ type_hint = native_type_compatibility.convert_to_beam_type(type_hint) validate_composite_type_param(type_hint, 'Type hints for a PTransform') - return super().with_output_types(type_hint) + for tag, hint in tagged_type_hints.items(): + tagged_type_hints[tag] = native_type_compatibility.convert_to_beam_type( + hint) + validate_composite_type_param( + tagged_type_hints[tag], f'Tagged output type hint for {tag!r}') + return super().with_output_types(type_hint, **tagged_type_hints) def with_resource_hints(self, **kwargs): # type: (...) -> PTransform """Adds resource hints to the :class:`PTransform`. @@ -479,10 +494,11 @@ def type_check_inputs_or_outputs(self, pvalueish, input_or_output): if hints is None or not any(hints): return arg_hints, kwarg_hints = hints - if arg_hints and kwarg_hints: + # Output types can have kwargs for tagged output types. + if arg_hints and kwarg_hints and input_or_output != 'output': raise TypeCheckError( - 'PTransform cannot have both positional and keyword type hints ' - 'without overriding %s._type_check_%s()' % + 'PTransform cannot have both positional and keyword input type hints' + ' without overriding %s._type_check_%s()' % (self.__class__, input_or_output)) root_hint = ( arg_hints[0] if len(arg_hints) == 1 else arg_hints or kwarg_hints) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 2d2f7981dd29..6bb6259bcb5e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -79,6 +79,7 @@ def foo((a, b)): # pytype: skip-file +import collections.abc import inspect import itertools import logging @@ -89,12 +90,16 @@ def foo((a, b)): from typing import Dict from typing import Iterable from typing import List +from typing import Literal from typing import NamedTuple from typing import Optional from typing import Tuple from typing import TypeVar from typing import Union +from typing import get_args +from typing import get_origin +from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import native_type_compatibility from apache_beam.typehints import typehints from apache_beam.typehints.native_type_compatibility import convert_to_beam_type @@ -182,6 +187,140 @@ def disable_type_annotations(): TRACEBACK_LIMIT = 5 +def _tag_and_type(t): + """Extract tag name and value type from TaggedOutput[Literal['tag'], Type]. + + Returns raw Python types - conversion to beam types happens in + _extract_output_types. + """ + args = get_args(t) + if len(args) != 2: + raise TypeError( + f"TaggedOutput expects 2 type parameters, got {len(args)}: {t}") + + literal_type, value_type = args + + if get_origin(literal_type) is not Literal: + raise TypeError( + f"First type parameter of TaggedOutput must be Literal['tag_name'], " + f"got {literal_type}. Example: TaggedOutput[Literal['errors'], str]") + + tag_string = get_args(literal_type)[0] + return tag_string, value_type + + +def _contains_tagged_output(t): + """Check if type contains TaggedOutput at a meaningful position. + + TaggedOutput only makes sense in these patterns: + - TaggedOutput[...] + - X | TaggedOutput[...] + - Iterable[TaggedOutput[...]] + - Iterable[X | TaggedOutput[...]] + """ + def _is_tagged(typ): + return get_origin(typ) is TaggedOutput or typ is TaggedOutput + + # TaggedOutput[...] + if _is_tagged(t): + return True + + origin = get_origin(t) + args = get_args(t) + + # X | TaggedOutput[...] + if origin is Union: + return any(_is_tagged(arg) for arg in args) + + # Iterable[...] + if origin is collections.abc.Iterable and len(args) == 1: + inner = args[0] + # Iterable[TaggedOutput[...]] + if _is_tagged(inner): + return True + # Iterable[X | TaggedOutput[...]] + if get_origin(inner) is Union: + return any(_is_tagged(arg) for arg in get_args(inner)) + + return False + + +def _extract_main_and_tagged(t): + """Extract main type and tagged types from a type annotation. + + Returns: + (main_type, tagged_dict) where main_type is the type without TaggedOutput + annotations (or None if no main type), and tagged_dict maps tag names to + their types. + """ + if get_origin(t) is TaggedOutput: + tag, typ = _tag_and_type(t) + return None, {tag: typ} + + if t is TaggedOutput: + raise TypeError( + "TaggedOutput in return type must include type parameters: " + "TaggedOutput[Literal['tag_name'], ValueType]") + + if get_origin(t) is not Union: + return t, {} + + main_types = [] + tagged_types = {} + for arg in get_args(t): + if get_origin(arg) is TaggedOutput: + tag, typ = _tag_and_type(arg) + tagged_types[tag] = typ + elif arg is TaggedOutput: + raise TypeError( + "TaggedOutput in return type must include type parameters: " + "TaggedOutput[Literal['tag_name'], ValueType]") + else: + main_types.append(arg) + + if len(main_types) == 0: + main_type = None + elif len(main_types) == 1: + main_type = main_types[0] + else: + main_type = Union[tuple(main_types)] + + return main_type, tagged_types + + +def _extract_output_types(return_annotation): + """Parse return annotation into (main_types, tagged_types). + + For tagged outputs to be extracted from generator/iterator functions, + users must explicitly use Iterable[T | TaggedOutput[...]] as return type. + + Returns raw Python types. Conversion to beam types happens in from_callable. + """ + if return_annotation == inspect.Signature.empty: + return [Any], {} + + # Early return if no TaggedOutput + if not _contains_tagged_output(return_annotation): + return [return_annotation], {} + + # Iterable[T | TaggedOutput[...]] + if get_origin(return_annotation) is collections.abc.Iterable: + yield_type = get_args(return_annotation)[0] + clean_yield, tagged_types = _extract_main_and_tagged(yield_type) + clean_main = clean_yield if clean_yield else Any + return [Iterable[clean_main]], tagged_types + + # TaggedOutput + if get_origin(return_annotation) is TaggedOutput: + tag, typ = _tag_and_type(return_annotation) + return [Any], {tag: typ} + + # T | TaggedOutput + main_type, tagged_types = _extract_main_and_tagged(return_annotation) + main = main_type if main_type else Any + return [main], tagged_types + + class IOTypeHints(NamedTuple): """Encapsulates all type hint information about a Beam construct. @@ -273,11 +412,14 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: param.VAR_POSITIONAL], \ 'Unsupported Parameter kind: %s' % param.kind input_args.append(convert_to_beam_type(param.annotation)) - output_args = [] - if signature.return_annotation != signature.empty: - output_args.append(convert_to_beam_type(signature.return_annotation)) - else: - output_args.append(typehints.Any) + + output_args, output_kwargs = _extract_output_types( + signature.return_annotation) + output_args = [convert_to_beam_type(t) for t in output_args] + output_kwargs = { + k: convert_to_beam_type(v) + for k, v in output_kwargs.items() + } name = getattr(fn, '__name__', '') msg = ['from_callable(%s)' % name, ' signature: %s' % signature] @@ -287,7 +429,7 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: (fn.__code__.co_filename, fn.__code__.co_firstlineno)) return IOTypeHints( input_types=(tuple(input_args), input_kwargs), - output_types=(tuple(output_args), {}), + output_types=(tuple(output_args), output_kwargs), origin=cls._make_origin([], tb=False, msg=msg)) def with_input_types(self, *args, **kwargs) -> 'IOTypeHints': @@ -308,18 +450,24 @@ def with_output_types_from(self, other: 'IOTypeHints') -> 'IOTypeHints': def simple_output_type(self, context): if self._has_output_types(): - args, kwargs = self.output_types - if len(args) != 1 or kwargs: + args, _ = self.output_types + # Note: kwargs may contain tagged output types, which are ignored here. + # Use tagged_output_types() to access those. + if len(args) != 1: raise TypeError( 'Expected single output type hint for %s but got: %s' % (context, self.output_types)) return args[0] + def tagged_output_types(self): + if not self._has_output_types(): + return {} + _, tagged_output_types = self.output_types + return tagged_output_types + def has_simple_output_type(self): """Whether there's a single positional output type.""" - return ( - self.output_types and len(self.output_types[0]) == 1 and - not self.output_types[1]) + return (self.output_types and len(self.output_types[0]) == 1) def strip_pcoll(self): from apache_beam.pipeline import Pipeline @@ -413,6 +561,7 @@ def strip_iterable(self) -> 'IOTypeHints': if self.output_types is None or not self.has_simple_output_type(): return self output_type = self.output_types[0][0] + tagged_output_types = self.output_types[1] if output_type is None or isinstance(output_type, type(None)): return self # If output_type == Optional[T]: output_type = T. @@ -427,12 +576,12 @@ def strip_iterable(self) -> 'IOTypeHints': if isinstance(output_type, typehints.TypeVariable): # We don't know what T yields, so we just assume Any. return self._replace( - output_types=((typehints.Any, ), {}), + output_types=((typehints.Any, ), tagged_output_types), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) yielded_type = typehints.get_yielded_type(output_type) return self._replace( - output_types=((yielded_type, ), {}), + output_types=((yielded_type, ), tagged_output_types), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints': @@ -782,7 +931,7 @@ def annotate_input_types(f): def with_output_types(*return_type_hint: Any, - **kwargs: Any) -> Callable[[T], T]: + **tagged_type_hints: Any) -> Callable[[T], T]: """A decorator that type-checks defined type-hints for return values(s). This decorator will type-check the return value(s) of the decorated function. @@ -822,18 +971,34 @@ def parse_ints(ints): def negate(p): return not p if p else p + For DoFns with tagged outputs, you can specify type hints for each tag: + + .. testcode:: + from apache_beam.typehints import with_input_types, with_output_types + @with_output_types(int, errors=str, warnings=str) + class MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', 'Negative value') + elif element == 0: + yield beam.pvalue.TaggedOutput('warnings', 'Zero value') + else: + yield element + Args: *return_type_hint: A type-hint specifying the proper return type of the function. This argument should either be a built-in Python type or an instance of a :class:`~apache_beam.typehints.typehints.TypeConstraint` created by 'indexing' a :class:`~apache_beam.typehints.typehints.CompositeTypeHint`. - **kwargs: Not used. + **tagged_type_hints: Type hints for tagged outputs. Each keyword argument + specifies the type for a tagged output, e.g., ``errors=str``. + Raises: - :class:`ValueError`: If any kwarg parameters are passed in, - or the length of **return_type_hint** is greater than ``1``. Or if the - inner wrapper function isn't passed a function object. + :class:`ValueError`: If the length of **return_type_hint** is greater + than ``1``. Or if the inner wrapper function isn't passed a function + object. :class:`TypeCheckError`: If the **return_type_hint** object is in invalid type-hint. @@ -841,11 +1006,6 @@ def negate(p): The original function decorated such that it enforces type-hint constraints for all return values. """ - if kwargs: - raise ValueError( - "All arguments for the 'returns' decorator must be " - "positional arguments.") - if len(return_type_hint) != 1: raise ValueError( "'returns' accepts only a single positional argument. In " @@ -854,13 +1014,20 @@ def negate(p): return_type_hint = native_type_compatibility.convert_to_beam_type( return_type_hint[0]) - validate_composite_type_param( return_type_hint, error_msg_prefix='All type hint arguments') + converted_tag_hints = {} + for tag, hint in tagged_type_hints.items(): + converted_hint = native_type_compatibility.convert_to_beam_type(hint) + validate_composite_type_param( + converted_hint, 'Tagged output type hint for %r' % tag) + converted_tag_hints[tag] = converted_hint + def annotate_output_types(f): th = getattr(f, '_type_hints', IOTypeHints.empty()) - f._type_hints = th.with_output_types(return_type_hint) # pylint: disable=protected-access + f._type_hints = th.with_output_types( # pylint: disable=protected-access + return_type_hint, **converted_tag_hints) return f return annotate_output_types diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index a2909b4e545f..570f71803d65 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -24,6 +24,7 @@ import unittest from apache_beam import Map +from apache_beam.pvalue import TaggedOutput from apache_beam.typehints import Any from apache_beam.typehints import Dict from apache_beam.typehints import List @@ -262,6 +263,75 @@ def fn(a: int) -> int: th = decorators.IOTypeHints.from_callable(fn) self.assertRegex(th.debug_str(), r'unknown') + def test_from_callable_no_tagged_output(self): + def fn(x: int) -> str: + return str(x) + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((str, ), {})) + + def fn2(x: int) -> typing.Iterable[str]: + yield str(x) + + th = decorators.IOTypeHints.from_callable(fn2) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((typehints.Iterable[str], ), {})) + + def test_from_callable_tagged_output_union(self): + def fn( + x: int + ) -> int | str | TaggedOutput[typing.Literal['errors'], float + | str] | TaggedOutput[ + typing.Literal['warnings'], str]: + return x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, + ((typehints.Union[int, str], ), { + 'errors': typehints.Union[float, str], 'warnings': str + })) + + def test_from_callable_tagged_output_iterable(self): + def fn( + x: int + ) -> typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]]: + yield x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, ((typehints.Iterable[int], ), { + 'errors': str + })) + + def test_from_callable_tagged_output_multiple_tags(self): + def fn( + x: int + ) -> ( + int | TaggedOutput[typing.Literal['errors'], str] | + TaggedOutput[typing.Literal['warnings'], str]): + return x + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual( + th.output_types, ((int, ), { + 'errors': str, 'warnings': str + })) + + def test_from_callable_tagged_output_only(self): + def fn(x: int) -> TaggedOutput[typing.Literal['errors'], str]: + pass + + th = decorators.IOTypeHints.from_callable(fn) + self.assertEqual(th.input_types, ((int, ), {})) + self.assertEqual(th.output_types, ((Any, ), { + 'errors': str + })) + def test_getcallargs_forhints(self): def fn( a: int, @@ -426,5 +496,115 @@ def fn2(a: int) -> int: _ = ['a', 'b', 'c'] | Map(fn2) # Doesn't raise - no input type hints. +class TaggedOutputExtractionTest(unittest.TestCase): + """Tests for TaggedOutput extraction helper functions.""" + def test_contains_tagged_output_true_direct(self): + t = TaggedOutput[typing.Literal['errors'], str] + self.assertTrue(decorators._contains_tagged_output(t)) + + def test_contains_tagged_output_true_in_union(self): + t = int | TaggedOutput[typing.Literal['errors'], str] + self.assertTrue(decorators._contains_tagged_output(t)) + + def test_contains_tagged_output_true_in_iterable(self): + t = typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]] + self.assertTrue(decorators._contains_tagged_output(t)) + + def test_contains_tagged_output_false_simple_type(self): + self.assertFalse(decorators._contains_tagged_output(int)) + self.assertFalse(decorators._contains_tagged_output(str)) + + def test_contains_tagged_output_false_union_no_tagged(self): + t = int | str + self.assertFalse(decorators._contains_tagged_output(t)) + + def test_contains_tagged_output_false_iterable_no_tagged(self): + t = typing.Iterable[int] + self.assertFalse(decorators._contains_tagged_output(t)) + + def test_contains_tagged_output_false_deeply_nested(self): + t = typing.List[typing.Tuple[TaggedOutput[typing.Literal['errors'], str]]] + self.assertFalse(decorators._contains_tagged_output(t)) + + def test_extract_main_and_tagged_simple_type(self): + main, tagged = decorators._extract_main_and_tagged(int) + self.assertEqual(main, int) + self.assertEqual(tagged, {}) + + def test_extract_main_and_tagged_tagged_output_only(self): + t = TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_main_and_tagged(t) + self.assertIsNone(main) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_main_and_tagged_union(self): + t = int | TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_main_and_tagged(t) + self.assertEqual(main, int) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_main_and_tagged_union_multiple_tagged(self): + t = ( + int | TaggedOutput[typing.Literal['errors'], str] + | TaggedOutput[typing.Literal['warnings'], str]) + main, tagged = decorators._extract_main_and_tagged(t) + self.assertEqual(main, int) + self.assertEqual(tagged, {'errors': str, 'warnings': str}) + + def test_extract_main_and_tagged_union_multiple_main_types(self): + t = (int | str | TaggedOutput[typing.Literal['errors'], bytes]) + main, tagged = decorators._extract_main_and_tagged(t) + # Main type should be Union[int, str] + self.assertEqual(typing.get_origin(main), typing.Union) + self.assertIn(int, typing.get_args(main)) + self.assertIn(str, typing.get_args(main)) + self.assertEqual(tagged, {'errors': bytes}) + + def test_extract_output_types_empty_signature(self): + import inspect + main, tagged = decorators._extract_output_types(inspect.Signature.empty) + self.assertEqual(main, [typing.Any]) + self.assertEqual(tagged, {}) + + def test_extract_output_types_simple_type(self): + main, tagged = decorators._extract_output_types(int) + self.assertEqual(main, [int]) + self.assertEqual(tagged, {}) + + def test_extract_output_types_union_with_tagged(self): + t = int | TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [int]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_iterable_with_tagged(self): + t = typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Iterable[int]]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_list_with_tagged_not_extracted(self): + t = typing.List[int | TaggedOutput[typing.Literal['errors'], str]] + _, tagged = decorators._extract_output_types(t) + # The whole type is converted as-is. Users should use Iterable instead. + self.assertEqual(tagged, {}) + + def test_extract_output_types_tagged_only(self): + t = TaggedOutput[typing.Literal['errors'], str] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Any]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_output_types_iterable_tagged_only(self): + t = typing.Iterable[TaggedOutput[typing.Literal['errors'], str]] + main, tagged = decorators._extract_output_types(t) + self.assertEqual(main, [typing.Iterable[typing.Any]]) + self.assertEqual(tagged, {'errors': str}) + + def test_extract_main_and_tagged_bare_tagged_output_raises(self): + with self.assertRaises(TypeError): + decorators._extract_main_and_tagged(TaggedOutput) + + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py new file mode 100644 index 000000000000..5dfae1b7e3dd --- /dev/null +++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py @@ -0,0 +1,356 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for tagged output type hints. + +This tests the implementation of type hints for tagged outputs via three styles: + +1. Decorator style: + @with_output_types(int, errors=str, warnings=str) + class MyDoFn(beam.DoFn): + ... + +2. Method chain style: + beam.ParDo(MyDoFn()).with_output_types(int, errors=str) + +3. Function annotation style: + def fn(element) -> int | TaggedOutput[Literal['errors'], str]: + ... +""" + +# pytype: skip-file + +import unittest +from typing import Iterable +from typing import Literal +from typing import Union + +import apache_beam as beam +from apache_beam.pvalue import TaggedOutput +from apache_beam.typehints import with_output_types +from apache_beam.typehints.decorators import IOTypeHints + + +class IOTypeHintsTaggedOutputTest(unittest.TestCase): + """Tests for IOTypeHints.tagged_output_types() accessor.""" + def test_empty_hints_returns_empty_dict(self): + empty = IOTypeHints.empty() + self.assertEqual(empty.tagged_output_types(), {}) + + def test_with_tagged_types(self): + hints = IOTypeHints.empty().with_output_types(int, errors=str, warnings=str) + self.assertEqual( + hints.tagged_output_types(), { + 'errors': str, 'warnings': str + }) + + def test_simple_output_type_with_tagged_types(self): + """simple_output_type() should still return main type when tags present.""" + hints = IOTypeHints.empty().with_output_types(int, errors=str, warnings=str) + self.assertEqual(hints.simple_output_type('test'), int) + + hints = IOTypeHints.empty().with_output_types( + Union[int, str], errors=str, warnings=str) + self.assertEqual(hints.simple_output_type('test'), Union[int, str]) + + def test_without_tagged_types(self): + """Without tagged types, tagged_output_types() returns empty dict.""" + hints = IOTypeHints.empty().with_output_types(int) + self.assertEqual(hints.tagged_output_types(), {}) + self.assertEqual(hints.simple_output_type('test'), int) + + +class DecoratorStyleTaggedOutputTest(unittest.TestCase): + """Tests for @with_output_types decorator style across all transforms.""" + def test_pardo_decorator_pipeline(self): + """Test that tagged types propagate through ParDo pipeline.""" + @with_output_types(int, errors=str) + class MyDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(MyDoFn()).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_decorator_pipeline(self): + """Test that tagged types propagate through Map.""" + @with_output_types(int, errors=str) + def mapfn(element): + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmap_decorator_pipeline(self): + """Test that tagged types propagate through FlatMap.""" + @with_output_types(Iterable[int], errors=str) + def flatmapfn(element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_maptuple_decorator_pipeline(self): + """Test that tagged types propagate through MapTuple.""" + @with_output_types(int, errors=str) + def maptuplefn(key, value): + if value < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + return value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.MapTuple(maptuplefn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmaptuple_decorator_pipeline(self): + """Test that tagged types propagate through FlatMapTuple.""" + @with_output_types(Iterable[int], errors=str) + def flatmaptuplefn(key, value): + if value < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + yield value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.FlatMapTuple(flatmaptuplefn).with_outputs( + 'errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +class ChainStyleTaggedOutputTest(unittest.TestCase): + """Tests for .with_output_types() method chain style across all transforms.""" + def test_pardo_chain_pipeline(self): + """Test ParDo with chained type hints.""" + class SimpleDoFn(beam.DoFn): + def process(self, element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(SimpleDoFn()).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_chain_pipeline(self): + """Test Map with chained type hints.""" + def mapfn(element): + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_output_types(int, errors=str).with_outputs( + 'errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmap_chain_pipeline(self): + """Test FlatMap with chained type hints. + + Note: For FlatMap.with_output_types(), specify the element type directly + (int), not wrapped in Iterable. The transform handles iteration internally. + """ + def flatmapfn(element): + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_maptuple_chain_pipeline(self): + """Test MapTuple with chained type hints.""" + def maptuplefn(key, value): + if value < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + return value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.MapTuple(maptuplefn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_flatmaptuple_chain_pipeline(self): + """Test FlatMapTuple with chained type hints. + + Note: For FlatMapTuple.with_output_types(), specify the element type + directly (int), not wrapped in Iterable. + """ + def flatmaptuplefn(key, value): + if value < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}') + else: + yield value * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([('a', -1), ('b', 2), ('c', 3)]) + | beam.FlatMapTuple(flatmaptuplefn).with_output_types( + int, errors=str).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +class AnnotationStyleTaggedOutputTest(unittest.TestCase): + """Tests for function annotation style across all transforms.""" + def test_map_annotation_union(self): + """Test Map with Union[int, TaggedOutput[...]] annotation.""" + def mapfn(element: int) -> int | TaggedOutput[Literal['errors'], str]: + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_map_annotation_multiple_tags(self): + """Test Map with multiple TaggedOutput types in annotation.""" + def mapfn( + element: int + ) -> int | TaggedOutput[Literal['errors'], + str] | TaggedOutput[Literal['warnings'], str]: + if element < 0: + return beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + elif element == 0: + return beam.pvalue.TaggedOutput('warnings', 'Zero value') + else: + return element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.Map(mapfn).with_outputs('errors', 'warnings', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + self.assertEqual(results.warnings.element_type, str) + + def test_flatmap_annotation_iterable(self): + """Test FlatMap with Iterable[int | TaggedOutput[...]] annotation.""" + def flatmapfn( + element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.FlatMap(flatmapfn).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + def test_pardo_annotation_process_method(self): + """Test DoFn with process method annotation.""" + class AnnotatedDoFn(beam.DoFn): + def process( + self, + element: int) -> Iterable[int | TaggedOutput[Literal['errors'], str]]: + if element < 0: + yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') + else: + yield element * 2 + + with beam.Pipeline() as p: + results = ( + p + | beam.Create([-1, 0, 1, 2]) + | beam.ParDo(AnnotatedDoFn()).with_outputs('errors', main='main')) + + self.assertEqual(results.main.element_type, int) + self.assertEqual(results.errors.element_type, str) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index 0bbc21f6739c..cec830380087 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -1421,7 +1421,7 @@ def unused_foo(): return 5, 'bar' def test_no_kwargs_accepted(self): - with self.assertRaisesRegex(ValueError, r'must be positional'): + with self.assertRaisesRegex(ValueError, r'single positional argument'): @with_output_types(m=int) def unused_foo(): From 4039780b38f670b83b2799ca28f944dc548c7503 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 29 Jan 2026 17:20:25 -0500 Subject: [PATCH 2/5] Just warn when bare tagged output --- .../apache_beam/typehints/decorators.py | 24 ++++++++++++++----- .../apache_beam/typehints/decorators_test.py | 11 ++++++--- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 6bb6259bcb5e..1bd455a61ead 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -187,6 +187,11 @@ def disable_type_annotations(): TRACEBACK_LIMIT = 5 +def _is_union_type(origin): + """Check if a type origin is a Union (typing.Union or types.UnionType).""" + return origin is Union or origin is types.UnionType + + def _tag_and_type(t): """Extract tag name and value type from TaggedOutput[Literal['tag'], Type]. @@ -219,7 +224,11 @@ def _contains_tagged_output(t): - Iterable[X | TaggedOutput[...]] """ def _is_tagged(typ): - return get_origin(typ) is TaggedOutput or typ is TaggedOutput + if typ is TaggedOutput: + logging.warning( + "TaggedOutput in return type must include type parameters: " + "TaggedOutput[Literal['tag_name'], ValueType]") + return get_origin(typ) is TaggedOutput # TaggedOutput[...] if _is_tagged(t): @@ -229,7 +238,7 @@ def _is_tagged(typ): args = get_args(t) # X | TaggedOutput[...] - if origin is Union: + if _is_union_type(origin): return any(_is_tagged(arg) for arg in args) # Iterable[...] @@ -239,7 +248,7 @@ def _is_tagged(typ): if _is_tagged(inner): return True # Iterable[X | TaggedOutput[...]] - if get_origin(inner) is Union: + if _is_union_type(get_origin(inner)): return any(_is_tagged(arg) for arg in get_args(inner)) return False @@ -258,11 +267,11 @@ def _extract_main_and_tagged(t): return None, {tag: typ} if t is TaggedOutput: - raise TypeError( + logging.warning( "TaggedOutput in return type must include type parameters: " "TaggedOutput[Literal['tag_name'], ValueType]") - if get_origin(t) is not Union: + if not _is_union_type(get_origin(t)): return t, {} main_types = [] @@ -272,9 +281,12 @@ def _extract_main_and_tagged(t): tag, typ = _tag_and_type(arg) tagged_types[tag] = typ elif arg is TaggedOutput: - raise TypeError( + logging.warning( "TaggedOutput in return type must include type parameters: " "TaggedOutput[Literal['tag_name'], ValueType]") + # Append to main types to maintain backwards compatibility. The result + # will be a union type that maps to FastPrimitivesCoder. + main_types.append(arg) else: main_types.append(arg) diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 570f71803d65..1f3ca16cbda3 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -601,9 +601,14 @@ def test_extract_output_types_iterable_tagged_only(self): self.assertEqual(main, [typing.Iterable[typing.Any]]) self.assertEqual(tagged, {'errors': str}) - def test_extract_main_and_tagged_bare_tagged_output_raises(self): - with self.assertRaises(TypeError): - decorators._extract_main_and_tagged(TaggedOutput) + def test_extract_output_types_bare_tagged_to_main(self): + with self.assertLogs(level='WARNING') as cm: + main, tagged = decorators._extract_output_types(str | TaggedOutput) + self.assertIn( + 'TaggedOutput in return type must include type parameters', + cm.output[0]) + self.assertEqual(main, [str | TaggedOutput]) + self.assertEqual(tagged, {}) if __name__ == '__main__': From 5692c120086c0ad069024f0820559126cb393bcb Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 29 Jan 2026 17:46:54 -0500 Subject: [PATCH 3/5] Remove contains tagged output check. --- .../apache_beam/typehints/decorators.py | 64 +++---------------- .../apache_beam/typehints/decorators_test.py | 30 +-------- 2 files changed, 9 insertions(+), 85 deletions(-) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 1bd455a61ead..2d07ac918f50 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -185,6 +185,7 @@ def disable_type_annotations(): TRACEBACK_LIMIT = 5 +_NO_MAIN_TYPE = object() def _is_union_type(origin): @@ -214,57 +215,17 @@ def _tag_and_type(t): return tag_string, value_type -def _contains_tagged_output(t): - """Check if type contains TaggedOutput at a meaningful position. - - TaggedOutput only makes sense in these patterns: - - TaggedOutput[...] - - X | TaggedOutput[...] - - Iterable[TaggedOutput[...]] - - Iterable[X | TaggedOutput[...]] - """ - def _is_tagged(typ): - if typ is TaggedOutput: - logging.warning( - "TaggedOutput in return type must include type parameters: " - "TaggedOutput[Literal['tag_name'], ValueType]") - return get_origin(typ) is TaggedOutput - - # TaggedOutput[...] - if _is_tagged(t): - return True - - origin = get_origin(t) - args = get_args(t) - - # X | TaggedOutput[...] - if _is_union_type(origin): - return any(_is_tagged(arg) for arg in args) - - # Iterable[...] - if origin is collections.abc.Iterable and len(args) == 1: - inner = args[0] - # Iterable[TaggedOutput[...]] - if _is_tagged(inner): - return True - # Iterable[X | TaggedOutput[...]] - if _is_union_type(get_origin(inner)): - return any(_is_tagged(arg) for arg in get_args(inner)) - - return False - - def _extract_main_and_tagged(t): """Extract main type and tagged types from a type annotation. Returns: (main_type, tagged_dict) where main_type is the type without TaggedOutput - annotations (or None if no main type), and tagged_dict maps tag names to - their types. + annotations (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag + names to their types. """ if get_origin(t) is TaggedOutput: tag, typ = _tag_and_type(t) - return None, {tag: typ} + return _NO_MAIN_TYPE, {tag: typ} if t is TaggedOutput: logging.warning( @@ -291,7 +252,7 @@ def _extract_main_and_tagged(t): main_types.append(arg) if len(main_types) == 0: - main_type = None + main_type = _NO_MAIN_TYPE elif len(main_types) == 1: main_type = main_types[0] else: @@ -311,25 +272,16 @@ def _extract_output_types(return_annotation): if return_annotation == inspect.Signature.empty: return [Any], {} - # Early return if no TaggedOutput - if not _contains_tagged_output(return_annotation): - return [return_annotation], {} - # Iterable[T | TaggedOutput[...]] if get_origin(return_annotation) is collections.abc.Iterable: yield_type = get_args(return_annotation)[0] clean_yield, tagged_types = _extract_main_and_tagged(yield_type) - clean_main = clean_yield if clean_yield else Any + clean_main = Any if clean_yield is _NO_MAIN_TYPE else clean_yield return [Iterable[clean_main]], tagged_types - # TaggedOutput - if get_origin(return_annotation) is TaggedOutput: - tag, typ = _tag_and_type(return_annotation) - return [Any], {tag: typ} - - # T | TaggedOutput + # T | TaggedOutput (or plain type with no tags) main_type, tagged_types = _extract_main_and_tagged(return_annotation) - main = main_type if main_type else Any + main = Any if main_type is _NO_MAIN_TYPE else main_type return [main], tagged_types diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 1f3ca16cbda3..4f0aed7a264e 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -498,34 +498,6 @@ def fn2(a: int) -> int: class TaggedOutputExtractionTest(unittest.TestCase): """Tests for TaggedOutput extraction helper functions.""" - def test_contains_tagged_output_true_direct(self): - t = TaggedOutput[typing.Literal['errors'], str] - self.assertTrue(decorators._contains_tagged_output(t)) - - def test_contains_tagged_output_true_in_union(self): - t = int | TaggedOutput[typing.Literal['errors'], str] - self.assertTrue(decorators._contains_tagged_output(t)) - - def test_contains_tagged_output_true_in_iterable(self): - t = typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]] - self.assertTrue(decorators._contains_tagged_output(t)) - - def test_contains_tagged_output_false_simple_type(self): - self.assertFalse(decorators._contains_tagged_output(int)) - self.assertFalse(decorators._contains_tagged_output(str)) - - def test_contains_tagged_output_false_union_no_tagged(self): - t = int | str - self.assertFalse(decorators._contains_tagged_output(t)) - - def test_contains_tagged_output_false_iterable_no_tagged(self): - t = typing.Iterable[int] - self.assertFalse(decorators._contains_tagged_output(t)) - - def test_contains_tagged_output_false_deeply_nested(self): - t = typing.List[typing.Tuple[TaggedOutput[typing.Literal['errors'], str]]] - self.assertFalse(decorators._contains_tagged_output(t)) - def test_extract_main_and_tagged_simple_type(self): main, tagged = decorators._extract_main_and_tagged(int) self.assertEqual(main, int) @@ -534,7 +506,7 @@ def test_extract_main_and_tagged_simple_type(self): def test_extract_main_and_tagged_tagged_output_only(self): t = TaggedOutput[typing.Literal['errors'], str] main, tagged = decorators._extract_main_and_tagged(t) - self.assertIsNone(main) + self.assertIs(main, decorators._NO_MAIN_TYPE) self.assertEqual(tagged, {'errors': str}) def test_extract_main_and_tagged_union(self): From c4e4d1068d7d9d8ce8875de01ecb812b87998696 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 29 Jan 2026 18:06:13 -0500 Subject: [PATCH 4/5] Mapped bare TaggedOutput to Any --- sdks/python/apache_beam/typehints/decorators.py | 10 +++++----- sdks/python/apache_beam/typehints/decorators_test.py | 8 +++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 2d07ac918f50..106765d76ba4 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -230,7 +230,9 @@ def _extract_main_and_tagged(t): if t is TaggedOutput: logging.warning( "TaggedOutput in return type must include type parameters: " - "TaggedOutput[Literal['tag_name'], ValueType]") + "TaggedOutput[Literal['tag_name'], ValueType]. " + "Bare TaggedOutput falling back to Any.") + return _NO_MAIN_TYPE, {} if not _is_union_type(get_origin(t)): return t, {} @@ -244,10 +246,8 @@ def _extract_main_and_tagged(t): elif arg is TaggedOutput: logging.warning( "TaggedOutput in return type must include type parameters: " - "TaggedOutput[Literal['tag_name'], ValueType]") - # Append to main types to maintain backwards compatibility. The result - # will be a union type that maps to FastPrimitivesCoder. - main_types.append(arg) + "TaggedOutput[Literal['tag_name'], ValueType]. " + "Bare TaggedOutput falling back to Any.") else: main_types.append(arg) diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 4f0aed7a264e..626c2ffbd497 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -573,13 +573,11 @@ def test_extract_output_types_iterable_tagged_only(self): self.assertEqual(main, [typing.Iterable[typing.Any]]) self.assertEqual(tagged, {'errors': str}) - def test_extract_output_types_bare_tagged_to_main(self): + def test_extract_output_types_bare_tagged_excluded(self): with self.assertLogs(level='WARNING') as cm: main, tagged = decorators._extract_output_types(str | TaggedOutput) - self.assertIn( - 'TaggedOutput in return type must include type parameters', - cm.output[0]) - self.assertEqual(main, [str | TaggedOutput]) + self.assertIn('Bare TaggedOutput falling back to Any', cm.output[0]) + self.assertEqual(main, [str]) self.assertEqual(tagged, {}) From 789ab66fa1fa1b8ae40aef26c5be2d321daefe94 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 31 Jan 2026 12:32:36 -0500 Subject: [PATCH 5/5] Extract tagged outputs after strip_iterable. --- sdks/python/apache_beam/transforms/core.py | 33 ++-- .../apache_beam/typehints/decorators.py | 145 +++++++++-------- .../apache_beam/typehints/decorators_test.py | 146 +++++++----------- .../typehints/tagged_output_typehints_test.py | 4 +- 4 files changed, 162 insertions(+), 166 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 19c96fa51f2b..7392a74317a3 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -824,6 +824,7 @@ def default_type_hints(self): process_type_hints = process_type_hints.strip_iterable() except ValueError as e: raise ValueError('Return value not iterable: %s: %s' % (self, e)) + process_type_hints = process_type_hints.extract_tagged_outputs() # Prefer class decorator type hints for backwards compatibility. return get_type_hints(self.__class__).with_defaults(process_type_hints) @@ -1039,6 +1040,7 @@ def default_type_hints(self): raise TypeCheckError( 'Return value not iterable: %s: %s' % (self.display_data()['fn'].value, e)) + type_hints = type_hints.extract_tagged_outputs() return type_hints def infer_output_type(self, input_type): @@ -2131,10 +2133,14 @@ def Map(fn, *args, **kwargs): # pylint: disable=invalid-name wrapper) output_hint = type_hints.simple_output_type(label) if output_hint: - tagged_output_types = type_hints.tagged_output_types() + tagged = { + k: typehints.Iterable[v] + for k, v in type_hints.tagged_output_types().items() + } wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)], - **tagged_output_types)( + typehints.Iterable[_strip_output_annotations( + output_hint, strip_tagged_output=False)], + **tagged)( wrapper) # pylint: disable=protected-access wrapper._argspec_fn = fn @@ -2202,10 +2208,14 @@ def MapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: - tagged_output_types = type_hints.tagged_output_types() + tagged = { + k: typehints.Iterable[v] + for k, v in type_hints.tagged_output_types().items() + } wrapper = with_output_types( - typehints.Iterable[_strip_output_annotations(output_hint)], - **tagged_output_types)( + typehints.Iterable[_strip_output_annotations( + output_hint, strip_tagged_output=False)], + **tagged)( wrapper) # Replace the first (args) component. @@ -2276,9 +2286,9 @@ def FlatMapTuple(fn, *args, **kwargs): # pylint: disable=invalid-name pass output_hint = type_hints.simple_output_type(label) if output_hint: - tagged_output_types = type_hints.tagged_output_types() wrapper = with_output_types( - _strip_output_annotations(output_hint), **tagged_output_types)( + _strip_output_annotations(output_hint, strip_tagged_output=False), + **type_hints.tagged_output_types())( wrapper) # Replace the first (args) component. @@ -4239,12 +4249,15 @@ def from_runner_api_parameter( return Impulse() -def _strip_output_annotations(type_hint): +def _strip_output_annotations(type_hint, strip_tagged_output=True): # TODO(robertwb): These should be parameterized types that the # type inferencer understands. # Then we can replace them with the correct element types instead of # using Any. Refer to typehints.WindowedValue when doing this. - annotations = (TimestampedValue, WindowedValue, pvalue.TaggedOutput) + annotations = [TimestampedValue, WindowedValue] + if strip_tagged_output: + annotations.append(pvalue.TaggedOutput) + annotations = tuple(annotations) contains_annotation = False diff --git a/sdks/python/apache_beam/typehints/decorators.py b/sdks/python/apache_beam/typehints/decorators.py index 106765d76ba4..e393113c002e 100644 --- a/sdks/python/apache_beam/typehints/decorators.py +++ b/sdks/python/apache_beam/typehints/decorators.py @@ -79,7 +79,6 @@ def foo((a, b)): # pytype: skip-file -import collections.abc import inspect import itertools import logging @@ -188,11 +187,6 @@ def disable_type_annotations(): _NO_MAIN_TYPE = object() -def _is_union_type(origin): - """Check if a type origin is a Union (typing.Union or types.UnionType).""" - return origin is Union or origin is types.UnionType - - def _tag_and_type(t): """Extract tag name and value type from TaggedOutput[Literal['tag'], Type]. @@ -215,74 +209,58 @@ def _tag_and_type(t): return tag_string, value_type -def _extract_main_and_tagged(t): - """Extract main type and tagged types from a type annotation. +def _extract_tagged_from_type(beam_type): + """Extract tagged output types from a Beam type (post-convert_to_beam_type). + + Called after the Iterable wrapper has been removed. + At this point, the type has already been through convert_to_beam_type, so + unions are typehints.UnionConstraint (not typing.Union), but + TaggedOutput[Literal['tag'], T] passes through unchanged as a typing + generic alias. Returns: - (main_type, tagged_dict) where main_type is the type without TaggedOutput - annotations (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag - names to their types. + (clean_type, tagged_dict) where clean_type is the type without TaggedOutput + members (or _NO_MAIN_TYPE if no main type), and tagged_dict maps tag names + to their Beam types. """ - if get_origin(t) is TaggedOutput: - tag, typ = _tag_and_type(t) - return _NO_MAIN_TYPE, {tag: typ} + # Single TaggedOutput[Literal['tag'], Type] + if get_origin(beam_type) is TaggedOutput: + tag, typ = _tag_and_type(beam_type) + return _NO_MAIN_TYPE, {tag: convert_to_beam_type(typ)} - if t is TaggedOutput: + # Bare TaggedOutput (unparameterized) + if beam_type is TaggedOutput: logging.warning( "TaggedOutput in return type must include type parameters: " "TaggedOutput[Literal['tag_name'], ValueType]. " - "Bare TaggedOutput falling back to Any.") + "Bare TaggedOutput will be ignored.") return _NO_MAIN_TYPE, {} - if not _is_union_type(get_origin(t)): - return t, {} + if not isinstance(beam_type, typehints.UnionHint.UnionConstraint): + return beam_type, {} + # UnionConstraint containing TaggedOutput members main_types = [] - tagged_types = {} - for arg in get_args(t): - if get_origin(arg) is TaggedOutput: - tag, typ = _tag_and_type(arg) - tagged_types[tag] = typ - elif arg is TaggedOutput: + tagged = {} + for member in beam_type.union_types: + if get_origin(member) is TaggedOutput: + tag, typ = _tag_and_type(member) + tagged[tag] = convert_to_beam_type(typ) + elif member is TaggedOutput: logging.warning( "TaggedOutput in return type must include type parameters: " "TaggedOutput[Literal['tag_name'], ValueType]. " - "Bare TaggedOutput falling back to Any.") + "Bare TaggedOutput will be ignored.") else: - main_types.append(arg) - - if len(main_types) == 0: - main_type = _NO_MAIN_TYPE + main_types.append(member) + if not tagged and len(main_types) == len(beam_type.union_types): + return beam_type, {} + if not main_types: + return _NO_MAIN_TYPE, tagged elif len(main_types) == 1: - main_type = main_types[0] + return main_types[0], tagged else: - main_type = Union[tuple(main_types)] - - return main_type, tagged_types - - -def _extract_output_types(return_annotation): - """Parse return annotation into (main_types, tagged_types). - - For tagged outputs to be extracted from generator/iterator functions, - users must explicitly use Iterable[T | TaggedOutput[...]] as return type. - - Returns raw Python types. Conversion to beam types happens in from_callable. - """ - if return_annotation == inspect.Signature.empty: - return [Any], {} - - # Iterable[T | TaggedOutput[...]] - if get_origin(return_annotation) is collections.abc.Iterable: - yield_type = get_args(return_annotation)[0] - clean_yield, tagged_types = _extract_main_and_tagged(yield_type) - clean_main = Any if clean_yield is _NO_MAIN_TYPE else clean_yield - return [Iterable[clean_main]], tagged_types - - # T | TaggedOutput (or plain type with no tags) - main_type, tagged_types = _extract_main_and_tagged(return_annotation) - main = Any if main_type is _NO_MAIN_TYPE else main_type - return [main], tagged_types + return typehints.Union[tuple(main_types)], tagged class IOTypeHints(NamedTuple): @@ -377,13 +355,11 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: 'Unsupported Parameter kind: %s' % param.kind input_args.append(convert_to_beam_type(param.annotation)) - output_args, output_kwargs = _extract_output_types( - signature.return_annotation) - output_args = [convert_to_beam_type(t) for t in output_args] - output_kwargs = { - k: convert_to_beam_type(v) - for k, v in output_kwargs.items() - } + output_args = [] + if signature.return_annotation != signature.empty: + output_args.append(convert_to_beam_type(signature.return_annotation)) + else: + output_args.append(typehints.Any) name = getattr(fn, '__name__', '') msg = ['from_callable(%s)' % name, ' signature: %s' % signature] @@ -393,7 +369,7 @@ def from_callable(cls, fn: Callable) -> Optional['IOTypeHints']: (fn.__code__.co_filename, fn.__code__.co_firstlineno)) return IOTypeHints( input_types=(tuple(input_args), input_kwargs), - output_types=(tuple(output_args), output_kwargs), + output_types=(tuple(output_args), {}), origin=cls._make_origin([], tb=False, msg=msg)) def with_input_types(self, *args, **kwargs) -> 'IOTypeHints': @@ -544,10 +520,47 @@ def strip_iterable(self) -> 'IOTypeHints': origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) yielded_type = typehints.get_yielded_type(output_type) + + # Also strip Iterable from tagged output types (e.g. from Map/MapTuple + # which wrap both main and tagged types in Iterable). + stripped_tags = { + tag: typehints.get_yielded_type(hint) + for tag, hint in tagged_output_types.items() + } + return self._replace( - output_types=((yielded_type, ), tagged_output_types), + output_types=((yielded_type, ), stripped_tags), origin=self._make_origin([self], tb=False, msg=['strip_iterable()'])) + def extract_tagged_outputs(self): + """Extract TaggedOutput types from the main output type into kwargs. + + For annotation style (e.g. -> Iterable[int | TaggedOutput[...]]), + TaggedOutput stays embedded in the main type through convert_to_beam_type + and strip_iterable. This method extracts those TaggedOutput members into + the tagged output kwargs dict. + + Should be called after strip_iterable(). + + Returns: + A copy of this instance with TaggedOutput members moved from the main + output type into the output kwargs dict. + """ + if self.output_types is None or not self.has_simple_output_type(): + return self + output_type = self.output_types[0][0] + + clean_type, extracted_tags = _extract_tagged_from_type(output_type) + if not extracted_tags: + return self + if clean_type is _NO_MAIN_TYPE: + clean_type = typehints.Any + return self._replace( + output_types=((clean_type, ), extracted_tags), + origin=self._make_origin([self], + tb=False, + msg=['extract_tagged_outputs()'])) + def with_defaults(self, hints: Optional['IOTypeHints']) -> 'IOTypeHints': if not hints: return self diff --git a/sdks/python/apache_beam/typehints/decorators_test.py b/sdks/python/apache_beam/typehints/decorators_test.py index 626c2ffbd497..95745f4e3d88 100644 --- a/sdks/python/apache_beam/typehints/decorators_test.py +++ b/sdks/python/apache_beam/typehints/decorators_test.py @@ -34,6 +34,7 @@ from apache_beam.typehints import WithTypeHints from apache_beam.typehints import decorators from apache_beam.typehints import typehints +from apache_beam.typehints.native_type_compatibility import convert_to_beam_type T = TypeVariable('T') # Name is 'T' so it converts to a beam type with the same name. @@ -279,6 +280,8 @@ def fn2(x: int) -> typing.Iterable[str]: self.assertEqual(th.output_types, ((typehints.Iterable[str], ), {})) def test_from_callable_tagged_output_union(self): + """Tagged types are NOT extracted in from_callable. They stay embedded + in the main type and are extracted later in strip_iterable().""" def fn( x: int ) -> int | str | TaggedOutput[typing.Literal['errors'], float @@ -288,13 +291,13 @@ def fn( th = decorators.IOTypeHints.from_callable(fn) self.assertEqual(th.input_types, ((int, ), {})) - self.assertEqual( - th.output_types, - ((typehints.Union[int, str], ), { - 'errors': typehints.Union[float, str], 'warnings': str - })) + # TaggedOutput members are preserved in the union no extraction yet. + output_type = th.output_types[0][0] + self.assertIsInstance(output_type, typehints.UnionConstraint) + self.assertEqual(th.output_types[1], {}) def test_from_callable_tagged_output_iterable(self): + """Tagged types inside Iterable are preserved until strip_iterable.""" def fn( x: int ) -> typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]]: @@ -302,35 +305,21 @@ def fn( th = decorators.IOTypeHints.from_callable(fn) self.assertEqual(th.input_types, ((int, ), {})) - self.assertEqual( - th.output_types, ((typehints.Iterable[int], ), { - 'errors': str - })) - - def test_from_callable_tagged_output_multiple_tags(self): - def fn( - x: int - ) -> ( - int | TaggedOutput[typing.Literal['errors'], str] | - TaggedOutput[typing.Literal['warnings'], str]): - return x - - th = decorators.IOTypeHints.from_callable(fn) - self.assertEqual(th.input_types, ((int, ), {})) - self.assertEqual( - th.output_types, ((int, ), { - 'errors': str, 'warnings': str - })) + # The full Iterable[Union[int, TaggedOutput[...]]] is preserved. + output_type = th.output_types[0][0] + self.assertIsInstance(output_type, typehints.IterableTypeConstraint) + self.assertEqual(th.output_types[1], {}) def test_from_callable_tagged_output_only(self): + """A standalone TaggedOutput annotation passes through from_callable.""" def fn(x: int) -> TaggedOutput[typing.Literal['errors'], str]: pass th = decorators.IOTypeHints.from_callable(fn) self.assertEqual(th.input_types, ((int, ), {})) - self.assertEqual(th.output_types, ((Any, ), { - 'errors': str - })) + # TaggedOutput[...] passes through convert_to_beam_type unchanged. + self.assertIs(typing.get_origin(th.output_types[0][0]), TaggedOutput) + self.assertEqual(th.output_types[1], {}) def test_getcallargs_forhints(self): def fn( @@ -496,88 +485,69 @@ def fn2(a: int) -> int: _ = ['a', 'b', 'c'] | Map(fn2) # Doesn't raise - no input type hints. -class TaggedOutputExtractionTest(unittest.TestCase): - """Tests for TaggedOutput extraction helper functions.""" - def test_extract_main_and_tagged_simple_type(self): - main, tagged = decorators._extract_main_and_tagged(int) +class ExtractTaggedFromTypeTest(unittest.TestCase): + """Tests for _extract_tagged_from_type (Beam-level type extraction).""" + def test_simple_type_no_extraction(self): + main, tagged = decorators._extract_tagged_from_type(int) self.assertEqual(main, int) self.assertEqual(tagged, {}) - def test_extract_main_and_tagged_tagged_output_only(self): + def test_beam_union_no_tagged(self): + t = typehints.Union[int, str] + main, tagged = decorators._extract_tagged_from_type(t) + self.assertEqual(main, t) + self.assertEqual(tagged, {}) + + def test_standalone_tagged_output(self): t = TaggedOutput[typing.Literal['errors'], str] - main, tagged = decorators._extract_main_and_tagged(t) + main, tagged = decorators._extract_tagged_from_type(t) self.assertIs(main, decorators._NO_MAIN_TYPE) self.assertEqual(tagged, {'errors': str}) - def test_extract_main_and_tagged_union(self): - t = int | TaggedOutput[typing.Literal['errors'], str] - main, tagged = decorators._extract_main_and_tagged(t) + def test_beam_union_with_tagged(self): + t = convert_to_beam_type(int | TaggedOutput[typing.Literal['errors'], str]) + main, tagged = decorators._extract_tagged_from_type(t) self.assertEqual(main, int) self.assertEqual(tagged, {'errors': str}) - def test_extract_main_and_tagged_union_multiple_tagged(self): - t = ( + def test_beam_union_multiple_tagged(self): + t = convert_to_beam_type( int | TaggedOutput[typing.Literal['errors'], str] | TaggedOutput[typing.Literal['warnings'], str]) - main, tagged = decorators._extract_main_and_tagged(t) + main, tagged = decorators._extract_tagged_from_type(t) self.assertEqual(main, int) self.assertEqual(tagged, {'errors': str, 'warnings': str}) - def test_extract_main_and_tagged_union_multiple_main_types(self): - t = (int | str | TaggedOutput[typing.Literal['errors'], bytes]) - main, tagged = decorators._extract_main_and_tagged(t) - # Main type should be Union[int, str] - self.assertEqual(typing.get_origin(main), typing.Union) - self.assertIn(int, typing.get_args(main)) - self.assertIn(str, typing.get_args(main)) + def test_beam_union_multiple_main_types(self): + t = convert_to_beam_type( + int | str | TaggedOutput[typing.Literal['errors'], bytes]) + main, tagged = decorators._extract_tagged_from_type(t) + self.assertIsInstance(main, typehints.UnionConstraint) + self.assertIn(int, main.union_types) + self.assertIn(str, main.union_types) self.assertEqual(tagged, {'errors': bytes}) - def test_extract_output_types_empty_signature(self): - import inspect - main, tagged = decorators._extract_output_types(inspect.Signature.empty) - self.assertEqual(main, [typing.Any]) - self.assertEqual(tagged, {}) - - def test_extract_output_types_simple_type(self): - main, tagged = decorators._extract_output_types(int) - self.assertEqual(main, [int]) - self.assertEqual(tagged, {}) - - def test_extract_output_types_union_with_tagged(self): - t = int | TaggedOutput[typing.Literal['errors'], str] - main, tagged = decorators._extract_output_types(t) - self.assertEqual(main, [int]) - self.assertEqual(tagged, {'errors': str}) - - def test_extract_output_types_iterable_with_tagged(self): - t = typing.Iterable[int | TaggedOutput[typing.Literal['errors'], str]] - main, tagged = decorators._extract_output_types(t) - self.assertEqual(main, [typing.Iterable[int]]) - self.assertEqual(tagged, {'errors': str}) + def test_beam_union_tagged_only(self): + t = convert_to_beam_type( + TaggedOutput[typing.Literal['errors'], str] + | TaggedOutput[typing.Literal['warnings'], int]) + main, tagged = decorators._extract_tagged_from_type(t) + self.assertIs(main, decorators._NO_MAIN_TYPE) + self.assertEqual(tagged, {'errors': str, 'warnings': int}) - def test_extract_output_types_list_with_tagged_not_extracted(self): - t = typing.List[int | TaggedOutput[typing.Literal['errors'], str]] - _, tagged = decorators._extract_output_types(t) - # The whole type is converted as-is. Users should use Iterable instead. + def test_bare_tagged_output_standalone(self): + with self.assertLogs(level='WARNING') as cm: + main, tagged = decorators._extract_tagged_from_type(TaggedOutput) + self.assertIn('Bare TaggedOutput will be ignored', cm.output[0]) + self.assertIs(main, decorators._NO_MAIN_TYPE) self.assertEqual(tagged, {}) - def test_extract_output_types_tagged_only(self): - t = TaggedOutput[typing.Literal['errors'], str] - main, tagged = decorators._extract_output_types(t) - self.assertEqual(main, [typing.Any]) - self.assertEqual(tagged, {'errors': str}) - - def test_extract_output_types_iterable_tagged_only(self): - t = typing.Iterable[TaggedOutput[typing.Literal['errors'], str]] - main, tagged = decorators._extract_output_types(t) - self.assertEqual(main, [typing.Iterable[typing.Any]]) - self.assertEqual(tagged, {'errors': str}) - - def test_extract_output_types_bare_tagged_excluded(self): + def test_bare_tagged_output_in_union(self): with self.assertLogs(level='WARNING') as cm: - main, tagged = decorators._extract_output_types(str | TaggedOutput) - self.assertIn('Bare TaggedOutput falling back to Any', cm.output[0]) - self.assertEqual(main, [str]) + t = convert_to_beam_type(str | TaggedOutput) + main, tagged = decorators._extract_tagged_from_type(t) + self.assertIn('Bare TaggedOutput will be ignored', cm.output[0]) + self.assertEqual(main, str) self.assertEqual(tagged, {}) diff --git a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py index 5dfae1b7e3dd..c06f68fb88a4 100644 --- a/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py +++ b/sdks/python/apache_beam/typehints/tagged_output_typehints_test.py @@ -115,7 +115,7 @@ def mapfn(element): def test_flatmap_decorator_pipeline(self): """Test that tagged types propagate through FlatMap.""" - @with_output_types(Iterable[int], errors=str) + @with_output_types(Iterable[int], errors=Iterable[str]) def flatmapfn(element): if element < 0: yield beam.pvalue.TaggedOutput('errors', f'Negative: {element}') @@ -151,7 +151,7 @@ def maptuplefn(key, value): def test_flatmaptuple_decorator_pipeline(self): """Test that tagged types propagate through FlatMapTuple.""" - @with_output_types(Iterable[int], errors=str) + @with_output_types(Iterable[int], errors=Iterable[str]) def flatmaptuplefn(key, value): if value < 0: yield beam.pvalue.TaggedOutput('errors', f'Negative: {key}={value}')