@@ -60,6 +60,13 @@ constexpr char kStopGradient[] = "StopGradient";
60
60
constexpr char kPreventGradient [] = " PreventGradient" ;
61
61
constexpr char kGather [] = " Gather" ;
62
62
constexpr char kGatherV2 [] = " GatherV2" ;
63
+ constexpr char kScatterAdd [] = " ScatterAdd" ;
64
+ constexpr char kScatterDiv [] = " ScatterDiv" ;
65
+ constexpr char kScatterMax [] = " ScatterMax" ;
66
+ constexpr char kScatterMin [] = " ScatterMin" ;
67
+ constexpr char kScatterMul [] = " ScatterMul" ;
68
+ constexpr char kScatterSub [] = " ScatterSub" ;
69
+ constexpr char kScatterUpdate [] = " ScatterUpdate" ;
63
70
constexpr char kSlice [] = " Slice" ;
64
71
constexpr char kMaxPool [] = " MaxPool" ;
65
72
constexpr char kMaxPoolGrad [] = " MaxPoolGrad" ;
@@ -275,6 +282,14 @@ OpLevelCostEstimator::OpLevelCostEstimator() {
275
282
276
283
{kGather , wrap (&OpLevelCostEstimator::PredictGatherOrSlice)},
277
284
{kGatherV2 , wrap (&OpLevelCostEstimator::PredictGatherOrSlice)},
285
+ {kScatterAdd , wrap (&OpLevelCostEstimator::PredictScatter)},
286
+ {kScatterDiv , wrap (&OpLevelCostEstimator::PredictScatter)},
287
+ {kScatterMax , wrap (&OpLevelCostEstimator::PredictScatter)},
288
+ {kScatterMin , wrap (&OpLevelCostEstimator::PredictScatter)},
289
+ {kScatterMul , wrap (&OpLevelCostEstimator::PredictScatter)},
290
+ {kScatterSub , wrap (&OpLevelCostEstimator::PredictScatter)},
291
+ {kScatterUpdate , wrap (&OpLevelCostEstimator::PredictScatter)},
292
+
278
293
{kSlice , wrap (&OpLevelCostEstimator::PredictGatherOrSlice)},
279
294
280
295
{kPlaceholder , wrap (&OpLevelCostEstimator::PredictIdentity)},
@@ -1551,6 +1566,53 @@ Costs OpLevelCostEstimator::PredictGatherOrSlice(
1551
1566
return costs;
1552
1567
}
1553
1568
1569
+ Costs OpLevelCostEstimator::PredictScatter (const OpContext& op_context) const {
1570
+ // Scatter ops sparsely access a reference input and output tensor.
1571
+ const auto & op_info = op_context.op_info ;
1572
+ bool found_unknown_shapes = false ;
1573
+
1574
+ // input[0]: ref tensor that will be sparsely accessed
1575
+ // input[1]: indices - A tensor of indices into the first dimension of ref.
1576
+ // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
1577
+ // See
1578
+ // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
1579
+ // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
1580
+
1581
+ const int64 num_indices =
1582
+ CalculateTensorElementCount (op_info.inputs (1 ), &found_unknown_shapes);
1583
+
1584
+ int64 num_elems_in_ref_per_index = 1 ;
1585
+ auto ref_tensor_shape = MaybeGetMinimumShape (
1586
+ op_info.inputs (0 ).shape (), op_info.inputs (0 ).shape ().dim_size (),
1587
+ &found_unknown_shapes);
1588
+ for (int i = 1 ; i < ref_tensor_shape.dim ().size (); ++i) {
1589
+ num_elems_in_ref_per_index *= ref_tensor_shape.dim (i).size ();
1590
+ }
1591
+ const int64 op_count = num_indices * num_elems_in_ref_per_index;
1592
+
1593
+ // Sparsely access ref so input size depends on the number of operations
1594
+ int64 ref_input_size =
1595
+ op_count * DataTypeSize (BaseType (op_info.inputs (0 ).dtype ()));
1596
+ int64 indices_input_size =
1597
+ CalculateTensorSize (op_info.inputs (1 ), &found_unknown_shapes);
1598
+ int64 updates_input_size =
1599
+ CalculateTensorSize (op_info.inputs (2 ), &found_unknown_shapes);
1600
+
1601
+ double total_input_size =
1602
+ ref_input_size + indices_input_size + updates_input_size;
1603
+
1604
+ // Sparsely access ref so output size depends on the number of operations
1605
+ double total_output_size =
1606
+ op_count * DataTypeSize (BaseType (op_info.outputs (0 ).dtype ()));
1607
+
1608
+ auto costs = PredictOpCountBasedCost (op_count, total_input_size,
1609
+ total_output_size, op_info);
1610
+ costs.inaccurate = found_unknown_shapes;
1611
+ costs.num_ops_with_unknown_shapes = found_unknown_shapes;
1612
+
1613
+ return costs;
1614
+ }
1615
+
1554
1616
Costs OpLevelCostEstimator::PredictFusedOp (
1555
1617
const OpContext& op_context,
1556
1618
const std::vector<OpContext>& fused_op_contexts) const {
0 commit comments