-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Description
Goal
We'd like to support sufficient APIs to enable libtorch-ABI-stable flash attention 3. We see flash attention 3 as a simple-enough and complicated-enough custom op, where if we are able to support FA3, we open our doors to supporting the first echelon of custom op libraries.
The following tasks are based on an exercise where I went through FA3 and evaluated what the final state should look like if we swapped out the torch APIs FA3 currently uses with in-the-future stable ones. (https://docs.google.com/document/d/1hL-pMgmsj8YlEmZOdbRqW0yfwHyPlJ8usn0xW_GqjFA/edit?tab=t.0)
Setting the stage
We consider an API stable when it falls under one of 2 categories:
- it is headeonly = it is header only, it is completely free of libtorch, using the API does not require linking with any part of libtorch
- it is stable = ABI stable, it calls into libtorch through a C shim, which is guaranteed to not change in the near future
What does FA3 need?
We label the following tasks with S, M, L not in terms of amount of work necessary, but more so how much ambiguity we have yet to figure out. S means it is trivial, but not necessarily fast to do--for example, migrating an API to header only is trivial but can involve fighting lots of internal build errors.
-
S (in progress) Introduce STD_TORCH_CHECK API -- a header only TORCH_CHECK that calls into std::runtime_error instead of our c10::Error. This task requires some fundamental steps:
- Migrating c10/macros/Export.h to header only Add STD_TORCH_CHECK to headeronly #158377
- Migrating c10/macros/Macros.h to header only Add STD_TORCH_CHECK to headeronly #158377
- Migrating portions of Exception.h to support the new STD_TORCH_CHECK Add STD_TORCH_CHECK to headeronly #158377
-
M (in progress) A stable::Tensor struct with the following member functions:
-
data_ptr(),stride(-3),get_device(),is_cuda(),size(-2)torch::stable::Tensor beginnings, mainly mem mgmt #155367 -
is_contiguous()no arguments done in Add a basic shim and stable::Tensor is_contiguous API #156228 - S
dtype()--> requires kBFloat16 etccaffe2::TypeMetasto be header only - S
scalar_type()--> requires c10ScalarTypeto be header only Add ScalarType -> shim conversion, add stable::Tensor.scalar_type #160557 - L
options()--> do we want to makeTensorOptionsstable through the C shim? Or simply rewrite a high level C++ API that looks like TensorOptions? - We have called bankruptcy on
data_ptr<int>()as a workaround is to castsizes()as a workaround we can break the implementation down intosize()anddim()where relevant in https://github.com/Dao-AILab/flash-attention/blob/main/hopper/flash_api.cpp
-
-
M The following aten functions, which need a highlevel C++ API into a C shim'd version. Here, the C shim'd version could be manually added, generated by
fallback_ops.py, or called throughaoti_torch_call_dispatcher.- S
transpose(1, 2)- C shim API -- can go through
aoti_torch_call_dispatcher - highlevel nice UX C++ API Add transpose to torch/csrc/stable #158160
- C shim API -- can go through
- S
zero_()- C shim API -- can go through
aoti_torch_call_dispatcher - highlevel nice UX C++ API Add zero_() and empty_like(t) to torch/csrc/stable/ops.h #158866
- C shim API -- can go through
- S
fill_()- C shim API added Enable generating generic c_shim that doesn't bypass dispatcher # B7D5 158974
- highlevel nice UX C++ API Enable generating generic c_shim that doesn't bypass dispatcher #158974
- S
torch::empty_like(p)- C shim API -- can go through
aoti_torch_call_dispatcher - highlevel nice UX C++ API Add zero_() and empty_like(t) to torch/csrc/stable/ops.h #158866
- C shim API -- can go through
- S
torch::nn::functional::pad(x, torch::nn::functional::PadFuncOptions({0, alignment - x.size(-1) % alignment}))can just go throughat::padinstead!at::pad_symint(x, std::vector<int64_t> {0, alignment - x.size(-1) % alignment}). No need to support PadFuncOptions.- C shim API for
padAdd C shim for at::pad and fix some typos #155226 - highlevel nice UX C++ API
- C shim API for
- L
torch::empty({...}, opts.dtype(torch::kInt32));L because we need to decide how to expose TensorOptions - L -> S We're calling bankruptcy on
tile_count_semaphore.indexand all indexing related APIs (liketorch::indexing::slicein favor ofnarrow, which is much easier to support:- C shim API for
narrowAdd shim fallback for narrow #156496 - highlevel nice UX C++ API
- C shim API for
- S
-
M CUDA utilities
- M torch::stable::accelerator::DeviceGuard, torch::stable::Stream Add beginnings of torch::stable::accelerator #159679
- M There are sufficient cuda C shim APIs in shim.h for FA3, HOWEVER, longer term, we would like to encourage custom op writers to use accelerator generic DeviceGuard and Stream instead.
- S high level C++ APIs have to be added for all of these
- Calling bankruptcy on
torch::stable::cuda::getCurrentDevicePropertieswhich we should just recommend boiling down to CUDA APIcudaGetDeviceProperties.
- M torch::stable::accelerator::DeviceGuard, torch::stable::Stream Add beginnings of torch::stable::accelerator #159679
-
lowerpri L To not have the user write boxed kernels for their STABLE_TORCH_LIBRARY_IMPL, we should be able to autogenerate the boxed kernels ourselves given that we have access to the schema. This would require a little bit of lift from today’s reality.
- easy patternmatching codegen would be M, but we may want to be able to support this with dispatcher templates which would be L.
List of APIs to make standalone/header-only
For each of the following APIs, we would like to
- Move their definition to live in torch/headeronly/blahbyblah/blah.h
- Write tests using them in a CPP file without linking against libtorch
- Add them to torch/header_only_apis.txt
- Expose them in the torch::headeronly namespace (such that referring to torch::headeronly::kBFloat16 has the same functionality as torch::kBFloat16 but with the header-only guarantee).
- caffe2 TypeMetas, which should all live in a single header --> harder to migrate and can be supersumed by ScalarType already
-
torch::kBFloat16/at::kBFloat16—>torch::headeronly::kBFloat16 -
torch::kFloat8_e4m3fn/at::kFloat8_e4m3fn—>torch::headeronly::kFloat8_e4m3fn -
torch::kFloat16—>torch::headeronly::kFloat16 -
torch::kInt32—>torch::headeronly::kInt32 -
at::kFloat—>torch::headeronly::kFloat
-
- We're gonna need the following to live in torch/headeronly/ScalarType.h Migrate ScalarType to headeronly #159416
-
at::ScalarType—>torch::headeronly::ScalarType -
torch::headeronly::ScalarType::Half -
torch::headeronly::ScalarType::BFloat16 -
torch::headeronly::ScalarType::Float8_e4m3fn -
torch::headeronly::ScalarType::Float
-
-
torch::IntArrayRef({__VA_ARGS__})—>torch::headeronly::IntArrayRef({__VA_ARGS__})
in torch/headeronly/util/ArrayRef.h (based on c10/util)- punted to future for a bigger discussion, as ArrayRef depends on c10/util/SmallVector.h which is NOT header only!
-
TORCH_CHECK(makeheaderonly)Line 541 in c381103
// TORCH_CHECK throws std::runtime_error instead of c10::Error which is -
AOTI_TORCH_ERROR_CODE_CHECK-> write a generic TORCH_ERROR_CODE_CHECK Cut a version of TORCH_ERROR_CODE_CHECK in headeronly from AOTI #159604