8000 Adding cost estimator for scatter operations · compnerd/tensorflow@71f8d91 · GitHub
[go: up one dir, main page]

Skip to content

Commit 71f8d91

Browse files
darylngtensorflower-gardener
authored andcommitted
Adding cost estimator for scatter operations
PiperOrigin-RevId: 252852416
1 parent 550741c commit 71f8d91

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

tensorflow/core/grappler/costs/op_level_cost_estimator.cc

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,13 @@ constexpr char kStopGradient[] = "StopGradient";
6060
constexpr char kPreventGradient[] = "PreventGradient";
6161
constexpr char kGather[] = "Gather";
6262
constexpr char kGatherV2[] = "GatherV2";
63+
constexpr char kScatterAdd[] = "ScatterAdd";
64+
constexpr char kScatterDiv[] = "ScatterDiv";
65+
constexpr char kScatterMax[] = "ScatterMax";
66+
constexpr char kScatterMin[] = "ScatterMin";
67+
constexpr char kScatterMul[] = "ScatterMul";
68+
constexpr char kScatterSub[] = "ScatterSub";
69+
constexpr char kScatterUpdate[] = "ScatterUpdate";
6370
constexpr char kSlice[] = "Slice";
6471
constexpr char kMaxPool[] = "MaxPool";
6572
constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
@@ -275,6 +282,14 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
275282

276283
{kGather, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
277284
{kGatherV2, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
285+
{kScatterAdd, wrap(&OpLevelCostEstimator::PredictScatter)},
286+
{kScatterDiv, wrap(&OpLevelCostEstimator::PredictScatter)},
287+
{kScatterMax, wrap(&OpLevelCostEstimator::PredictScatter)},
288+
{kScatterMin, wrap(&OpLevelCostEstimator::PredictScatter)},
289+
{kScatterMul, wrap(&OpLevelCostEstimator::PredictScatter)},
290+
{kScatterSub, wrap(&OpLevelCostEstimator::PredictScatter)},
291+
{kScatterUpdate, wrap(&OpLevelCostEstimator::PredictScatter)},
292+
278293
{kSlice, wrap(&OpLevelCostEstimator::PredictGatherOrSlice)},
279294

280295
{kPlaceholder, wrap(&OpLevelCostEstimator::PredictIdentity)},
@@ -1551,6 +1566,53 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
15511566
return costs;
15521567
}
15531568

1569+
Costs OpLevelCostEstimator::PredictScatter(const OpContext& op_context) const {
1570+
// Scatter ops sparsely access a reference input and output tensor.
1571+
const auto& op_info = op_context.op_info;
1572+
bool found_unknown_shapes = false;
1573+
1574+
// input[0]: ref tensor that will be sparsely accessed
1575+
// input[1]: indices - A tensor of indices into the first dimension of ref.
1576+
// input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
1577+
// See
1578+
// https://www.tensorflow.org/api_docs/python/tf/scatter_add and
1579+
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
1580+
1581+
const int64 num_indices =
1582+
CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1583+
1584+
int64 num_elems_in_ref_per_index = 1;
1585+
auto ref_tensor_shape = MaybeGetMinimumShape(
1586+
op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
1587+
&found_unknown_shapes);
1588+
for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
1589+
num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
1590+
}
1591+
const int64 op_count = num_indices * num_elems_in_ref_per_index;
1592+
1593+
// Sparsely access ref so input size depends on the number of operations
1594+
int64 ref_input_size =
1595+
op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
1596+
int64 indices_input_size =
1597+
CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1598+
int64 updates_input_size =
1599+
CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1600+
1601+
double total_input_size =
1602+
ref_input_size + indices_input_size + updates_input_size;
1603+
1604+
// Sparsely access ref so output size depends on the number of operations
1605+
double total_output_size =
1606+
op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
1607+
1608+
auto costs = PredictOpCountBasedCost(op_count, total_input_size,
1609+
total_output_size, op_info);
1610+
costs.inaccurate = found_unknown_shapes;
1611+
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1612+
1613+
return costs;
1614+
}
1615+
15541616
Costs OpLevelCostEstimator::PredictFusedOp(
15551617
const OpContext& op_context,
15561618
const std::vector<OpContext>& fused_op_contexts) const {

tensorflow/core/grappler/costs/op_level_cost_estimator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ class OpLevelCostEstimator {
141141
Costs PredictBatchMatMul(const OpContext& op_context) const;
142142
Costs PredictMetadata(const OpContext& op_context) const;
143143
Costs PredictGatherOrSlice(const OpContext& op_context) const;
144+
Costs PredictScatter(const OpContext& op_context) const;
144145
Costs PredictMaxPool(const OpContext& op_context) const;
145146
Costs PredictMaxPoolGrad(const OpContext& op_context) const;
146147
Costs PredictAvgPool(const OpContext& op_context) const;

tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -612,6 +612,57 @@ TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
612612
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
613613
}
614614

615+
TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
616+
std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
617+
"ScatterMin", "ScatterMul", "ScatterSub",
618+
"ScatterUpdate"};
619+
for (const auto& op : scatter_ops) {
620+
// Test updates.shape = indices.shape + ref.shape[1:]
621+
{
622+
OpContext op_context;
623+
SetCpuDevice(&op_context.op_info);
624+
op_context.op_info.set_op(op);
625+
// Huge first dimension in input shouldn't affect Scatter execution and
626+
// memory costs.
627+
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
628+
DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
629+
DescribeArbitraryRankInput({16, 10}, DT_FLOAT, &op_context.op_info);
630+
DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
631+
&op_context.op_info);
632+
633+
auto cost = estimator_.PredictCosts(op_context);
634+
EXPECT_EQ(Costs::Duration(205), cost.memory_time);
635+
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
636+
EXPECT_EQ(Costs::Duration(221), cost.execution_time);
637+
EXPECT_EQ(1, cost.num_ops_total);
638+
EXPECT_FALSE(cost.inaccurate);
639+
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
640+
}
641+
642+
// Test updates.shape = [] and INT32 indices
643+
{
644+
OpContext op_context;
645+
SetCpuDevice(&op_context.op_info);
646+
op_context.op_info.set_op(op);
647+
// Huge first dimension in input shouldn't affect Scatter execution and
648+
// memory costs.
649+
DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
650+
DescribeArbitraryRankInput({16}, DT_INT32, &op_context.op_info);
651+
DescribeArbitraryRankInput({}, DT_FLOAT, &op_context.op_info);
652+
DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
653+
&op_context.op_info);
654+
655+
auto cost = estimator_.PredictCosts(op_context);
656+
EXPECT_EQ(Costs::Duration(135), cost.memory_time);
657+
EXPECT_EQ(Costs::Duration(16), cost.compute_time);
658+
EXPECT_EQ(Costs::Duration(151), cost.execution_time);
659+
EXPECT_EQ(1, cost.num_ops_total);
660+
EXPECT_FALSE(cost.inaccurate);
661+
EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
662+
}
663+
}
664+
}
665+
615666
TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
616667
auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
617668
EXPECT_EQ(Costs::Duration(8400), cost.memory_time);

0 commit comments

Comments
 (0)
0