8000 Merge pull request #3041 from minrk/requirepush · ipython/ipython@4be107e · GitHub
[go: up one dir, main page]

Skip to content
Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Commit 4be107e

Browse files
committed
Merge pull request #3041 from minrk/requirepush
support non-modules in @require
2 parents 331b90a + ebaf52e commit 4be107e

File tree

3 files changed

+102
-23
lines changed

3 files changed

+102
-23
lines changed

IPython/parallel/controller/dependency.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* Min RK
66
"""
77
#-----------------------------------------------------------------------------
8-
# Copyright (C) 2010-2011 The IPython Development Team
8+
# Copyright (C) 2013 The IPython Development Team
99
#
1010
# Distributed under the terms of the BSD License. The full license is in
1111
# the file COPYING, distributed as part of this software.
@@ -17,6 +17,7 @@
1717
from IPython.parallel.error import UnmetDependency
1818
from IPython.parallel.util import interactive
1919
from IPython.utils import py3compat
20+
from IPython.utils.pickleutil import can, uncan
2021

2122
class depend(object):
2223
"""Dependency decorator, for use with tasks.
@@ -58,12 +59,12 @@ def __init__(self, f, df, *dargs, **dkwargs):
5859
self.df = df
5960
self.dargs = dargs
6061
self.dkwargs = dkwargs
61-
62-
def __call__(self, *args, **kwargs):
63-
# if hasattr(self.f, 'func_globals') and hasattr(self.df, 'func_globals'):
64-
# self.df.func_globals = self.f.func_globals
62+
63+
def check_dependency(self):
6564
if self.df(*self.dargs, **self.dkwargs) is False:
6665
raise UnmetDependency()
66+
67+
def __call__(self, *args, **kwargs):
6768
return self.f(*args, **kwargs)
6869

6970
if not py3compat.PY3:
@@ -72,41 +73,62 @@ def __name__(self):
7273
return self.func_name
7374

7475
@interactive
75-
def _require(*names):
76+
def _require(*modules, **mapping):
7677
"""Helper for @require decorator."""
7778
from IPython.parallel.error import UnmetDependency
79+
from IPython.utils.pickleutil import uncan
7880
user_ns = globals()
79-
for name in names:
80-
if name in user_ns:
81-
continue
81+
for name in modules:
8282
try:
83-
exec 'import %s'%name in user_ns
83+
exec 'import %s' % name in user_ns
8484
except ImportError:
8585
raise UnmetDependency(name)
86+
87+
for name, cobj in mapping.items():
88+
user_ns[name] = uncan(cobj, user_ns)
8689
return True
8790

88-
def require(*mods):
89-
"""Simple decorator for requiring names to be importable.
91+
def require(*objects, **mapping):
92+
"""Simple decorator for requiring local objects and modules to be available
93+
when the decorated function is called on the engine.
94+
95+
Modules specified by name or passed directly will be imported
96+
prior to calling the decorated function.
97+
98+
Objects other than modules will be pushed as a part of the task.
99+
Functions can be passed positionally,
100+
and will be pushed to the engine with their __name__.
101+
Other objects can be passed by keyword arg.
90102
91103
Examples
92104
--------
93105
94106
In [1]: @require('numpy')
95107
...: def norm(a):
96-
...: import numpy
97108
...: return numpy.linalg.norm(a,2)
109+
110+
In [2]: foo = lambda x: x*x
111+
In [3]: @require(foo)
112+
...: def bar(a):
113+
...: return foo(1-a)
98114
"""
99115
names = []
100-
for mod in mods:
101-
if isinstance(mod, ModuleType):
102-
mod = mod.__name__
116+
for obj in objects:
117+
if isinstance(obj, ModuleType):
118+
obj = obj.__name__
103119

104-
if isinstance(mod, basestring):
105-
names.append(mod)
120+
if isinstance(obj, basestring):
121+
names.append(obj)
122+
elif hasattr(obj, '__name__'):
123+
mapping[obj.__name__] = obj
106124
else:
107-
raise TypeError("names must be modules or module names, not %s"%type(mod))
125+
raise TypeError("Objects other than modules and functions "
126+
"must be passed by kwarg, but got: %s" % type(obj)
127+
)
108128

109-
return depend(_require, *names)
129+
for name, obj in mapping.items():
130+
mapping[name] = can(obj)
131+
return depend(_require, *names, **mapping)
110132

111133
class Dependency(set):
112134
"""An object for representing a set of msg_id dependencies.

