8000 Update base for Update on "Trace attention inference patterns with p=… · pytorch/pytorch@870bb37 · GitHub
[go: up one dir, main page]

Skip to content

Commit 870bb37

Browse files
committed
Update base for Update on "Trace attention inference patterns with p=0, cleanup"
When dropout is traced in inference, it creates a clone() instead of training pattern of rand() etc. This was partially addressed by manually #108141, however that did not cover all of the patterns that included dropout, and there is no reason we should have to specify them manually. This updates the inference patterns generated to trace with dropout_p = 0.0. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx peterbell10 ipiszy ngimel yf225 chenyang78 kadeng muchulee8 aakhundov [ghstack-poisoned]
2 parents 227e6da + d4990ad commit 870bb37

File tree

278 files changed

+13148
-11496
lines changed
  • variables
  • _export
  • _inductor
  • _numpy
  • _prims_common
  • _subclasses
  • ao/quantization/pt2e/representation
  • autograd
  • csrc
  • distributed
  • export
  • fx
  • jit
  • onnx/_internal/fx
  • testing/_internal
  • Some content is hidden

    Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

    278 files changed

    +13148
    -11496
    lines changed

    .ci/docker/common/install_onnx.sh

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -29,7 +29,7 @@ pip_install \
    2929
    transformers==4.32.1
    3030

    3131
    pip_install coloredlogs packaging
    32-
    retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.16.0.dev20230908001
    32+
    retry pip_install -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/ --no-cache-dir --no-input ort-nightly==1.16.0.dev20230912006
    3333

    3434
    pip_install onnx==1.14.1
    3535
    pip_install onnxscript-preview==0.1.0.dev20230828 --no-deps

    .ci/pytorch/build.sh

    Lines changed: 8 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -159,6 +159,14 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* && -z "$TORCH_CUDA_ARCH_LIST" ]]; then
    159159
    exit 1
    160160
    fi
    161161

    162+
    # We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
    163+
    # memory to build and will OOM
    164+
    if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ "$TORCH_CUDA_ARCH_LIST" == *"8.6"* || "$TORCH_CUDA_ARCH_LIST" == *"8.0"* ]]; then
    165+
    echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM"
    166+
    echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage"
    167+
    export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))"
    168+
    fi
    169+
    162170
    if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then
    163171
    export CC=clang
    164172
    export CXX=clang++

    .circleci/scripts/binary_populate_env.sh

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -155,8 +155,8 @@ EOL
    155155

    156156
    # nproc doesn't exist on darwin
    157157
    if [[ "$(uname)" != Darwin ]]; then
    158-
    # Because most Circle executors only have 20 CPUs, using more causes OOMs w/ Ninja and nvcc parallelization
    159-
    MEMORY_LIMIT_MAX_JOBS=18
    158+
    # This was lowered from 18 to 12 to avoid OOMs when compiling FlashAttentionV2
    159+
    MEMORY_LIMIT_MAX_JOBS=12
    160160
    NUM_CPUS=$(( $(nproc) - 2 ))
    161161

    162162
    # Defaults here for **binary** linux builds so they can be changed in one place

    .github/scripts/filter_test_configs.py

    Lines changed: 7 additions & 6 deletions
    Original file line numberDiff line numberDiff line change
    @@ -410,16 +410,17 @@ def process_jobs(
    410410
    if target_job in (TEST_JOB_NAME, BUILD_AND_TEST_JOB_NAME):
    411411
    target_cfg = m.group("cfg")
    412412

    413-
    return _filter_jobs(
    413+
    # NB: There can be multiple unstable configurations, i.e. inductor, inductor_huggingface
    414+
    test_matrix = _filter_jobs(
    414415
    test_matrix=test_matrix,
    415416
    issue_type=issue_type,
    416417
    target_cfg=target_cfg,
    417418
    )
    418-
    419-
    warnings.warn(
    420-
    f"Found a matching {issue_type.value} issue {target_url} for {workflow} / {job_name}, "
    421-
    + f"but the name {target_job_cfg} is invalid"
    422-
    )
    419+
    else:
    420+
    warnings.warn(
    421+
    f"Found a matching {issue_type.value} issue {target_url} for {workflow} / {job_name}, "
    422+
    + f"but the name {target_job_cfg} is invalid"
    423+
    )
    423424

    424425
    # Found no matching target, return the same input test matrix
    425426
    return test_matrix

    .github/scripts/test_filter_test_configs.py

    Lines changed: 56 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -102,6 +102,30 @@
    102102
    "manywheel-py3_8-cuda11_8-build",
    103103
    "",
    104104
    ],
    105+
    "inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor)": [
    106+
    "pytorchbot",
    107+
    "107079",
    108+
    "https://github.com/pytorch/pytorch/issues/107079",
    109+
    "inductor",
    110+
    "cuda12.1-py3.10-gcc9-sm86",
    111+
    "test (inductor)",
    112+
    ],
    113+
    "inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_huggingface)": [
    114+
    "pytorchbot",
    115+
    "109153",
    116+
    "https://github.com/pytorch/pytorch/issues/109153",
    117+
    "inductor",
    118+
    "cuda12.1-py3.10-gcc9-sm86",
    119+
    "test (inductor_huggingface)",
    120+
    ],
    121+
    "inductor / cuda12.1-py3.10-gcc9-sm86 / test (inductor_huggingface_dynamic)": [
    122+
    "pytorchbot",
    123+
    "109154",
    124+
    "https://github.com/pytorch/pytorch/issues/109154",
    125+
    "inductor",
    126+
    "cuda12.1-py3.10-gcc9-sm86",
    127+
    "test (inductor_huggingface_dynamic)",
    128+
    ],
    105129
    }
    106130

    107131
    MOCKED_PR_INFO = {
    @@ -569,6 +593,37 @@ def test_mark_unstable_jobs(self, mock_download_json: Any) -> None:
    569593
    "expected": '{"include": [{"config": "default", "unstable": "unstable"}]}',
    570594
    "description": "Both binary build and test jobs are unstable",
    571595
    },
    596+
    {
    597+
    "workflow": "inductor",
    598+
    "job_name": "cuda12.1-py3.10-gcc9-sm86 / build",
    599+
    "test_matrix": """
    600+
    { include: [
    601+
    { config: "inductor" },
    602+
    { config: "inductor_huggingface", shard: 1 },
    603+
    { config: "inductor_huggingface", shard: 2 },
    604+
    { config: "inductor_timm", shard: 1 },
    605+
    { config: "inductor_timm", shard: 2 },
    606+
    { config: "inductor_torchbench" },
    607+
    { config: "inductor_huggingface_dynamic" },
    608+
    { config: "inductor_torchbench_dynamic" },
    609+
    { config: "inductor_distributed" },
    610+
    ]}
    611+
    """,
    612+
    "expected": """
    613+
    { "include": [
    614+
    { "config": "inductor", "unstable": "unstable" },
    615+
    { "config": "inductor_huggingface", "shard": 1, "unstable": "unstable" },
    616+
    { "config": "inductor_huggingface", "shard": 2, "unstable": "unstable" },
    617+
    { "config": "inductor_timm", "shard": 1 },
    618+
    { "config": "inductor_timm", "shard": 2 },
    619+
    { "config": "inductor_torchbench" },
    620+
    { "config": "inductor_huggingface_dynamic", "unstable": "unstable" },
    621+
    { "config": "inductor_torchbench_dynamic" },
    622+
    { "config": "inductor_distributed" }
    623+
    ]}
    624+
    """,
    625+
    "description": "Marking multiple unstable configurations",
    626+
    },
    572627
    ]
    573628

    574629
    for case in testcases:
    @@ -577,7 +632,7 @@ def test_mark_unstable_jobs(self, mock_download_json: Any) -> None:
    577632
    test_matrix = yaml.safe_load(case["test_matrix"])
    578633

    579634
    filtered_test_matrix = mark_unstable_jobs(workflow, job_name, test_matrix)
    580-
    self.assertEqual(case["expected"], json.dumps(filtered_test_matrix))
    635+
    self.assertEqual(json.loads(case["expected"]), filtered_test_matrix)
    581636

    582637
    @mock.patch("subprocess.check_output")
    583638
    def test_perform_misc_tasks(self, mocked_subprocess: Any) -> None:

    .github/workflows/build-triton-wheel.yml

    Lines changed: 3 additions & 3 deletions
    Original file line numberDiff line numberDiff line change
    @@ -131,7 +131,7 @@ jobs:
    131131
    needs: build-wheel
    132132
    container:
    133133
    image: continuumio/miniconda3:4.12.0
    134-
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
    134+
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
    135135
    steps:
    136136
    - uses: actions/checkout@v3
    137137

    @@ -244,7 +244,7 @@ jobs:
    244244
    needs: build-conda
    245245
    container:
    246246
    image: continuumio/miniconda3:4.12.0
    247-
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
    247+
    environment: ${{ (github.event_name == 'push' && (github.event.ref == 'refs/heads/nightly' || github.event.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v'))) && 'conda-aws-upload' || '' }}
    248248
    steps:
    249249
    - uses: actions/checkout@v3
    250250

    @@ -283,7 +283,7 @@ jobs:
    283283
    run: |
    284284
    set -ex
    285285
    286-
    if [[ "${UPLOAD_CHANNEL}" = "nightly" ]]; then
    286+
    if [[ "${UPLOAD_CHANNEL:-nightly}" == "nightly" ]]; then
    287287
    export ANACONDA_API_TOKEN="${CONDA_PYTORCHBOT_TOKEN}"
    288288
    else
    289289
    export ANACONDA_API_TOKEN="${CONDA_PYTORCHBOT_TOKEN_TEST}"

    .lintrunner.toml

    Lines changed: 0 additions & 14 deletions
    Original file line numberDiff line numberDiff line change
    @@ -195,37 +195,23 @@ include_patterns = [
    195195
    exclude_patterns = [
    196196
    '**/fb/**',
    197197
    'torch/_inductor/index_propagation.py',
    198-
    'torch/_inductor/coordinate_descent_tuner.py',
    199198
    'torch/_inductor/debug.py',
    200-
    'torch/_inductor/hooks.py',
    201199
    'torch/_inductor/bounds.py',
    202-
    'torch/_inductor/config.py',
    203200
    'torch/_inductor/ir.py',
    204-
    'torch/_inductor/codecache.py',
    205-
    'torch/_inductor/test_operators.py',
    206-
    'torch/_inductor/inductor_prims.py',
    207201
    'torch/_inductor/scheduler.py',
    208202
    'torch/_inductor/exc.py',
    209203
    'torch/_inductor/sizevars.py',
    210-
    'torch/_inductor/triton_helpers.py',
    211204
    'torch/_inductor/freezing.py',
    212205
    'torch/_inductor/pattern_matcher.py',
    213206
    'torch/_inductor/fx_utils.py',
    214-
    'torch/_inductor/virtualized.py',
    215-
    'torch/_inductor/cuda_properties.py',
    216207
    'torch/_inductor/codegen/triton_foreach.py',
    217-
    'torch/_inductor/codegen/__init__.py',
    218208
    'torch/_inductor/codegen/cpp.py',
    219209
    'torch/_inductor/codegen/triton.py',
    220210
    'torch/_inductor/fx_passes/split_cat.py',
    221-
    'torch/_inductor/fx_passes/binary_folding.py',
    222-
    'torch/_inductor/fx_passes/replace_random.py',
    223211
    'torch/_inductor/fx_passes/joint_graph.py',
    224212
    'torch/_inductor/fx_passes/pad_mm.py',
    225-
    'torch/_inductor/fx_passes/__init__.py',
    226213
    'torch/_inductor/fx_passes/group_batch_fusion.py',
    227214
    'torch/_inductor/fx_passes/pre_grad.py',
    228-
    'torch/_inductor/fx_passes/freezing_patterns.py',
    229215
    ]
    230216
    command = [
    231217
    'python3',

    CMakeLists.txt

    Lines changed: 1 addition & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -730,7 +730,7 @@ include(cmake/Dependencies.cmake)
    730730
    cmake_dependent_option(
    731731
    USE_FLASH_ATTENTION
    732732
    "Whether to build the flash_attention kernel for scaled dot product attention" ON
    733-
    "USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
    733+
    "USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
    734734

    735735
    # Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
    736736
    cmake_dependent_option(

    aten/src/ATen/CMakeLists.txt

    Lines changed: 2 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -161,6 +161,7 @@ file(GLOB native_utils_cpp "native/utils/*.cpp")
    161161

    162162
    # flash_attention sources
    163163
    file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
    164+
    file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
    164165
    file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
    165166

    166167
    #Mem_eff attention sources
    @@ -170,6 +171,7 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention
    170171

    171172
    if(USE_FLASH_ATTENTION)
    172173
    list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
    174+
    list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_kernels_cu})
    173175
    list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
    174176
    endif()
    175177

    aten/src/ATen/core/interned_strings.h

    Lines changed: 2 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -340,7 +340,8 @@ namespace c10 {
    340340
    _(attr, output_layouts) \
    341341
    _(attr, allowzero) \
    342342
    _(attr, seen_none) \
    343-
    _(attr, overload_name)
    343+
    _(attr, overload_name) \
    344+
    _(attr, node_stack_idx)
    344345

    345346
    enum class _keys : unique_t {
    346347
    #define DEFINE_KEY(ns, s) ns##_##s,

    0 commit comments

    Comments
     (0)
    0