8000 None on "Introduce unsafe way to mark functions as cacheable" · pytorch/pytorch@83e2d32 · GitHub
[go: up one dir, main page]

Skip to content

Commit 83e2d32

Browse files
committed
None on "Introduce unsafe way to mark functions as cacheable"
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov [ghstack-poisoned]
2 parents a14d391 + 411b9bf commit 83e2d32

File tree

4 files changed

+98
-13
lines changed

4 files changed

+98
-13
lines changed

test/inductor/test_codecache.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,11 +1471,15 @@ def backend(gm_, args_, **kwargs_):
14711471
@config.patch({"fx_graph_cache": True})
14721472
@config.patch({"fx_graph_remote_cache": False})
14731473
@functorch_config.patch({"enable_autograd_cache": True})
1474+
@parametrize("device", (GPU_TYPE, "cpu"))
14741475
@parametrize("format", ("binary", "unpacked"))
14751476
@parametrize("dynamic", (False, True))
1476-
def test_basic(self, format: str, dynamic: bool) -> None:
1477-
mod = torch.nn.Linear(1, 3)
1478-
x = torch.randn(4, 1)
1477+
def test_basic(self, device: str, format: str, dynamic: bool) -> None:
1478+
if device == GPU_TYPE and not HAS_GPU:
1479+
raise unittest.SkipTest(f"requires {GPU_TYPE}")
1480+
1481+
mod = torch.nn.Linear(1, 3, device=device)
1482+
x = torch.randn 10000 (4, 1, device=device)
14791483
if dynamic:
14801484
torch._dynamo.mark_dynamic(x, 0)
14811485

@@ -1562,6 +1566,65 @@ def f(x):
15621566
compiled_out = loaded(*args)[0]
15631567
self.assertEqual(eager_out, compiled_out)
15641568

1569+
@config.patch({"fx_graph_cache": True})
1570+
@config.patch({"fx_graph_remote_cache": False})
1571+
@functorch_config.patch({"enable_autograd_cache": True})
1572+
@parametrize("device", (GPU_TYPE, "cpu"))
1573+
def test_modify_unpacked_file(self, device: str) -> None:
1574+
if device == GPU_TYPE and not HAS_GPU:
1575+
raise unittest.SkipTest(f"requires {GPU_TYPE}")
1576+
1577+
x = torch.ones(4, device=device)
1578+
1579+
def f(x):
1580+
with torch.no_grad():
1581+
return 2 * x, x.sin()
1582+
1583+
eager_out = f(x)
1584+
1585+
with tempfile.TemporaryDirectory() as temp_dir:
1586+
with fresh_inductor_cache():
1587+
gm, args, kwargs = self.capture(f)(x)
1588+
assert not kwargs
1589+
1590+
compiled_artifact = torch._inductor.standalone_compile(gm, args)
1591+
compiled_out = compiled_artifact(*args)
1592+
self.assertEqual(eager_out, compiled_out)
1593+
1594+
compiled_artifact.save(path=temp_dir, format="unpacked")
1595+
1596+
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
1597+
1598+
with fresh_inductor_cache():
1599+
# Now modify the output file and expect to see the changes
1600+
for subdir in os.listdir(temp_dir):
1601+
if subdir in ["aotautograd", "fxgraph"]:
1602+
continue
1603+
subdir_path = os.path.join(temp_dir, subdir)
1604+
for file in os.listdir(subdir_path):
1605+
file_path = os.path.join(subdir_path, file)
1606+
assert os.path.isfile(file_path)
1607+
with open(file_path) as f:
1608+
file_contents = f.read()
1609+
if device == GPU_TYPE:
1610+
file_contents = file_contents.replace(
1611+
"tmp1 = 2.0", "tmp1 = 8.0"
1612+
)
1613+
else:
1614+
assert device == "cpu"
1615+
file_contents = file_contents.replace(
1616+
"auto tmp1 = static_cast<float>(2.0);",
1617+
"auto tmp1 = static_cast<float>(8.0);",
1618+
)
1619+
with open(file_path, "w") as f:
1620+
f.write(file_contents)
1621+
1622+
loaded = torch._inductor.CompiledArtifact.load(
1623+
path=temp_dir, format="unpacked"
1624+
)
1625+
compiled_out = loaded( 8000 *args)
1626+
self.assertEqual(4 * eager_out[0], compiled_out[0])
1627+
15651628
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
15661629

