8000 more typing · apache/beam@2f1e9d4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 2f1e9d4

Browse files
committed
more typing
1 parent 77ca5e5 commit 2f1e9d4

File tree

15 files changed

+264
-116
lines changed

15 files changed

+264
-116
lines changed

.test-infra/mypy/beam_plugin.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Plugin, MethodContext
66
)
77
from mypy.types import (
8-
Type, Instance, TypeVarType
8+
Type, Instance, TypeVarType, UninhabitedType
99
)
1010
from mypy.expandtype import (
1111
expand_type
@@ -17,11 +17,11 @@ def get_method_hook(self, fullname: str
1717
) -> Optional[Callable[[MethodContext], Type]]:
1818
if fullname in {'apache_beam.pvalue.PValue.__or__',
1919
'apache_beam.pvalue.PCollection.__or__'}:
20-
return pvalue_pipe_callback
20+
return pvalue_or_callback
2121
return None
2222

2323

24-
def pvalue_pipe_callback(ctx: MethodContext) -> Type:
24+
def pvalue_or_callback(ctx: MethodContext) -> Type:
2525
"""
2626
Callback to provide an accurate return type for
2727
apache_beam.pvalue.PValue.__or__.
@@ -31,26 +31,59 @@ def pvalue_pipe_callback(ctx: MethodContext) -> Type:
3131
'apache_beam.pvalue.PValue.__or__ should have exactly one parameter'
3232
assert len(ctx.arg_types[0]) == 1, \
3333
"apache_beam.pvalue.PValue.__or__'s parameter should not be variadic"
34-
transform_type = ctx.arg_types[0][0]
35-
36-
args = transform_type.args
37-
if len(args) == 2:
38-
print("{}.__or__() -> {}".format(ctx.type, ctx.default_return_type))
39-
print(" {}".format(ctx.type.args[0]))
40-
print(" {}".format(transform_type))
41-
print(" {}".format(ctx.context.line))
42-
43-
44-
in_arg = args[0]
45-
if isinstance(in_arg, TypeVarType):
46-
expanded = expand_type(transform_type,
47-
{in_arg.id: ctx.type.args[0]})
48-
out_arg = expanded.args[1]
49-
print(" -> {}".format(expanded))
34+
xform_type = ctx.arg_types[0][0]
35+
36+
xform_args = xform_type.args
37+
if len(xform_args) == 2:
38+
print("{}.__or__()".format(ctx.type))
39+
print(" default return: {}".format(ctx.default_return_type))
40+
print(" xform arg: {}".format(xform_type))
41+
print(" line: {}".format(ctx.context.line))
42+
43+
pvalue = ctx.type
44+
# this is the PValue InT TypeVar T`1
45+
pvalue_typevar = pvalue.type.bases[0].args[0]
46+
47+
# xform_typevar = xform_type.type
48+
# print(" {}".format(xform_typevar))
49+
50+
# ths is the PTransform InT arg T`-1 (may be TypeVar or may be filled)
51+
xform_in_arg = xform_args[0]
52+
xform_out_arg = xform_args[1]
53+
54+
# in_arg may already be filled, so we need the original TypeVarId
55+
meth = pvalue.type.get('__or__').type
56+
# print(" {}".format(meth))
57+
# print(" {}".format(meth.arg_types[0]))
58+
59+
# print(" {}".format(ctx.arg_types))
60+
print(" pvalue typevar {}".format(pvalue_typevar))
61+
print(" xform typevar {} ({})".format(xform_in_arg, type(xform_in_arg)))
62+
63+
if isinstance(xform_in_arg, TypeVarType):
64+
xform_expanded = expand_type(
65+
xform_type, {xform_in_arg.id: pvalue.args[0]})
66+
out_arg = xform_expanded.args[1]
67+
print(" -> {}".format(xform_expanded))
5068
result = ctx.default_return_type.copy_modified(
5169
args=[out_arg])
5270
print(" -> {}".format(result))
5371
return result
72+
elif isinstance(xform_in_arg, UninhabitedType) and isinstance(xform_out_arg, UninhabitedType):
73+
# print(" {}".format(xform_type.type))
74+
xform_type_inhabited = xform_type.copy_modified(
75+
args=xform_type.type.bases[0].args)
76+
print(" {}".format(xform_type_inhabited))
77+
# else:
78+
#
79+
# print(" {}".format(ctx.type.type.type_vars[0]))
80+
# print(" {}".format(typevar))
81+
# print(" {}".format(type(typevar.id)))
82+
# print(" {}".format(in_arg))
83+
# print(" {}".format(type(in_arg)))
84+
# if isinstance(in_arg, Instance):
85+
# print(" {}".format(in_arg.args))
86+
5487
# for arg_type in ctx.arg_types:
5588
# print(" {}".format(arg_type))
5689
return ctx.default_return_type

mypy.ini

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,12 @@ warn_redundant_casts = true
88
warn_unused_ignores = true
99
plugins = beam_plugin
1010

11-
[mypy-luma.accruals]
11+
[mypy-apache_beam.runners.dataflow.internal.clients.dataflow.dataflow_v1b3_client]
1212
ignore_errors = true
13+
14+
[mypy-apache_beam.io.gcp.internal.clients.storage.storage_v1_client]
15+
ignore_errors = true
16+
17+
[mypy-apache_beam.io.gcp.internal.clients.bigquery.bigquery_v2_client]
18+
ignore_errors = true
19+

sdks/python/apache_beam/examples/wordcount.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -99,14 +99,16 @@ def run(argv=None):
9999
p = beam.Pipeline(options=pipeline_options)
100100

101101
reveal_type(p)
102+
102103
_read = ReadFromText(known_args.input)
103104
reveal_type(_read)
105+
104106
read = 'read' >> _read
105107
reveal_type(read)
108+
106109
# Read the text file[pattern] into a PCollection.
107110
lines = p | read
108-
109-
reveal_type(lines)
111+
reveal_type(lines) # PCollection[unicode*]
110112

111113
def make_ones(x):
112114
# type: (T) -> Tuple[T, int]
@@ -118,21 +120,32 @@ def count_ones(word_ones):
118120
(word, ones) = word_ones
119121
return (word, sum(ones))
120122

121-
split = lines | 'split' >> beam.ParDo(WordExtractingDoFn())
122-
reveal_type(split)
123+
split = lines | beam.ParDo(WordExtractingDoFn())
124+
reveal_type(split) # PCollection[unicode*]
123125

124126
makemap = beam.Map(make_ones)
125-
reveal_type(makemap)
126-
# breaks here: we want the type of split to affect the TypeVar in make_ones
127-
# but it doesn't...
128-
pair = split | 'pair_with_one' >> beam.Map(make_ones)
129-
reveal_type(pair)
127+
reveal_type(makemap) # ParDo[T`-1, Tuple[T`-1, int]]
128+
129+
# results in error
130+
# Unsupported operand types for | ("PCollection[unicode]" and "PTransform[T, Tuple[T, int]]")
131+
pair = split | makemap
132+
reveal_type(pair) # PCollection[Tuple[unicode*, int]]
133+
134+
# Below, beam.Map picks up unicode from split (becoming PTransform[unicode, Tuple[T, int]), resulting in the error:
135+
# Argument 1 to "Map" has incompatible type "Callable[[T], Tuple[T, int]]"; expected "Callable[[unicode], Tuple[T, int]]"
136+
pair2 = split | beam.Map(make_ones)
137+
reveal_type(pair2) # PCollection[Tuple[unicode*, int]]
138+
139+
_group = beam.GroupByKey() # type: beam.GroupByKey[Tuple[unicode, int], Tuple[unicode, Iterable[int]]]
140+
reveal_type(_group)
141+
reveal_type(beam.GroupByKey())
142+
reveal_type(beam.GroupByKey.expand)
130143

131-
group = pair | 'group' >> beam.GroupByKey()
132-
reveal_type(group)
144+
group = pair | beam.GroupByKey()
145+
reveal_type(group) # PCollection[Tuple[unicode*, typing.Iterable[int*]]]
133146

134-
counts = group | 'count' >> beam.Map(count_ones)
135-
reveal_type(counts)
147+
counts = group | beam.Map(count_ones)
148+
reveal_type(counts) # PCollection[Tuple[K`-1, int]]
136149

137150
# Format the counts into a PCollection of strings.
138151
def format_result(word_count):
@@ -144,7 +157,7 @@ def format_result(word_count):
144157

145158
# Write the output using a "Write" transform that has side effects.
146159
# pylint: disable=expression-not-assigned
147-
output | 'write' >> WriteToText(known_args.output)
160+
output | WriteToText(known_args.output)
148161

149162
result = p.run()
150163
result.wait_until_finish()

sdks/python/apache_beam/io/gcp/pubsub.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from __future__ import absolute_import
2626

2727
import re
28+
import typing
2829
from builtins import object
2930

3031
from future.utils import iteritems
@@ -121,6 +122,7 @@ def _to_proto_str(self):
121122

122123
@staticmethod
123124
def _from_message(msg):
125+
# type: () -> PubsubMessage
124126
"""Construct from ``google.cloud.pubsub_v1.subscriber.message.Message``.
125127
126128
https://googleapis.github.io/google-cloud-python/latest/pubsub/subscriber/api/message.html
@@ -134,8 +136,13 @@ class ReadFromPubSub(PTransform):
134136
"""A ``PTransform`` for reading from Cloud Pub/Sub."""
135137
# Implementation note: This ``PTransform`` is overridden by Directrunner.
136138

137-
def __init__(self, topic=None, subscription=None, id_label=None,
138-
with_attributes=False, timestamp_attribute=None):
139+
def __init__(self,
140+
topic=None, # type: typing.Optional[str]
141+
subscription=None, # type: typing.Optional[str]
142+
id_label=None, # type: typing.Optional[str]
143+
with_attributes=False, # type: bool
144+
timestamp_attribute=None # type: typing.Union[float, str, None]
145+
):
139146
"""Initializes ``ReadFromPubSub``.
140147
141148
Args:
@@ -327,8 +334,13 @@ class _PubSubSource(dataflow_io.NativeSource):
327334
fetches ``PubsubMessage`` protobufs.
328335
"""
329336

330-
def __init__(self, topic=None, subscription=None, id_label=None,
331-
with_attributes=False, timestamp_attribute=None):
337+
def __init__(self,
338+
topic=None, # type: typing.Optional[str]
339+
subscription=None, # type: typing.Optional[str]
340+
id_label=None, # type: typing.Optional[str]
341+
with_attributes=False, # type: bool
342+
timestamp_attribute=None # type: typing.Union[float, str, None]
343+
):
332344
self.coder = coders.BytesCoder()
333345
self.full_topic = topic
334346
self.full_subscription = subscription

sdks/python/apache_beam/io/iobase.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import logging
3636
import math
3737
import random
38+
import typing
3839
import uuid
3940
from builtins import object
4041
from builtins import range
@@ -59,6 +60,8 @@
5960
__all__ = ['BoundedSource', 'RangeTracker', 'Read', 'RestrictionTracker',
6061
'Sink', 'Write', 'Writer']
6162

63+
InT = typing.TypeVar('InT')
64+
OutT = typing.TypeVar('OutT')
6265

6366
# Encapsulates information about a bundle of a source generated when method
6467
# BoundedSource.split() is invoked.
@@ -833,7 +836,7 @@ def close(self):
833836
raise NotImplementedError
834837

835838

836-
class Read(ptransform.PTransform):
839+
class Read(ptransform.PTransform[None, OutT]):
837840
"""A transform that reads a PCollection."""
838841

839842
def __init__(self, source):
@@ -847,7 +850,7 @@ def __init__(self, source):
847850
self.source = source
848851

849852
def expand(self, pbegin):
850-
# type: (pvalue.PBegin) -> pvalue.PCollection
853+
# type: (pvalue.PBegin) -> pvalue.PCollection[OutT]
851854
from apache_beam.options.pipeline_options import DebugOptions
852855
from apache_beam.transforms import util
853856

@@ -912,7 +915,7 @@ def from_runner_api_parameter(parameter, context):
912915
Read.from_runner_api_parameter)
913916

