8000 [Easy] Fix the format by fffrog · Pull Request #158450 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/_inductor/fx_passes/fuse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,7 @@ def _sfdp_replacement_20(query, key, value, attn_mask, dropout_p):
def _sfdp_pattern_24(query, key, value, attention_mask):
"""
this pattern is for MBartForCausalLM/PLBartForCausalLM.
attn_mask has a differnt dtype with QKV.
attn_mask has a different dtype with QKV.
there is no scale in sdpa.
"""
bs = query.size(0)
Expand Down
5 changes: 3 additions & 2 deletions torch/fx/traceback.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def __init__(
self.from_node = []

# cache the action string and dict representation for performance.
self._action_string = None
self._dict = None
self._action_string: Optional[str] = None
self._dict: Optional[dict[str, Any]] = None

@property
def name(self) -> str:
Expand Down Expand Up @@ -130,6 +130,7 @@ def to_dict(self) -> dict:
"from_node": [node.to_dict() for node in self.from_node],
}

assert self._dict is not None
return self._dict

def __eq__(self, other: object):
Expand Down
Loading
0