8000 [dynamo, logging] Move extra graph_code logging to a verbose artifact · Issue #153646 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
[dynamo, logging] Move extra graph_code logging to a verbose artifact #153646
@williamwen42

Description

@williamwen42

TORCH_LOGS=graph_code is getting too spammy - it should only log the output of _dynamo/output_graph.py, but there are some FX passes that also log to the graph_code artifact. We should move the latter logs to a new logging artifact, e.g. graph_code_verbose.

e.g.

import torch

def f1(x):
    return x + 1

def f2(x):
    torch._dynamo.graph_break()
    return x + 2

def f3(x):
    return f1(x) + f2(x)

torch.compile(f3, backend="eager")(torch.ones(3))

Today's output:

$ TORCH_LOGS_FORMAT="" TORCH_LOGS="graph_code" python playground2.py 
TRACED GRAPH
 ===== pre insert_deferred_runtime_asserts __compiled_fn_2 =====
 <eval_with_key>.0 class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3]"):
        l_x_ = L_x_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:4 in f1, code: return x + 1
        add: "f32[3]" = l_x_ + 1;  l_x_ = None
        return (add,)
        

TRACED GRAPH
 ===== __compiled_fn_2 =====
 /data/users/williamwen/pytorch2/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:4 in f1, code: return x + 1
        add: "f32[3][1]cpu" = l_x_ + 1;  l_x_ = None
        return (add,)
        

TRACED GRAPH
 ===== pre insert_deferred_runtime_asserts __compiled_fn_8 =====
 <eval_with_key>.2 class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3]"):
        l_x_ = L_x_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:8 in torch_dynamo_resume_in_f2_at_7, code: return x + 2
        add: "f32[3]" = l_x_ + 2;  l_x_ = None
        return (add,)
        

TRACED GRAPH
 ===== __compiled_fn_8 =====
 /data/users/williamwen/pytorch2/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:8 in torch_dynamo_resume_in_f2_at_7, code: return x + 2
        add: "f32[3][1]cpu" = l_x_ + 2;  l_x_ = None
        return (add,)
        

TRACED GRAPH
 ===== pre insert_deferred_runtime_asserts __compiled_fn_10 =====
 <eval_with_key>.4 class GraphModule(torch.nn.Module):
    def forward(self, L_stack0_: "f32[3]", L_stack1_: "f32[3]"):
        l_stack0_ = L_stack0_
        l_stack1_ = L_stack1_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:11 in torch_dynamo_resume_in_f3_at_11, code: return f1(x) + f2(x)
        add: "f32[3]" = l_stack0_ + l_stack1_;  l_stack0_ = l_stack1_ = None
        return (add,)
        

TRACED GRAPH
 ===== __compiled_fn_10 =====
 /data/users/williamwen/pytorch2/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
    def forward(self, L_stack0_: "f32[3][1]cpu", L_stack1_: "f32[3][1]cpu"):
        l_stack0_ = L_stack0_
        l_stack1_ = L_stack1_
        
         # File: /data/users/williamwen/pytorch2/playground2.py:11 in torch_dynamo_resume_in_f3_at_11, code: return f1(x) + f2(x)
        add: "f32[3][1]cpu" = l_stack0_ + l_stack1_;  l_stack0_ = l_stack1_ = None
        return (add,)

I only want to see ===== __compiled_fn_# ===== outputs logged to graph_code, everything else should go to graph_code_verbose.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Metadata

Assignees

No one assigned

    Labels

    good first issuemodule: dynamomodule: loggingFeatures which make it easier to tell what PyTorch is doing under the hoodoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0