914917

915-
class Write(ptransform.PTransform):
918+
class Write(ptransform.PTransform[InT, None]):
916919
"""A ``PTransform`` that writes to a sink.
917920
918921
A sink should inherit ``iobase.Sink``. Such implementations are
@@ -941,6 +944,7 @@ class Write(ptransform.PTransform):
941944
"""
942945

943946
def __init__(self, sink):
947+
# type: (Sink) -> None
944948
"""Initializes a Write transform.
945949
946950
Args:
@@ -954,7 +958,7 @@ def display_data(self):
954958
'sink_dd': self.sink}
955959

956960
def expand(self, pcoll):
957-
# type: (pvalue.PCollection) -> pvalue.PCollection
961+
# type: (pvalue.PC 10000 ollection[InT]) -> pvalue.PCollection[None]
958962
from apache_beam.runners.dataflow.native_io import iobase as dataflow_io
959963
if isinstance(self.sink, dataflow_io.NativeSink):
960964
# A native sink
@@ -970,14 +974,15 @@ def expand(self, pcoll):
970974
'or be a PTransform. Received : %r' % self.sink)
971975

972976

973-
class WriteImpl(ptransform.PTransform):
977+
class WriteImpl(ptransform.PTransform[InT, None]):
974978
"""Implements the writing of custom sinks."""
975979

976980
def __init__(self, sink):
977981
super(WriteImpl, self).__init__()
978982
self.sink = sink
979983

980984
def expand(self, pcoll):
985+
# type: (pvalue.PCollection[InT]) -> pvalue.PCollection[None]
981986
do_once = pcoll.pipeline | 'DoOnce' >> core.Create([None])
982987
init_result_coll = do_once | 'InitializeWrite' >> core.Map(
983988
lambda _, sink: sink.initialize_write(), self.sink)
@@ -1083,7 +1088,7 @@ def _finalize_write(unused_element, sink, init_result, write_results,
10831088
window.TimestampedValue(v, timestamp.MAX_TIMESTAMP) for v in outputs)
10841089

10851090

1086-
class _RoundRobinKeyFn(core.DoFn):
1091+
class _RoundRobinKeyFn(core.DoFn[InT, OutT]):
10871092

10881093
def __init__(self, count):
10891094
self.count = count
@@ -1092,6 +1097,7 @@ def start_bundle(self):
10921097
self.counter = random.randint(0, self.count - 1)
10931098

10941099
def process(self, element):
1100+
# type: (InT) -> typing.Iterator[typing.Tuple[int, OutT]]
10951101
self.counter += 1
10961102
if self.counter >= self.count:
10971103
self.counter -= self.count

sdks/python/apache_beam/pipeline.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import re
5353
import shutil< 10000 /div>
5454
import tempfile
55+
import typing
5556
from builtins import object
5657
from builtins import zip
5758

@@ -730,7 +731,12 @@ class AppliedPTransform(object):
730731
(used internally by Pipeline for bookeeping purposes).
731732
"""
732733

