8000 Fix Fusion IR cloning (#567) · mcarilli/pytorch@c6d8c4a · GitHub
[go: up one dir, main page]

Skip to content

Commit c6d8c4a

Browse files
authored
Fix Fusion IR cloning (pytorch#567)
Fixes pytorch#566
1 parent cd1242b commit c6d8c4a

File tree

5 files changed

+30
-35
lines changed

5 files changed

+30
-35
lines changed

torch/csrc/jit/codegen/cuda/fusion.cpp

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,41 +72,25 @@ Fusion::Fusion(const Fusion& other) {
7272
val_set_.insert(ir_cloner.clone(val));
7373
}
7474

75+
for (auto expr : other.expr_set_) {
76+
expr_set_.insert(ir_cloner.clone(expr));
77+
}
78+
7579
for (auto val : other.val_deque_) {
7680
val_deque_.push_back(ir_cloner.clone(val));
7781
}
7882

79-
for (auto old_expr : other.expr_set_) {
80-
auto new_expr = ir_cloner.clone(old_expr);
81-
expr_set_.insert(new_expr);
82-
83-
// ir_cloner doesn't go through registerStmt, so we need to "Register Expr"
84-
// we would similarly need to do to val if there was in that pass that is
85-
// also not covered here.
86-
for (Val* input : new_expr->inputs()) {
87-
auto uses_copy = input->uses();
88-
if (std::find(uses_copy.begin(), uses_copy.end(), new_expr) ==
89-
uses_copy.end()) {
90-
uses_copy.push_back(new_expr);
91-
input->setUses(uses_copy);
92-
}
93-
}
83+
// Fixup potentially cyclic pointers
84+
for (auto val : val_set_) {
85+
val->definition_ = ir_cloner.clone(val->definition_);
86+
val->uses_ = ir_cloner.clone(val->uses_);
9487
}
9588

9689
val_type_name_map_ = other.val_type_name_map_;
9790
expr_name_counter_ = other.expr_name_counter_;
9891

9992
inputs_ = ir_cloner.clone(other.inputs_);
10093
outputs_ = ir_cloner.clone(other.outputs_);
101-
102-
for (auto inp : inputs_) {
103-
inp->setIsFusionInput(true);
104-
}
105-
for (auto out : outputs_) {
106-
out->setIsFusionOutput(true);
107-
}
108-
109-
resetTvUses();
11094
}
11195

11296
Fusion::Fusion(Fusion&& other) noexcept {
@@ -421,16 +405,16 @@ void Fusion::resetTvUses() {
421405
// remove dead exprs, this could reinsert them. getExprs is also boundeds by
422406
// inputs as registered inputs will return nullptr as their definition.
423407
const auto all_tvs = ir_utils::filterByType<TensorView>(val_set_);
424-
auto used_exprs = ExprSort::getExprs(this);
408+
const auto used_exprs = ExprSort::getExprs(this);
425409

426410
for (auto tv : all_tvs) {
427-
tv->setUses(std::deque<Expr*>());
411+
tv->setUses({});
428412
}
429413

430414
// Same as in register expr
431415
for (auto expr : used_exprs) {
432416
for (Val* input : expr->inputs()) {
433-
std::deque<Expr*> uses_copy = input->uses();
417+
auto uses_copy = input->uses();
434418
if (std::find(uses_copy.begin(), uses_copy.end(), expr) ==
435419
uses_copy.end()) {
436420
uses_copy.push_back(expr);

torch/csrc/jit/codegen/cuda/ir_base_nodes.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,19 @@ Val::Val(ValType _vtype, DataType _dtype, bool register_val)
5353
}
5454
}
5555

56+
// NOTE: we don't clone the definition_ and uses_ here
57+
// since they may introduce cloning cycles. Instead, we copy
58+
// the original pointers and we'll fix them up later part of the
59+
// Fusion copy
60+
//
5661
Val::Val(const Val* src, IrCloner* ir_cloner)
5762
: Statement(src, ir_cloner),
5863
vtype_(src->vtype_),
5964
dtype_(src->dtype_),
60-
definition_(ir_cloner->clone(src->definition())) {}
65+
is_fusion_input_(src->is_fusion_input_),
66+
is_fusion_output_(src->is_fusion_output_),
67+
definition_(src->definition_),
68+
uses_(src->uses_) {}
6169

6270
namespace {
6371

torch/csrc/jit/codegen/cuda/ir_base_nodes.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include <torch/csrc/jit/codegen/cuda/utils.h>
1010

1111
#include <cstdint>
12-
#include <deque>
1312
#include <iostream>
1413
#include <limits>
1514
#include <memory>
@@ -214,7 +213,7 @@ class TORCH_CUDA_API Val : public Statement {
214213
return definition_;
215214
}
216215

217-
const std::deque<Expr*>& uses() const {
216+
const auto& uses() const {
218217
return uses_;
219218
}
220219

@@ -272,7 +271,7 @@ class TORCH_CUDA_API Val : public Statement {
272271
is_fusion_output_ = is_fusion_output;
273272
}
274273

275-
void setUses(std::deque<Expr*> uses) {
274+
void setUses(const std::vector<Expr*>& uses) {
276275
uses_ = uses;
277276
}
278277

@@ -282,7 +281,7 @@ class TORCH_CUDA_API Val : public Statement {
282281
bool is_fusion_output_ = false;
283282

284283
Expr* definition_ = nullptr;
285-
std::deque<Expr*> uses_;
284+
std::vector<Expr*> uses_;
286285
};
287286

288287
//! A Expr represents a "computation." These are functions that takes inputs

torch/csrc/jit/codegen/cuda/ir_cloner.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@ namespace cuda {
1313

1414
class Fusion;
1515

16-
// Clones nodes from an exiting Fusion
16+
//! Clones nodes from an exiting Fusion
17+
//!
18+
//! \warning IrCloner machinery is a specialized helper for implementing
19+
//! Fusion copy operations and it's not intended for any other uses
20+
//!
1721
class TORCH_CUDA_API IrCloner : private OptInConstDispatch {
1822
friend class Statement;
1923

torch/csrc/jit/codegen/cuda/root_domain_map.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::map(
7171
TORCH_INTERNAL_ASSERT(producer_tv_->domain() == producer);
7272
TORCH_INTERNAL_ASSERT(consumer_tv_->domain() == consumer);
7373

74-
if (consumer_tv_->getOrigin()->isA<TransposeOp>()) {
74+
if (consumer_tv_->definition()->isA<TransposeOp>()) {
7575
return mapTranspose(
7676
producer, consumer, root_dims_to_map, producer_to_consumer);
7777
}
@@ -126,7 +126,7 @@ std::unordered_map<IterDomain*, IterDomain*> PairwiseRootDomainMap::
126126

127127
std::unordered_map<IterDomain*, IterDomain*> dom_map;
128128

129-
TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->getOrigin());
129+
TransposeOp* top = dynamic_cast<TransposeOp*>(consumer_tv_->definition());
130130
TORCH_INTERNAL_ASSERT(top != nullptr);
131131

132132
const auto& new2old = top->new2old();

0 commit comments

Comments
 (0)
0