-
Notifications
You must be signed in to change notification settings - Fork 24.2k
auto functionalize base_hop #151067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/ydwu4/232/base
Are you sure you want to change the base?
auto functionalize base_hop #151067
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/151067
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 0b7695b with merge base 7a0781e ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozh 8000 e blzheng jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
# However, it's up to each hop's functioanlization implementation to decide | ||
# whether do_auto_functionalized is called when input mutation happens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uh what does this mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just want to clarify that do_auto_functionalized is called in hop's functionalization key's implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see the base_hop functionalization key: https://github.com/pytorch/pytorch/pull/151067/files#diff-193a1ca345897131c5b7bde79857e6c7c6a6b407284b82f6c680c3b3bd59d477R141-R144. Sorry i mean it should be do_auto_functionalized_v2 not do_auto_functionalized, is it the confusion here?
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
cc ezyang SherlockNoMad EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
if isinstance(_mutable_op, HigherOrderOperator): | ||
# Note [materialize callable inputs as graph] | ||
# Below code materializes the callable inputs to the hop as graph modules. | ||
# kwargs may contain general callables, that are not proxable e.g. FunctionWithNoFreeVars | ||
# this could happen when we auto_functionalize the backward of the hop, | ||
# where backward fn is a callablle that wrapps forward graph module. | ||
# This function materialize the callable args according to the schema of the hop. | ||
|
||
# We cannot materialize the callables in kwargs directly because the inputs to callable | ||
# vary from hops to hop. To make the materialiation process generic to all hops, | ||
# we trace a function that wraps the hop and let each hop itself figure out how to trace | ||
# its callable inputs. Then we look at the schema of the traced hop node and replace the | ||
# callable in original kwarg with the traced subgraphs. | ||
# | ||
# Specifically, we first trace a wrapped_fn that calls into the hop. Then we look for the | ||
# hop node in the traced graph and graph module inputs to the hop. Finally, we replace the | ||
# original kwarg's callable with the graph module. | ||
all_bases = kwargs.get("_all_bases", []) | ||
_only_clone_these_bases = kwargs.get("_only_clone_these_bases", None) | ||
if _only_clone_these_bases is None: | ||
_only_clone_these_bases = tuple(range(len(all_bases))) | ||
|
||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema | ||
new_kwargs, _ = _generate_new_op_kwargs_from_bases( | ||
schema, | ||
{k: v for k, v in kwargs.items() if k not in ("_all_bases", "_op_schema")}, | ||
all_bases, | ||
_only_clone_these_bases, | ||
) | ||
|
||
def wrapped_fn(*args): | ||
return _invoke_op_with_kwargs_and_schema(_mutable_op, new_kwargs, schema) # type: ignore[arg-type] | ||
|
||
# We need to trace the higher order op in order to materilaize the callable inputs that | ||
# are a callable (e.g. after functionalization key) | ||
gm = reenter_make_fx(wrapped_fn)(pytree.tree_leaves(new_kwargs)) | ||
hop_node = gm.graph.find_nodes(op="call_function", target=_mutable_op)[0] | ||
arg_proxies = pytree.tree_leaves((hop_node.args, hop_node.kwargs)) | ||
assert isinstance(schema, torch._C.FunctionSchema) and len(arg_proxies) == len( | ||
schema.arguments | ||
) | ||
|
||
# _invoke_op_with_kwargs_and_schema preserves ordering of proxies via schema | ||
materialized_arg = {} | ||
for proxy, arg in zip(arg_proxies, schema.arguments): | ||
if ( | ||
isinstance(proxy, torch.fx.Node) | ||
and proxy.op == "get_attr" | ||
and isinstance(getattr(gm, proxy.target), torch.fx.GraphModule) # type: ignore[arg-type] | ||
): | ||
assert callable(kwargs[arg.name]), (schema, arg.name, kwargs) | ||
materialized_arg[arg.name] = getattr(gm, proxy.target) # type: ignore[arg-type] | ||
|
||
# Update kwargs with materialized graphs | ||
kwargs.update(materialized_arg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you shove this into a helper function? Given a HOP, some args, and some kwargs, it would return the materialized graphs.
Or maybe it just returns a new set of (args, kwargs) with the materialized graphs.
def materialize_graphs(hop, args, kwargs) -> (args, kwargs):
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can also go into higher_order_ops.utils
|
||
if isinstance(out, tuple): | ||
return (*out, *all_bases_new) # type: ignore[return-value] | ||
else: | ||
return (out, *all_bases_new) # type: ignore[return-value] | ||
|
||
|
||
def _invoke_op_with_kwargs_and_schema( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe just "call_hop"? Can go into torch/_higher_order_ops.utils
I feel like we should split call_hop vs call_op. Because passing the schema for an opoverload doesn't make too much sense (we end up ignoring it).
e.g.
if isinstance(op, OpOverload):
return op(**kwargs)
else:
return call_hop(hop, schema, kwargs)
(hop, schema, kwargs) seems like the right order (schema is tied to the hop)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you did do the HopInstance thing. Alternative, this can just be:
def call_op(op: Union[OpOverload, HOPInstance], args, kwargs):
if isinstance(op, OpOverload):
return op(*args, **kwargs)
# HOPInstance case: pull schema from it
schema = op._schema
...
if isinstance(op, HigherOrderOperator): | ||
assert ( | ||
len(schema.returns) > 0 | ||
), f"hop is expected to return at least on output {schema}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: on -> one
assert can_auto_functionalize(_mutable_op) | ||
_op_to_check: Optional[Union[OpOverload, HopInstance]] = None | ||
if isinstance(_mutable_op, HigherOrderOperator): | ||
schema = pytree.tree_unflatten([], kwargs.get("_op_schema", None)).schema |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens if _op_schema is None? Does tree_unflatten([], None) work? Maybe we wanted an assertion here?
type(op), method | ||
) is not getattr(HigherOrderOperator, method) | ||
|
||
return _has_gen_schema(op._op) and op._schema.is_mutable |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: You also probably want the other checks that we do on schema below (lines 431-450)?
@@ -520,7 +570,7 @@ def sync_update(o, orig_arg): | |||
|
|||
def do_auto_functionalize_v2( | |||
mode: "torch._subclasses.functional_tensor.FunctionalTensorMode", | |||
op: OpOverload, | |||
op: _MutableOpType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused, the callsite looks like this:
pytorch/torch/_subclasses/functional_tensor.py
Lines 444 to 454 in daca611
if can_auto_functionalize( | |
func | |
) and not torch._C._dispatch_has_kernel_for_dispatch_key( | |
func.name(), torch._C.DispatchKey.Functionalize | |
): | |
import torch._inductor.config as inductor_config | |
if self.export or not inductor_config.enable_auto_functionalized_v2: | |
return do_auto_functionalize(self, func, args, kwargs) | |
else: | |
return do_auto_functionalize_v2(self, func, args, kwargs) |
since we changed can_auto_functionalize to take in a HOPInstance, we should also change do_auto_functionalize_v2 to take in a HOPInstance. So the code becomes:
func = op isinstance(OpOverload) else HOPInstance(op)
if can_auto_functionalize(func):
do_auto_functionalize_v2(func)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I see the problem...
HOPs should go through FunctionalTensorMode.__torch_dispatch__
. Unfortunately, they don't.
We can still change do_auto_functionalize_v2 to accept a HOPInstance though.
from torch._higher_order_ops.auto_functionalize import do_auto_functionalize_v2 | ||
|
||
# invoke_quant has non-proxable argument of type InvokeQuant that | ||
# we cannot generate schema for. | ||
if self is not torch.ops.higher_order.invoke_quant_packed: | ||
hop_schema = self.gen_schema(subgraph, *operands, **kwargs) | ||
if hop_schema.is_mutable: | ||
return do_auto_functionalize_v2( | ||
ctx.mode, self, (subgraph, *operands), kwargs | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gah
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay seems fine
warnings.warn( | ||
"Aliasing is not suppported for HOP subgraph.\n" | ||
f"{subgraph.print_readable(print_output=False)}\n" | ||
f"Alias info: inp-inp alias: {inp_inp_alias}, inp-out alias: {inp_out_alias}, out-out alias{out_out_alias}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
say something like this may lead to silent incorrectness. Since this is silent incorrectness we should try to fix this asap (not in this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My comments are mostly cosmetic, stamping to unblock
Stack from ghstack (oldest at bottom):
cc @ezyang @SherlockNoMad @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov