8000 Update on "[Cutlass] Implement Epilogue Argument emitter" · pytorch/pytorch@750b7f5 · GitHub
[go: up one dir, main page]

Skip to content

Commit 750b7f5

Browse files
committed
Update on "[Cutlass] Implement Epilogue Argument emitter"
This implements epilogue visitor tree argument generation (example type [here](https://github.com/NVIDIA/cutlass/blob/3fe62887d8dd75700fdaf57f9c181878701b0802/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp#L332)). Details: The codegen task here is to implement a function which can generate a tree of C++ structs and properly extract the correct properties from Inductor buffers and write them to the correct locations in the generated struct. To implement this with the minimum amount of code, I generate the cutlass DAGIR (the EVT internal represenation) which specifically has a pass, [pass_argument_type.py ](https://github.com/NVIDIA/cutlass/blob/5e497243f7ad13a2aa842143f9b10bbb23d98292/python/cutlass/backend/evt/passes/pass_argument_type.py#L4) which generates a nested tree of custom argument types for each node in the DAGIR. This nested tree of constructors is then passed kwargs to fill in the proper values, where the node's name is used to differentiate between different values in the kwarg dictionary. This however is non-customizable; the nested tree of EVT args is a nested tree of ctypes which looks for *actual values* so that this object can be passed directly to the cutlass-python C++ runner. Inductor on the other hand needs to fill this struct with string C++ expressions representing the values (or extracting the values from kernel launcher args). So `_render_argument_type` implements this: it iterates over the tree of types created by pass_argument_type.py and generates a string representing the nested structs, filling in C++ expressions representing the different fields. Long term plan: Long term, I will ask the nvidia to provide an overridable [visitor_factory](https://github.com/NVIDIA/cutlass/blob/5e497243f7ad13a2aa842143f9b10bbb23d98292/python/cutlass/backend/evt/passes/pass_argument_type.py#L82) which could allow us to override the behavior of pass_argument_type.py to generate the string we would like during DAGIR generation. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
2 parents 2d629c2 + 10e2ace commit 750b7f5

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+995
-111
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3366,7 +3366,7 @@ static std::vector<Tensor> _pad_chunk(
33663366
std::vector<int64_t> view_sizes(
33673367
tensor_size.begin(), tensor_size.begin() + dim);
33683368
view_sizes.insert(view_sizes.end(), {num_chunks, -1});
3369-
padded_tensors.push_back(padded_tensor.view(view_sizes));
3369+
padded_tensors.push_back(padded_tensor.reshape(view_sizes));
33703370
}
33713371
return padded_tensors;
33723372
}

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -612,28 +612,41 @@ struct check_binary_functor_types_for_specialization<
612612
};
613613

