Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions sdks/python/apache_beam/typehints/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ def foo((a, b)):
properly it must appear at the top of the module where all functions are
defined, or before importing a module containing type-hinted functions.
"""

# pytype: skip-file

import inspect
import itertools
import logging
import traceback
import types
import typing
from typing import Any
from typing import Callable
from typing import Dict
Expand Down Expand Up @@ -126,6 +126,15 @@ def foo((a, b)):
_disable_from_callable = False # pylint: disable=invalid-name


def _get_type_hints(obj):
if isinstance(obj, type):
resolved_annotations = typing.get_type_hints(obj.__init__)
resolved_annotations['return'] = obj
else:
resolved_annotations = typing.get_type_hints(obj)
return resolved_annotations


def get_signature(func):
"""Like inspect.signature(), but supports Py2 as well.

Expand All @@ -146,6 +155,34 @@ def get_signature(func):

signature = inspect.Signature(params)

try:
resolved_annotations: Dict[str, Any] = _get_type_hints(func)
except NameError:
# note(jtran): if the function uses any types defined only in a
# `if typing.TYPE_CHECKING:` block, we'll get a NameError
pass
except TypeError:
# Let callable non-functions and functools.partial pass through
pass
else:
new_parameters = []
for name, param in signature.parameters.items():
# Look up the resolved annotation for the parameter
resolved_annotation = resolved_annotations.get(name, param.annotation)

# Create a new Parameter object with the resolved annotation
new_param = param.replace(annotation=resolved_annotation)
new_parameters.append(new_param)
# 4. Determine the resolved return annotation
resolved_return_annotation = resolved_annotations.get(
'return', signature.return_annotation)
if resolved_return_annotation is type(None):
# For backward compatibility, we just use None to represent the
# type of None
resolved_return_annotation = None
signature = signature.replace(
parameters=new_parameters, return_annotation=resolved_return_annotation)

# This is a specialization to hint the first argument of certain builtins,
# such as str.strip.
if isinstance(func, _MethodDescriptorType):
Expand Down Expand Up @@ -339,7 +376,7 @@ def strip_pcoll(self):
strip_pcoll_helper(self.output_types,
self.has_simple_output_type,
'output_types',
[PDone, None],
[PDone, None, type(None)],
'This output type hint will be ignored '
'and not used for type-checking purposes. '
'Typically, output type hints for a '
Expand Down Expand Up @@ -431,6 +468,11 @@ def strip_iterable(self) -> 'IOTypeHints':
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))

yielded_type = typehints.get_yielded_type(output_type)
if isinstance(yielded_type, typehints.TypeVariable):
# For backwards compatibility, we cast TypeVars to Any.
return self._replace(
output_types=((typehints.Any, ), {}),
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
return self._replace(
output_types=((yielded_type, ), {}),
origin=self._make_origin([self], tb=False, msg=['strip_iterable()']))
Expand Down
14 changes: 13 additions & 1 deletion sdks/python/apache_beam/typehints/decorators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@
# pytype: skip-file

import functools
import sys
import typing
import unittest

import pytest

from apache_beam import Map
from apache_beam.typehints import Any
from apache_beam.typehints import Dict
Expand All @@ -41,6 +44,15 @@
T_typing = typing.TypeVar('T') # type: ignore


@pytest.fixture(autouse=True)
def skipif310_or_lower(request):
if sys.version_info < (3, 11) and "futureannotations" in str(
request.node.name):
# NOTE(hjtran): the futureannotation tests seem to pass on py3.10
# locally but fail on the GH runner.
pytest.skip("Skipping test on Python 3.10 or lower")


class IOTypeHintsTest(unittest.TestCase):
def test_get_signature(self):
# Basic coverage only to make sure function works.
Expand All @@ -53,7 +65,7 @@ def fn(a, b=1, *c, **d):
def test_get_signature_builtin(self):
s = decorators.get_signature(list)
self.assertListEqual(list(s.parameters), ['iterable'])
self.assertEqual(s.return_annotation, List[Any])
self.assertEqual(s.return_annotation, list)

def test_from_callable_without_annotations(self):
def fn(a, b=None, *args, **kwargs):
Expand Down
Loading
Loading