8000 Fix division on MPS · pytorch/pytorch@ec475d3 · GitHub
[go: up one dir, main page]

Skip to content

Commit ec475d3

Browse files
committed
Fix division on MPS
1 parent dc1c302 commit ec475d3

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

aten/src/ATen/native/mps/operations/BinaryOps.mm

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,11 @@ void div_mode_template(const Tensor& self, const Tensor& other,
143143
const Tensor& output, const string op_name)
144144
{
145145
BinaryOpBlock div_mode_op_block = ^BinaryOpFn() {
146-
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primary
147-
secondaryTensor:secondary
146+
// Division needs float inputs, result should also always be float32
147+
MPSGraphTensor* primaryFloat32 = [mpsGraph castTensor:primary toType:MPSDataTypeFloat32 name:@"castPrimary"];
148+
MPSGraphTensor* secondaryFloat32 = [mpsGraph castTensor:secondary toType:MPSDataTypeFloat32 name:@"castSecondary"];
149+
MPSGraphTensor* divTensor = [mpsGraph divisionWithPrimaryTensor:primaryFloat32
150+
secondaryTensor:secondaryFloat32
148151
name:nil];
149152
if (!rounding_mode.has_value()) {
150153
return divTensor;

0 commit comments

Comments
 (0)
0