IPython/parallel/tests/test_dependency.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def wait(n):
3737
time.sleep(n)
3838
return n
3939

40+
@pmod.interactive
41+
def func(x):
42+
return x*x
43+
4044
mixed = map(str, range(10))
4145
completed = map(str, range(0,10,2))
4246
failed = map(str, range(1,10,2))
@@ -104,3 +108,29 @@ def test_failure_only(self):
104108
dep.all=False
105109
self.assertUnmet(dep)
106110
self.assertUnreachable(dep)
111+
112+
def test_require_function(self):
113+
114+
@pmod.interactive
115+
def bar(a):
116+
return func(a)
117+
118+
@pmod.require(func)
119+
@pmod.interactive
120+
def bar2(a):
121+
return func(a)
122+
123+
self.client[:].clear()
124+
self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
125+
ar = self.view.apply_async(bar2, 5)
126+
self.assertEqual(ar.get(5), func(5))
127+
128+
def test_require_object(self):
129+
130+
@pmod.require(foo=func)
131+
@pmod.interactive
132+
def bar(a):
133+
return foo(a)
134+
135+
ar = self.view.apply_async(bar, 5)
136+
self.assertEqual(ar.get(5), func(5))

IPython/utils/pickleutil.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,26 @@
4949

5050

5151
class CannedObject(object):
52-
def __init__(self, obj, keys=[]):
52+
def __init__(self, obj, keys=[], hook=None):
53+
"""can an object for safe pickling
54+
55+
Parameters
56+
==========
57+
58+
obj:
59+
The object to be canned
60+
keys: list (optional)
61+
list of attribute names that will be explicitly canned / uncanned
62+
hook: callable (optional)
63+
An optional extra callable,
64+
which can do additional processing of the uncanned object.
65+
66+
large data may be offloaded into the buffers list,
67+
used for zero-copy transfers.
68+
"""
5369
self.keys = keys
5470
self.obj = copy.copy(obj)
71+
self.hook = can(hook)
5572
for key in keys:
5673
setattr(self.obj, key, can(getattr(obj, key)))
5774

@@ -60,8 +77,13 @@ def __init__(self, obj, keys=[]):
6077
def get_object(self, g=None):
6178
if g is None:
6279
g = {}
80+
obj = self.obj
6381
for key in self.keys:
64-
setattr(self.obj, key, uncan(getattr(self.obj, key), g))
82+
setattr(obj, key, uncan(getattr(obj, key), g))
83+
84+
if self.hook:
85+
self.hook = uncan(self.hook, g)
86+
self.hook(obj, g)
6587
return self.obj
6688

6789

@@ -302,6 +324,11 @@ def uncan_sequence(obj, g=None):
302324
else:
303325
return obj
304326

327+
def _uncan_dependent_hook(dep, g=None):
328+
dep.check_dependency()
329+
330+
def can_dependent(obj):
331+
return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
305332

306333
#-------------------------------------------------------------------------------
307334
# API dictionaries
@@ -310,7 +337,7 @@ def uncan_sequence(obj, g=None):
310337
# These dicts can be extended for custom serialization of new objects
311338

312339
can_map = {
313-
'IPython.parallel.dependent' : lambda obj: CannedObject(obj, keys=('f','df')),
340+
'IPython.parallel.dependent' : can_dependent,
314341
'numpy.ndarray' : CannedArray,
315342
FunctionType : CannedFunction,
316343
bytes : CannedBytes,

0 commit comments

Comments
 (0)
0