8000 [Tracker] Support flash attention fa3 ABI stable w/ libtorch · Issue #154908 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

[Tracker] Support flash attention fa3 ABI stable w/ libtorch #154908

@janeyx99

Description

@janeyx99

Goal

We'd like to support sufficient APIs to enable libtorch-ABI-stable flash attention 3. We see flash attention 3 as a simple-enough and complicated-enough custom op, where if we are able to support FA3, we open our doors to supporting the first echelon of custom op libraries.

The following tasks are based on an exercise where I went through FA3 and evaluated what the final state should look like if we swapped out the torch APIs FA3 currently uses with in-the-future stable ones. (https://docs.google.com/document/d/1hL-pMgmsj8YlEmZOdbRqW0yfwHyPlJ8usn0xW_GqjFA/edit?tab=t.0)

Setting the stage

We consider an API stable when it falls under one of 2 categories:

  • it is headeonly = it is header only, it is completely free of libtorch, using the API does not require linking with any part of libtorch
  • it is stable = ABI stable, it calls into libtorch through a C shim, which is guaranteed to not change in the near future

What does FA3 need?

We label the following tasks with S, M, L not in terms of amount of work necessary, but more so how much ambiguity we have yet to figure out. S means it is trivial, but not necessarily fast to do--for example, migrating an API to header only is trivial but can involve fighting lots of internal build errors.

List of APIs to make standalone/header-only

For each of the following APIs, we would like to

  1. Move their definition to live in torch/headeronly/blahbyblah/blah.h
  2. Write tests using them in a CPP file without linking against libtorch
  3. Add them to torch/header_only_apis.txt
  4. Expose them in the torch::headeronly namespace (such that referring to torch::headeronly::kBFloat16 has the same functionality as torch::kBFloat16 but with the header-only guarantee).
  • caffe2 TypeMetas, which should all live in a single header --> harder to migrate and can be supersumed by ScalarType already
    • torch::kBFloat16 / at::kBFloat16 —> torch::headeronly::kBFloat16
    • torch::kFloat8_e4m3fn / at::kFloat8_e4m3fn —> torch::headeronly::kFloat8_e4m3fn
    • torch::kFloat16 —> torch::headeronly::kFloat16
    • torch::kInt32 —> torch::headeronly::kInt32
    • at::kFloat —> torch::headeronly::kFloat
  • We're gonna need the following to live in torch/headeronly/ScalarType.h Migrate ScalarType to headeronly #159416
    • at::ScalarType —> torch::headeronly::ScalarType
    • torch::headeronly::ScalarType::Half
    • torch::headeronly::ScalarType::BFloat16
    • torch::headeronly::ScalarType::Float8_e4m3fn
    • torch::headeronly::ScalarType::Float
  • torch::IntArrayRef({__VA_ARGS__}) —> torch::headeronly::IntArrayRef({__VA_ARGS__})
    in torch/headeronly/util/ArrayRef.h (based on c10/util)
    • punted to future for a bigger discussion, as ArrayRef depends on c10/util/SmallVector.h which is NOT header only!
  • TORCH_CHECK (make
    // TORCH_CHECK throws std::runtime_error instead of c10::Error which is
    headeronly)
  • AOTI_TORCH_ERROR_CODE_CHECK -> write a generic TORCH_ERROR_CODE_CHECK Cut a version of TORCH_ERROR_CODE_CHECK in headeronly from AOTI #159604

cc @jbschlosser @desertfire @swolchok @tridao

Metadata

Metadata

Assignees

Labels

module: abilibtorch C++ ABI related problemsmodule: cppRelated to C++ APItriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions

    0