10000 Parallel device: fix variable initialization in tf.function · metacortex/tensorflow@d44cb28 · GitHub
[go: up one dir, main page]

Skip to content

Commit d44cb28

Browse files
allenlavoietensorflower-gardener
authored andcommitted
Parallel device: fix variable initialization in tf.function
Switches ParallelDevice variables to be compatible with the tf.function variable creator scope, and adds a special case to handle conditional initialization of parallel variables. Adds TPU tests for the parallel device since that's a major constraint on the implementation (no uninitialized input to tf.cond). Rolling forward with some branching logic for Windows (may not be Windows-specific, but whatever combination of packages we test with there). PiperOrigin-RevId: 334170699 Change-Id: I541655bd8a116d013a5a3f62b645aa7242411a40
1 parent e96a709 commit d44cb28

File tree

6 files changed

+172
-52
lines changed

6 files changed

+172
-52
lines changed

tensorflow/python/distribute/parallel_device/BUILD

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
load("//tensorflow/core/platform/default:distribute.bzl", "distribute_py_test")
2+
13
package(
24
default_visibility = ["//tensorflow:internal"],
35
licenses = ["notice"], # Apache 2.0
@@ -17,6 +19,7 @@ py_library(
1719
":saving",
1820
"//tensorflow/python:_pywrap_parallel_device",
1921
"//tensorflow/python/distribute:device_util",
22+
"//tensorflow/python/tpu:tpu_ops",
2023
],
2124
)
2225

@@ -27,15 +30,13 @@ py_library(
2730
deps = ["//tensorflow/python:framework_ops"],
2831
)
2932

