-
Notifications
You must be signed in to change notification settings - Fork 63
Closed
Labels
allocation domainissues related to allocation domain supportissues related to allocation domain support
Description
This is the repro script:
import torch
import nvfuser
from nvfuser.pytorch_utils import torch_dtype_to_nvfuser_dtype
import functools
from torch import Tensor, HalfTensor, BoolTensor
from typing import Tuple
DEBUG = True
if DEBUG:
WARMUP_ITERS = 0
ITERS = 1
else:
WARMUP_ITERS = 10
ITERS = 100
#REPRO = False
REPRO = True
if REPRO:
FORMAT = (3, 0, 2, 1)
else:
FORMAT = (3, 2, 1, 0)
def print_tensors(tensors, tensor_names):
for i, t in enumerate(tensors):
print(tensor_names[i] + ":", t.shape, t.stride(), t.dtype, t.requires_grad)
class bn_relu_jit(torch.autograd.Function):
@staticmethod
@torch.cuda.amp.custom_fwd(cast_inputs=torch.half)
def forward(ctx, input, scale, bias):
bn_relu_out, relu_mask = fwd_bn_relu_nvfuser(input, scale, bias)
# Debugging prints
if DEBUG:
print_tensors([input, scale, bias, bn_relu_out, relu_mask],
["fwd_input", "fwd_scale", "fwd_bias", "fwd_bn_relu_out", "fwd_relu_mask"])
ctx.save_for_backward(scale, relu_mask)
return bn_relu_out
@staticmethod
@torch.cuda.amp.custom_bwd
def backward(ctx, grad_output):
raise NotImplemented
### nvFuser Python frontend
def partially_contig_tensor(
fd: "nvfuser.FusionDefinition",
x: torch.Tensor,
) -> "nvfuser.Tensor":
return fd.define_tensor(
sizes=x.size(),
strides=x.stride(),
dtype=torch_dtype_to_nvfuser_dtype(x.dtype),
)
#return fd.define_tensor(
# shape=[-1] * x.ndim,
# contiguity=nvfuser.compute_contiguity(x.size(), x.stride()),
# dtype=torch_dtype_to_nvfuser_dtype(x.dtype),
#)
def fwd_bn_relu_nvfuser(input: HalfTensor, scale: HalfTensor, bias: HalfTensor) -> Tuple[HalfTensor, BoolTensor]:
tensors = [input, scale, bias]
with nvfuser.FusionDefinition() as fd:
x = partially_contig_tensor(fd, tensors[0])
s = partially_contig_tensor(fd, tensors[1])
b = partially_contig_tensor(fd, tensors[2])
z = fd.define_scalar(0)
T0 = fd.ops.mul(x, s)
T1 = fd.ops.add(T0, b)
T2 = fd.ops.relu(T1)
if REPRO:
T3 = fd.ops.cast(T2, dtype=nvfuser.DataType.Half)
else:
T3 = fd.ops.set(T2)
T4 = fd.ops.gt(T1, z)
fd.add_output(T3, FORMAT)
fd.add_output(T4, FORMAT)
bn_relu, relu_mask = fd.execute(tensors)
return bn_relu, relu_mask
### Inputs
inputs = [
torch.randn(32, 1024, 25, 25, device="cuda").as_strided((32, 1024, 25, 25), (640000, 1, 25600, 1024)).half().requires_grad_(),
torch.r
6A3E
andn(1, 1024, 1, 1, device="cuda").as_strided((1, 1024, 1, 1), (1024, 1, 1024, 1024)).half(),
torch.randn(1, 1024, 1, 1, device="cuda").as_strided((1, 1024, 1, 1), (1024, 1, 1024, 1024)).half(),
]
### Repro code
model = bn_relu_jit()
for i in range(WARMUP_ITERS):
o = model.apply(*inputs)
torch.cuda.profiler.start()
for i in range(ITERS):
o = model.apply(*inputs)
torch.cuda.profiler.stop()
with stride_order/alloc_domain properly plumbed through nvfuser codegen IR, we are actually seeing plummeting performance. 😢
running with NVFUSER_DUMP=python_definition,dump_eff_bandwidth,scheduler_params,segmenter_logging,fusion_ir python repro.py
I'm getting
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 0, 2, 1])
T1 = fd.define_tensor(shape=[1, -1, 1, 1], contiguity=[None, None, None, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 0, 2, 1])
T2 = fd.define_tensor(shape=[1, -1, 1, 1], contiguity=[None, None, None, True], dtype=DataType.Half, is_cpu=False, stride_order=[3, 0, 2, 1])
S3 = fd.define_scalar(0, dtype=DataType.Int)
T4 = fd.ops.mul(T0, T1)
T5 = fd.ops.add(T4, T2)
T6 = fd.ops.relu(T5)
T7 = fd.ops.cast(T6, dtype=DataType.Half)
T8 = fd.ops.gt(T5, S3)
fd.add_output(T7, stride_order=[3, 0, 2, 1])
fd.add_output(T8, stride_order=[3, 0, 2, 1])
===== Pointwise Stats ========
num_elems: 20480000
elem_counts: 32 1024 25 25
max_input_dtype_size: 2
vectorize_factor: 1
broadcast_byte_multiples: (0, 9), (5, 9), (9, 5), (9, 5), LHS elems: 32768 RHS elems: 625
===== Pointwise Parameters ========
Tag: Pointwise heuristics Pointwise Characteristics:
Gridx: 1 BlckY: 1 BlckX: 79
2D Schedule
Bcast break point: 2
Unroll, Factor: 8
====================================
kernelpointwise_f0_c1_r0_g0 run in 1.13971 ms, achieved: 89.8509 GB/s
I'll take a initial look at this to see if I can fix the vectorization analysis first.
Metadata
Metadata
Assignees
Labels
allocation domainissues related to allocation domain supportissues related to allocation domain support