614614
// The following is a list of type specializations for vectorized_templated
615-
// elementwise kernel. It refers to the first and second runtime types of the
616-
// arguments of a binary functor.
617-
615+
// elementwise kernel. The three types refer to runtime types of the output
616+
// tensor, first tensor argument, and the second tensor argument used for a
617+
// binary functor.
618618
constexpr std::array rt_binary_specializations = {
619-
std::array<c10::ScalarType, 2>(
619+
std::array<c10::ScalarType, 3>(
620620
{c10::CppTypeToScalarType<float>::value,
621+
c10::CppTypeToScalarType<float>::value,
621622
c10::CppTypeToScalarType<BFloat16>::value}),
622-
std::array<c10::ScalarType, 2>(
623+
std::array<c10::ScalarType, 3>(
624+
{c10::CppTypeToScalarType<float>::value,
625+
c10::CppTypeToScalarType<BFloat16>::value,
626+
c10::CppTypeToScalarType<float>::value}),
627+
std::array<c10::ScalarType, 3>(
623628
{c10::CppTypeToScalarType<BFloat16>::value,
629+
c10::CppTypeToScalarType<BFloat16>::value,
624630
c10::CppTypeToScalarType<float>::value}),
625-
std::array<c10::ScalarType, 2>(
631+
std::array<c10::ScalarType, 3>(
626632
{c10::CppTypeToScalarType<float>::value,
633+
c10::CppTypeToScalarType<float>::value,
627634
c10::CppTypeToScalarType<Half>::value}),
628-
std::array<c10::ScalarType, 2>(
635+
std::array<c10::ScalarType, 3>(
636+
{c10::CppTypeToScalarType<float>::value,
637+
c10::CppTypeToScalarType<Half>::value,
638+
c10::CppTypeToScalarType<float>::value}),
639+
std::array<c10::ScalarType, 3>(
629640
{c10::CppTypeToScalarType<Half>::value,
641+
c10::CppTypeToScalarType<Half>::value,
630642
c10::CppTypeToScalarType<float>::value})};
631643

632644
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
633645
if (iter.ninputs() != 2)
634646
return false;
635647
for (auto spec : rt_binary_specializations)
636-
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
648+
if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] &&
649+
iter.input_dtype(1) == spec[2])
637650
return true;
638651
return false;
639652
}
@@ -648,6 +661,7 @@ struct type_specialized_kernel_launcher {
648661
typename loader_t,
649662
typename storer_t>
650663
static void apply(
664+
ScalarType ret_t,
651665
ScalarType arg0_t,
652666
ScalarType arg1_t,
653667
int64_t numel,
@@ -657,22 +671,22 @@ struct type_specialized_kernel_launcher {
657671
out_calc_t output_offset_calculator,
658672
loader_t loader,
659673
storer_t storer) {
660-
using traits = function_traits<func_t>;
661-
using return_t = typename traits::result_type;
662-
if (arg0_t == rt_binary_specializations[arg_index][0] &&
663-
arg1_t == rt_binary_specializations[arg_index][1])
674+
if (ret_t == rt_binary_specializations[arg_index][0] &&
675+
arg0_t == rt_binary_specializations[arg_index][1] &&
676+
arg1_t == rt_binary_specializations[arg_index][2])
664677
launch_vectorized_templated_kernel<
665678
func_t,
666679
array_t,
667680
inp_calc_t,
668681
out_calc_t,
669682
loader_t,
670683
storer_t,
671-
return_t,
672684
decltype(c10::impl::ScalarTypeToCPPType<
673685
rt_binary_specializations[arg_index][0]>::t),
674686
decltype(c10::impl::ScalarTypeToCPPType<
675-
rt_binary_specializations[arg_index][1]>::t)>(
687+
rt_binary_specializations[arg_index][1]>::t),
688+
decltype(c10::impl::ScalarTypeToCPPType<
689+
rt_binary_specializations[arg_index][2]>::t)>(
676690
numel,
677691
f,
678692
data,
@@ -712,7 +726,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
712726
#ifdef USE_ROCM
713727
// Attempt to call specialized vectorized elementwise kernel
714728
// that enables interleaving.
715-
716729
if (check_binary_rt_types_for_specialization(iter) &&
717730
memory::can_vectorize_up_to<func_t>(data) > 1) {
718731
// constexpr to reduce the amount of kernels generated for
@@ -740,6 +753,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
740753
type_specialized_kernel_launcher,
741754
rt_binary_specializations.size()>::
742755
with_args(
756+
iter.dtype(0),
743757
iter.input_dtype(0),
744758
iter.input_dtype(1),
745759
numel,

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ struct vectorized_templated {
407407
// float(float,bfloat16) and functor add on float(float,float).
408408
template <typename scalar_t>
409409
__device__ inline void store(scalar_t* from, int idx) {
410-
using vec_t = aligned_vector<scalar_t, vec_size>;
411-
scalar_t* to = reinterpret_cast<scalar_t*>(data[0]) + block_work_size * idx;
410+
using vec_t = aligned_vector<CastToT, vec_size>;
411+
CastToT* to = reinterpret_cast<CastToT*>(data[0]) + block_work_size * idx;
412412
vec_t* to_ = reinterpret_cast<vec_t*>(to);
413413
int thread_idx = threadIdx.x;
414414
#pragma unroll

aten/src/ATen/native/cuda/TensorShape.cu

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,12 @@ static __global__ void chunk_cat_cuda_kernel(
422422
}
423423

424424
bool all_contiguous(TensorList tensors) {
425-
bool contiguous = true;
426425
for (const auto& t : tensors) {
427-
contiguous &= t.is_non_overlapping_and_dense();
426+
if (!t.is_contiguous()) {
427+
return false;
428+
}
428429
}
429-
return contiguous;
430+
return true;
430431
}
431432

432433
// Get leading dimensions before `dim`-th dimension.

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ REGISTER_AVX512_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
449449
REGISTER_VSX_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
450450
REGISTER_ZVECTOR_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
451451
REGISTER_SVE256_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cpp)
452-
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta);
452+
REGISTER_HPU_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_meta)
453453

454454
int64_t _fused_sdp_choice_meta(
455455
const Tensor& query_,

cmake/Dependencies.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,9 @@ if(USE_DISTRIBUTED AND USE_TENSORPIPE)
11771177
set(CMAKE_POLICY_VERSION_MINIMUM 3.5)
11781178
endif()
11791179
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/tensorpipe)
1180+
# Suppress warning to unblock libnop comiplation by clang-17
1181+
# See https://github.com/pytorch/pytorch/issues/151316
1182+
target_compile_options_if_supported(tensorpipe -Wno-missing-template-arg-list-after-template-kw)
11801183
if(CMAKE_VERSION VERSION_GREATER_EQUAL "4.0.0")
11811184
unset(CMAKE_POLICY_VERSION_MINIMUM)
11821185
endif()

test/cpp_extensions/open_registration_extension/setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import distutils.command.clean
22
import os
3+
import platform
34
import shutil
45
import sys
56
from pathlib import Path
@@ -40,6 +41,9 @@ def run(self):
4041
CXX_FLAGS = ["/sdl"]
4142
else:
4243
CXX_FLAGS = ["/sdl", "/permissive-"]
44+
elif platform.machine() == "s390x":
45+
# no -Werror on s390x due to newer compiler
46+
CXX_FLAGS = {"cxx": ["-g", "-Wall"]}
4347
else:
4448
CXX_FLAGS = {"cxx": ["-g", "-Wall", "-Werror"]}
4549

test/inductor/test_aot_inductor_package.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,63 @@ def forward(self, a):
467467
output = compiled(test_inputs)
468468
self.assertEqual(expected, output)
469469

470+
@skipif(
471+
lambda device, package_cpp_only: device == "cpu" or package_cpp_only,
472+
"No support for cpp only and cpu",
473+
)
474+
def test_package_user_managed_weight(self):
475+
class Model(torch.nn.Module):
476+
def __init__(self, n, k, device):
477+
super().__init__()
478+
self.linear = torch.nn.Linear(k, n, device=device)
479+
480+
def forward(self, a):
481+
return self.linear(a)
482+
483+
M, N, K = 128, 4096, 4096
484+
model = Model(N, K, self.device)
485+
example_inputs = (torch.randn(M, K, device=self.device),)
486+
487+
inductor_configs = {
488+
"always_keep_tensor_constants": True,
489+
"aot_inductor.package_constants_in_so": False,
490+
}
491+
compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
492+
493+
self.assertEqual(
494+
set(compiled.get_constant_fqns()), set(model.state_dict().keys())
495+
)
496+
497+
compiled.load_constants(
498+
model.state_dict(), check_full_update=True, user_managed=False
499+
)
500+
501+
test_inputs = torch.randn(M, K, device=self.device)
502+
expected = model(test_inputs)
503+
output = compiled(test_inputs)
504+
self.assertEqual(expected, output)
505+
506+
# Let's try to modify the weight in-place, result shouldn't change.
507+
model.linear.weight.data *= 3.7
508+
new_output = compiled(test_inputs)
509+
self.assertEqual(new_output, output)
510+
511+
# Recreate a new model that we will test against user_managed=True
512+
new_compiled = compile(model, example_inputs, inductor_configs=inductor_configs)
513+
new_compiled.load_constants(
514+
model.state_dict(), check_full_update=True, user_managed=True
515+
)
516+
517+
expected = model(test_inputs)
518+
new_output = new_compiled(test_inputs)
519+
self.assertEqual(expected, new_output)
520+
521+
# Try to modify the weight in-place, result should change.
522+
model.linear.weight.data *= 3.7
523+
expected = model(test_inputs)
524+
new_output = new_compiled(test_inputs)
525+
self.assertEqual(new_output, expected)
526+
470527
def test_deepcopy_compiled_model(self):
471528
class Model(torch.nn.Module):
472529
def forward(self, x, y):

test/inductor/test_codecache.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import os
44
import pickle
55
import shutil
6+
import subprocess
7+
import sys
68
import tempfile
79
import unittest
810
from typing import Optional, Union
@@ -11,6 +13,7 @@
1113
import torch
1214
from torch._dynamo import reset
1315
from torch._dynamo.utils import counters
16+
from torch._functorch import config as functorch_config
1417
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
1518
from torch._inductor import config, metrics
1619
from torch._inductor.codecache import (
@@ -35,6 +38,7 @@
3538 10000
from torch.testing._internal.common_device_type import largeTensorTest
3639
from torch.testing._internal.common_utils import (
3740
instantiate_parametrized_tests,
41+
IS_FBCODE,
3842
parametrize,
3943
TEST_WITH_ROCM,
4044
)
@@ -1376,6 +1380,132 @@ def forward(self, x):
13761380
)
13771381

13781382

1383+
@instantiate_parametrized_tests
1384+
class TestStandaloneCompile(TestCase):
1385+
def setUp(self):
1386+
super().setUp()
1387+
counters.clear()
1388+
PatchCaches.setUp()
1389+
CacheArtifactManager.clear()
1390+
1391+
def tearDown(self):
1392+
super().tearDown()
1393+
PatchCaches.tearDown()
1394+
1395+
def reset(self):
1396+
AOTAutogradCache.clear()
1397+
PyCodeCache.cache_clear(purge=True)
1398+
torch._dynamo.reset()
1399+
clear_inductor_caches()
1400+
1401+
def capture(self, fn):
1402+
def inner(*args):
1403+
gm = None
1404+
actual_args = None
1405+
kwargs = None
1406+
1407+
def backend(gm_, args_, **kwargs_):
1408+
nonlocal gm
1409+
nonlocal actual_args
1410+
nonlocal kwargs
1411+
gm = gm_
1412+
actual_args = args_
1413+
kwargs = kwargs_
1414+
return gm
1415+
1416+
_ = torch.compile(fn, fullgraph=True, backend=backend)(*args)
1417+
return gm, actual_args, kwargs
1418+
1419+
return inner
1420+
1421+
@config.patch({"fx_graph_cache": True})
1422+
@config.patch({"fx_graph_remote_cache": False})
1423+
@functorch_config.patch({"enable_autograd_cache": True})
1424+
@parametrize("format", ("binary", "unpacked"))
1425+
@parametrize("dynamic", (False, True))
1426+
def test_basic(self, format: str, dynamic: bool) -> None:
1427+
mod = torch.nn.Linear(1, 3)
1428+
x = torch.randn(4, 1)
1429+
if dynamic:
1430+
torch._dynamo.mark_dynamic(x, 0)
1431+
1432+
def f(x):
1433+
with torch.no_grad():
1434+
return mod(x)
1435+
1436+
eager_out = f(x)
1437+
1438+
with tempfile.TemporaryDirectory() as temp_dir:
1439+
path = (
1440+
temp_dir
1441+
if format == "unpacked"
1442+
else os.path.join(temp_dir, "compiled_artifact.bin")
1443+
)
1444+
with fresh_inductor_cache():
1445+
gm, args, kwargs = self.capture(f)(x)
1446+
assert not kwargs
1447+
1448+
compiled_artifact = torch._inductor.standalone_compile(gm, args)
1449+
compiled_artifact.save(path=path, format=format)
1450+
1451+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
1452+
1453+
with fresh_inductor_cache():
1454+
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
1455+
compiled_out = loaded(*args)
1456+
self.assertEqual(eager_out, compiled_out)
1457+
1458+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
1459+
1460+
@unittest.skipIf(IS_FBCODE, "torch import error")
1461+
@config.patch({"fx_graph_cache": True})
1462+
@config.patch({"fx_graph_remote_cache": False})
1463+
@functorch_config.patch({"enable_autograd_cache": True})
1464+
def test_different_process(self):
1465+
x = torch.ones(4, 1)
1466+
1467+
def f(x):
1468+
return x.sin() * 2
1469+
1470+
gm, args, kwargs = self.capture(f)(x)
1471+
assert not kwargs
1472+
1473+
with tempfile.TemporaryDirectory() as temp_dir:
1474+
path = os.path.join(temp_dir, "compiled_artifact.bin")
1475+
1476+
with fresh_inductor_cache():
1477+
compiled_artifact = torch._inductor.standalone_compile(gm, args)
1478+
compiled_artifact.save(path=path)
1479+
1480+
script = f"""
1481+
import torch
1482+
from torch._inductor.utils import fresh_inductor_cache
1483+
1484+
arg = torch.ones(4, 1)
1485+
with fresh_inductor_cache():
1486+
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
1487+
compiled_result = loaded(arg)
1488+
1489+
eager_result = arg.sin() * 2
1490+
1491+
if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
1492+
raise RuntimeError("tensors do not match")
1493+
"""
1494+
try:
1495+
subprocess.check_output(
1496+
[sys.executable, "-c", script],
1497+
stderr=subprocess.STDOUT,
1498+
cwd=os.path.dirname(os.path.realpath(__file__)),
1499+
)
1500+
except subprocess.CalledProcessError as e:
1501+
self.fail(
1502+
msg=(
1503+
"Subprocess exception while attempting to run test: "
1504+
+ e.output.decode("utf-8")
1505+
)
1506+
)
1507+
1508+
13791509
class TestFxGraphCacheHashing(TestCase):
13801510
def test_parameter_constants(self):
13811511
"""

0 commit comments

Comments
 (0)
0