15671630
@unittest.skipIf(IS_FBCODE, "torch import error")

torch/_inductor/output_code.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -654,14 +654,9 @@ def prepare_for_serialization(self) -> None:
654654
self.current_callable = None
655655
self.recursively_apply_fns = None
656656

657-
def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
658-
from torch._dynamo.utils import counters, dynamo_timed
659-
from torch._inductor.codecache import (
660-
cpp_prefix_path,
661-
get_path,
662-
PyCodeCache,
663-
write_atomic,
664-
)
657+
def write_to_disk(self) -> str:
658+
from torch._dynamo.utils import counters
659+
from torch._inductor.codecache import cpp_prefix_path, get_path, write_atomic
665660

666661
# See _save_graph(); we don't store the callable in the cache entry so
667662
# recreate it here from the PyCodeCache disk cache.
@@ -682,6 +677,13 @@ def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
682677
self.source_code = code
683678

684679
write_atomic(artifact_path, code, make_dirs=True)
680+
return artifact_path
681+
682+
def after_deserialization(self, constants: CompiledFxGraphConstants) -> str:
683+
from torch._dynamo.utils import dynamo_timed
684+
from torch._inductor.codecache import PyCodeCache
685+
686+
artifact_path = self.write_to_disk()
685687

686688
try:
687689
with dynamo_timed(

torch/_inductor/runtime/cache_dir_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,15 @@ def triton_cache_dir(device: int) -> str:
3737

3838
@contextmanager
3939
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
40+
from torch._inductor.utils import clear_inductor_caches
41+
4042
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
4143
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
4244
try:
45+
clear_inductor_caches()
4346
yield
4447
finally:
48+
clear_inductor_caches()
4549
if original is None:
4650
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
4751
else:

torch/_inductor/standalone_compile.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import copy
44
import logging
55
import os
6+
import pickle
67
import shutil
78
from contextlib import AbstractContextManager, nullcontext
89
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
@@ -69,7 +70,7 @@ def save(
6970
"CompiledArtifact.save failed to save since there's no artifact to save"
7071
)
7172
artifact_bytes, cache_info = self._artifacts
72-
assert len(cache_info.aot_autograd_artifacts) == 1
73+
assert len(cache_info.aot_autograd_artifacts) == 1, cache_info
7374
key = cache_info.aot_autograd_artifacts[0]
7475

7576
if format == "binary":
@@ -92,9 +93,24 @@ def save(
9293
assert os.path.isdir(path)
9394
shutil.rmtree(path, ignore_errors=True)
9495

96+
from .codecache import FxGraphCache
97+
9598
with temporary_cache_dir(path):
9699
# This function unpacks the cache artifacts to disk
97-
torch.compiler.load_cache_artifacts(artifact_bytes)
100+
loaded_cache_info = torch.compiler.load_cache_artifacts(
101+
artifact_bytes
102+
)
103+
assert loaded_cache_info is not None
104+
# Now write all the output_code artifacts to disk so that
105+
# they can be inspected and modified
106+
for key in loaded_cache_info.inductor_artifacts:
107+
subdir = FxGraphCache._get_tmp_dir_for_key(key)
108+
assert os.path.exists(subdir)
109+
for path in sorted(os.listdir(subdir)):
110+
with open(os.path.join(subdir, path), "rb") as f:
111+
graph = pickle.load(f)
112+
output_file = graph.write_to_disk()
113+
log.info("Output code written to: %s", output_file)
98114

99115
@staticmethod
100116
def load(

0 commit comments

Comments
 (0)
0