733-
def __init__(self, parent, transform, full_label, inputs):
734+
def __init__(self,
735+
parent,
736+
transform, # type: ptransform.PTransform
737+
full_label,
738+
inputs
739+
):
734740
self.parent = parent
735741
self.transform = transform
736742
# Note that we want the PipelineVisitor classes to use the full_label,
@@ -741,7 +747,7 @@ def __init__(self, parent, transform, full_label, inputs):
741747
self.full_label = full_label
742748
self.inputs = inputs or ()
743749
self.side_inputs = () if transform is None else tuple(transform.side_inputs)
744-
self.outputs = {}
750+
self.outputs = {} # type: typing.Dict[str, pvalue.PValue]
745751
self.parts = []
746752

747753
def __repr__(self):
@@ -762,7 +768,11 @@ def replace_output(self, output, tag=None):
762768
else:
763769
raise TypeError("Unexpected output type: %s" % output)
764770

765-
def add_output(self, output, tag=None):
771+
def add_output(self,
772+
output, # type: typing.Union[pvalue.DoOutputsTuple, pvalue.PValue]
773+
tag=None # type: typing.Optional[str]
774+
):
775+
# type: (...) -> None
766776
if isinstance(output, pvalue.DoOutputsTuple):
767777
self.add_output(output[output._main_tag])
768778
elif isinstance(output, pvalue.PValue):

0 commit comments

Comments
 (0)
0