30-
py_test(
33+
distribute_py_test(
3134
name = "parallel_device_test",
3235
srcs = ["parallel_device_test.py"],
3336
python_version = "PY3",
3437
tags = [
3538
# Dependencies aren't otherwise included in the pip package yet.
3639
"no_pip",
37-
# MRO broken; needs investigation
38-
"no_windows",
3940
],
4041
deps = [
4142
":parallel_device",

tensorflow/python/distribute/parallel_device/parallel_device.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import print_function
2020

2121
import threading
22+
import weakref
2223

2324
from tensorflow.python import _pywrap_parallel_device
2425
from tensorflow.python.distribute import device_util
@@ -32,6 +33,16 @@
3233
_next_device_number = 0
3334
_next_device_number_lock = threading.Lock()
3435

36+
_all_parallel_devices = weakref.WeakValueDictionary()
37+
38+
39+
def unpack(tensor):
40+
"""Finds `tensor`'s parallel device and unpacks its components."""
41+
parallel_device = _all_parallel_devices.get(tensor.device, None)
42+
if parallel_device is None:
43+
raise ValueError("{} is not a parallel device".format(tensor.device))
44+
return parallel_device.unpack(tensor)
45+
3546

3647
# TODO(allenl): Expand this docstring once things like getting components on and
3748
# off the device are stable.
@@ -67,6 +78,7 @@ def __init__(self, components):
6778
self._device_ids = None
6879
self._device_scope = None
6980
self._saving_scope = None
81+
_all_parallel_devices[self._name] = self
7082

7183
def pack(self, tensors):
7284
"""Create a tensor on the parallel device from a sequence of tensors.

tensorflow/python/distribute/parallel_device/parallel_device_test.py

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,19 +93,30 @@ class _VirtualDeviceTestCase(test.TestCase):
9393

9494
def setUp(self):
9595
super(_VirtualDeviceTestCase, self).setUp()
96-
cpus = context.context().list_physical_devices("CPU")
97-
# Set 4 virtual CPUs
98-
context.context().set_logical_device_configuration(cpus[0], [
99-
context.LogicalDeviceConfiguration(),
100-
context.LogicalDeviceConfiguration(),
101-
context.LogicalDeviceConfiguration(),
102-
context.LogicalDeviceConfiguration()
96+
ctx = context.context()
97+
if ctx.list_physical_devices("TPU"):
98+
self.device_type = "TPU"
99+
elif ctx.list_physical_devices("GPU"):
100+
self.device_type = "GPU"
101+
gpus = ctx.list_physical_devices(self.device_type)
102+
ctx.set_logical_device_configuration(gpus[0], [
103+
context.LogicalDeviceConfiguration(memory_limit=100),
104+
context.LogicalDeviceConfiguration(memory_limit=100),
105+
])
106+
else:
107+
self.device_type = "CPU"
108+
cpus = ctx.list_physical_devices("CPU")
109+
ctx.set_logical_device_configuration(cpus[0], [
110+
context.LogicalDeviceConfiguration(),
111+
context.LogicalDeviceConfiguration(),
112+
])
113+
114+
self.device = parallel_device.ParallelDevice(components=[
115+
"/job:localhost/device:{}:0".format(self.device_type),
116+
self.device_type + ":1"
103117
])
104-
105-
self.device = parallel_device.ParallelDevice(
106-
components=["/job:localhost/device:CPU:0", "CPU:1"])
107-
self.assertIn("CPU:0", self.device.components[0])
108-
self.assertIn("CPU:1", self.device.components[1])
118+
self.assertIn(self.device_type + ":0", self.device.components[0])
119+
self.assertIn(self.device_type + ":1", self.device.components[1])
109120

110121

111122
class ParallelDeviceTests(_VirtualDeviceTestCase):
@@ -124,10 +135,14 @@ def test_register_parallel_device(self):
124135
def test_device_id(self):
125136
device_ids = self.device.unpack(self.device.device_ids)
126137
self.assertAllClose([0, 1], device_ids)
127-
self.assertIn(self.device.components[0], device_ids[0].backing_device)
128-
self.assertIn(self.device.components[1], device_ids[1].backing_device)
138+
# TODO(allenl): Should device IDs be int64 so they can be placed on GPUs?
139+
# Currently backing_device is CPU.
140+
self.assertIn(self.device.components[0], device_ids[0].device)
141+
self.assertIn(self.device.components[1], device_ids[1].device)
129142

130143
def test_collective_reduce(self):
144+
if self.device_type == "TPU":
145+
self.skipTest("ParallelDevice collectives on TPUs need work")
131146
with self.device:
132147
x = self.device.pack(
133148
[constant_op.constant(-1.5),
@@ -139,6 +154,8 @@ def test_collective_reduce(self):
139154
self.assertIn(self.device.components[1], outputs[1].backing_device)
140155

141156
def test_collective_reduce_async_scope(self):
157+
if self.device_type == "TPU":
158+
self.skipTest("ParallelDevice collectives on TPUs need work")
142159
# Note that ops on the parallel device currently don't execute
143160
# asynchronously. The test is just that we don't get deadlocks.
144161
with context.async_scope(), self.device:
@@ -152,6 +169,8 @@ def test_collective_reduce_async_scope(self):
152169
self.assertIn(self.device.components[1], outputs[1].backing_device)
153170

154171
def test_collective_reduce_async_context(self):
172+
if self.device_type == "TPU":
173+
self.skipTest("ParallelDevice collectives on TPUs need work")
155174
previous = config.get_synchronous_execution()
156175
try:
157176
context._reset_context()
@@ -173,6 +192,8 @@ def test_collective_reduce_async_context(self):
173192
config.set_synchronous_execution(previous)
174193

175194
def test_collective_in_function(self):
195+
if self.device_type == "TPU":
196+
self.skipTest("ParallelDevice collectives on TPUs need work")
176197
c = constant_op.constant([2])
177198

178199
@def_function.function
@@ -313,6 +334,33 @@ def _test_fn():
313334
return y, tape.gradient(y, x)
314335
self._assert_close_to_non_parallel(_test_fn)
315336

337+
def test_variable_created_in_function(self):
338+
339+
class M(module.Module):
340+
< F987 /code>
341+
def __init__(self):
342+
self.v = None
343+
self.w = None
344+
self.x = None
345+
self.z = None
346+
347+
@def_function.function(autograph=False)
348+
def __call__(self, x):
349+
if self.v is None:
350+
with ops.init_scope():
351+
initial_value = constant_op.constant(2.)
352+
self.z = variables.Variable(initial_value)
353+
self.x = variables.Variable(initial_value)
354+
self.w = variables.Variable(lambda: constant_op.constant(2.))
355+
self.v = variables.Variable(constant_op.constant(2.))
356+
return x * self.v * self.w * self.x * self.z
357+
358+
with self.device:
359+
m = M()
360+
packed_outputs = m(array_ops.ones([]))
361+
outputs = self.device.unpack(packed_outputs)
362+
self.assertAllClose([16., 16.], outputs)
363+
316364

317365
class LayerTests(_VirtualDeviceTestCase):
318366

@@ -340,6 +388,8 @@ def test_layer_forward(self):
340388
self.assertIn(self.device.components[1], outputs[1].backing_device)
341389

342390
def test_layer_sync_training(self):
391+
if self.device_type == "TPU":
392+
self.skipTest("ParallelDevice collectives on TPUs need work")
343393
with self.device:
344394
layer = _Dense(5)
345395

@@ -389,6 +439,8 @@ def test_layer_divergent_buffer_training(self):
389439
self.assertIn(self.device.components[1], final_kernels[1].backing_device)
390440

391441
def test_training_loop(self):
442+
if self.device_type == "TPU":
443+
self.skipTest("ParallelDevice collectives on TPUs need work")
392444
for _ in range(5):
393445
layer = _Dense(5)
394446
checkpoint = tracking.Checkpoint(layer=layer)

tensorflow/python/distribute/parallel_device/saving.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
import contextlib
2222
import functools
23+
import six
24+
import wrapt
2325

2426
from tensorflow.python.ops import gen_resource_variable_ops
2527
from tensorflow.python.ops import resource_variable_ops
@@ -47,14 +49,32 @@ def restore(self, tensors, restored_shapes=None):
4749
resource=self._handle, value=restored_tensor)
4850

4951

50-
class ParallelSavingMixin(resource_variable_ops.BaseResourceVariable):
51-
"""Mixin to to override variable checkpointing, saving each component."""
52+
_wrapt_type = type(wrapt.ObjectProxy)
53+
_variable_type = type(resource_variable_ops.BaseResourceVariable)
54+
if issubclass(_variable_type, _wrapt_type):
55+
# Some wrapt versions do not have a meta-class, which would create an invalid
56+
# MRO.
57+
VariableProxyMetaClass = _variable_type
58+
else:
59+
class VariableProxyMetaClass(_wrapt_type, _variable_type): # pylint: disable=duplicate-bases
60+
"""A combined MetaClasses for ParallelVariable.
5261
53-
def __init__(self, parallel_device, expected_shape=None, use_resource=None,
54-
**kwargs):
55-
del expected_shape, use_resource
56-
self._parallel_device = parallel_device
57-
super(ParallelSavingMixin, self).__init__(**kwargs)
62+
Satisfies the requirement "the metaclass of a derived class must be a
63+
(non-strict) subclass of the metaclasses of all its bases." At the time of
64+
writing these two MetaClasses are compatible (overriding different methods,
65+
both relatively trivial).
66+
"""
67+
pass
68+
69+
70+
class ParallelVariable(
71+
six.with_metaclass(VariableProxyMetaClass, wrapt.ObjectProxy,
72+
resource_variable_ops.BaseResourceVariable)):
73+
"""Overrides variable checkpointing, saving each component."""
74+
75+
def __init__(self, parallel_device, wrapped_variable):
76+
self._self_parallel_device = parallel_device
77+
super(ParallelVariable, self).__init__(wrapped_variable)
5878

5979
# TODO(allenl): Consider either adding a boolean argument for
6080
# save-primary-only or looking at synchronization/aggregation properties.
@@ -63,7 +83,8 @@ def _gather_saveables_for_checkpoint(self):
6383
component_saveables = {}
6484
# Create one SaveableObject per device, each one of which looks like a
6585
# regular ResourceVariable saveable.
66-
for index, handle in enumerate(self._parallel_device.unpack(self.handle)):
86+
for index, handle in enumerate(
87+
self._self_parallel_device.unpack(self.handle)):
6788
if index == 0:
6889
# This is the name regular tf.Variables use to save. Using it for the
6990
# component on the first device means non-parallel tf.Variable objects
@@ -80,26 +101,24 @@ def _gather_saveables_for_checkpoint(self):
80101
return component_saveables
81102

82103

83-
class ParallelVariable(
84-
ParallelSavingMixin, resource_variable_ops.ResourceVariable):
85-
pass
86-
87-
88-
class UninitializedParallelVariable(
89-
ParallelSavingMixin, resource_variable_ops.UninitializedVariable):
90-
pass
91-
92-
93-
def _variable_creator(next_creator, parallel_device, initial_value=None,
94-
**kwargs):
95-
del next_creator
96-
if initial_value is not None:
104+
def _variable_creator(next_creator, parallel_device, **kwargs):
105+
"""Wraps intercepted variables to add parallel saving."""
106+
# Depending on the context (SavedModel loading, tf.function, etc.) we may get
107+
# one of several different variable types. For variables placed on the
108+
# parallel device we only want to affect saving and otherwise preserve
109+
# behavior. This wrapping to override behavior is similar to tf.distribute's
110+
# DistributedVariable, but much more limited.
111+
variable = next_creator(**kwargs)
112+
if variable.device == parallel_device._name: # Friend access; pylint: disable=protected-access
97113
return ParallelVariable(
98-
parallel_device=parallel_device, initial_value=initial_value, **kwargs)
114+
parallel_device=parallel_device, wrapped_variable=variable)
99115
else:
100-
# SavedModel loading does not pass an initial value.
101-
return UninitializedParallelVariable(
102-
parallel_device=parallel_device, **kwargs)
116+
# Variables not placed on the handler (because of a device scope) don't
117+
# need wrapping.
118+
#
119+
# TODO(allenl): Device scopes should merge with parallel devices rather
120+
# than overriding them like this.
121+
return variable
103122

104123

105124
@contextlib.contextmanager

tensorflow/python/eager/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,6 +836,7 @@ py_library(
836836
"//tensorflow/python:util",
837837
"//tensorflow/python:variable_scope",
838838
"//tensorflow/python:while_v2", # TODO(b/118513001): Imported via control_flow_ops; remove.
839+
"//tensorflow/python/distribute/parallel_device",
839840
"//tensorflow/python/profiler:trace",
840841
"//tensorflow/python/training/tracking:base",
841842
],

tensorflow/python/eager/def_function.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@
2727
from google.protobuf import text_format as _text_format
2828
from google.protobuf.message import DecodeError
2929
from tensorflow.core.framework import attr_value_pb2
30+
from tensorflow.python.distribute.parallel_device import parallel_device
3031
from tensorflow.python.eager import context
3132
from tensorflow.python.eager import function as function_lib
3233
from tensorflow.python.eager import lift_to_graph
34+
from tensorflow.python.framework import errors
3335
from tensorflow.python.framework import func_graph as func_graph_module
3436
from tensorflow.python.framework import ops
3537
from tensorflow.python.ops import array_ops
@@ -430,6 +432,45 @@ def functions_run_eagerly():
430432
return RUN_FUNCTIONS_EAGERLY
431433

432434

435+
def _evaluate_var_is_initialized(variables):
436+
"""Compute booleans indicating whether each variable is initialized."""
437+
with ops.init_scope():
438+
var_is_initialized = []
439+
for v in variables:
440+
var_is_initialized.append(
441+
resource_variable_ops.var_is_initialized_op(v.handle))
442+
try:
443+
# Stack all the var_is_initialized values into one tensor and interpret
444+
# the numpy value. This will reduce the number of RPCs between client and
445+
# worker in the remote case.
446+
return array_ops.stack(var_is_initialized).numpy()
447+
except errors.UnimplementedError:
448+
# Some devices do not support implicit copy-off to host. Fall back to
449+
# variable-by-variable processing.
450+
for index, v in enumerate(variables):
451+
try:
452+
numpy_value = var_is_initialized[index].numpy()
453+
except errors.UnimplementedError:
454+
# This is a variable on a parallel device; we'll extract its value on
455+
# each replica and assert that they're identical.
456+
components = parallel_device.unpack(var_is_initialized[index])
457+
with ops.device(None):
458+
components = array_ops.stack(components)
459+
all_initialized = math_ops.reduce_all(components).numpy()
460+
any_initialized = math_ops.reduce_any(components).numpy()
461+
if all_initialized != any_initialized:
462+
raise NotImplementedError(
463+
("Some but not all components of a parallel variable {} were "
464+
"initialized between their creation in a tf.function and "
465+
"the function's trace having completed. This is not yet "
466+
"supported; consider initializing either all or none of the "
467+
"components, or moving initialization out of the function."
468+
).format(repr(v)))
469+
numpy_value = all_initialized
470+
var_is_initialized[index] = numpy_value
471+
return var_is_initialized
472+
473+
433474
class FunctionDeleter(object):
434475

435476
__slots__ = ["func_graph"]
@@ -1024,21 +1065,15 @@ def _initialize_uninitialized_variables(self, initializers):
10241065
if not initializers:
10251066
return
10261067

1068+
var_is_initialized = _evaluate_var_is_initialized(
1069+
[v for v, _ in initializers])
1070+
10271071
# Note: using defun here avoids an infinite recursion.
10281072
# Most of the code in this function runs eagerly with init_scope, where
10291073
# autograph is not necessary.
10301074
@function_lib.defun(autograph=False)
10311075
def initialize_variables():
10321076
op_map = object_identity.ObjectIdentityDictionary()
1033-
# Stack all the var_is_initialized values into one tensor and interpret
1034-
# the numpy value. This will reduce the number of RPCs between client and
1035-
# worker in the remote case.
1036-
with ops.init_scope():
1037-
var_is_initialized = []
1038-
for v, _ in initializers:
1039-
var_is_initialized.append(
1040-
resource_variable_ops.var_is_initialized_op(v.handle))
1041-
var_is_initialized = array_ops.stack(var_is_initialized).numpy()
10421077

10431078
inits = []
10441079
for (v, init), is_initialized in zip(initializers, var_is_initialized):

0 commit comments

Comments
 (0)
0