10000 Adds cudaMallocAsync as an alternative backend for the CUDA allocator by mcarilli · Pull Request #65365 · pytorch/pytorch · GitHub
[go: up one dir, main page]

Skip to content

Adds cudaMallocAsync as an alternative backend for the CUDA allocator #65365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 73 commits into from

Conversation

mcarilli
Copy link
Collaborator
@mcarilli mcarilli commented Sep 20, 2021

Benefits

The main benefit of adding cudaMallocAsync to Pytorch is "ecosystem composability":
It allows transparent, efficient co-use of GPU memory with other libraries in the same process that also use cudaMallocAsync.

User Exposure

This PR exposes cudaMallocAsync through the already-existing environment variable PYTORCH_CUDA_ALLOC_CONF, i.e.

export PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync or native

Should we use default or Pytorch-private cudaMemPool_ts?

One goal of cudaMallocAsync is to facilitate memory sharing in scenarios where pytorch shares the GPU with other libraries in the same process that also use cudaMallocAsync under the hood. I see 3 options:

  1. Pytorch cudaMallocAsyncs from the (per-device) default pool.
    • Pros: If other libraries also use the default pool, sharing memory is fast and efficient.
    • Cons: Flags that Pytorch sets on the pool (e.g. cudaMemPoolReuseAllowOpportunistic, cudaMemPoolAttrReleaseThreshold) silently affect the other libraries, and vice versa.
  2. Pytorch creates and cudaMallocAsyncs from its own private cudaMemPool_ts.
    • Pros: Flags are sandboxed.
    • Cons: Memory sharing with other libraries gets ugly. Different libraries hit memory high water marks at different times, so their pools want to reserve more memory at different times. Vivek Kini (a cudaMallocAsync author) says pools ARE allowed to steal unused reserved memory from other pools (such that the entire process hopefully won't go OOM), but it's slow: memory must be released to the system by pool A then reserved by pool B. This repeated rebalancing/inter-pool thrashing of reserved memory would be a difficult-to-diagnose, potentially nondeterministic performance hit.
  3. Allow the user to choose 1 or 2.

The PR currently implements 1. (Implementing 2 or 3 instead, ie, optionally creating per-device pools instead of using default pools, would be straightforward, so the choice between 1, 2, 3 should be motivated purely by high-level pros and cons.)

Internal (c10/cuda) interface

Each allocator interface function in c10/cuda/CUDACachingAllocator.h calls a function pointer that's populated from either the Native namespace or cudaMallocAsync namespace at load time (static initialization).

This design is mainly motivated by not needing any thread-safety gunk in the hot path, which would be the case if, for example, the pointers were populated as function-static variables on first call (compare the code for "interface() in https://godbolt.org/z/3GzaK65dv with the much simpler code for interface() in the load-time design, https://godbolt.org/z/crKGhb3o9).

I think it also avoids deepening the current call chain, ie, maintains existing inlining opportunities and doesn't add new out-of-line calls on the hot path. (I assumed "existing opportunities" does NOT include link-time inlining from LTO, which @malfet told me isn't used when linking lib10_cuda.so.)

Testing

Testing will be a challenge. I propose we add one cuda CI build that enables cudaMallocAsync end to end, but first we need a CI build that uses cuda 11.4+. Afaik such a build doesn't exist yet.

Future goals (for followup PRs)

IPC support for cudaMallocAsynced memory

^

Pluggable external allocators

In the future, pluggable external allocators could be enabled with the same approach (see #43144, rapidsai/rmm#501) . For example, the user could say

export PYTORCH_CUDA_ALLOC_CONF=backend:mylib.so

where mylib.so has symbols with the right signatures. The allocator API functions in CUDACachingAllocator.cpp could then set their static function pointers using dlopen and dlsym on mylib.so (huge security hole if anyone runs pytorch as a setuid process, but that's probably true for a million things in pytorch.)

cc @ngimel

@facebook-github-bot
Copy link
Contributor
facebook-github-bot commented Sep 20, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 2b1e0b2 (more details on the Dr. CI page):


  • 2/3 failures introduced in this PR
  • 1/3 broken upstream at merge base e5a1a78 on Apr 05 from 11:50am to 5:17pm

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages

See GitHub Actions build pull / linux-bionic-rocm5.0-py3.7 / build (1/1)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

2022-04-05T21:25:09.5714140Z �[36;1m echo "ERR...t available for the merge-base of your branch"�[0m
2022-04-05T21:25:09.5711223Z �[36;1mfi�[0m
2022-04-05T21:25:09.5711447Z �[36;1m# Covers the case where a previous tag doesn't exist for the tree�[0m
2022-04-05T21:25:09.5711785Z �[36;1m# this is only really applicable on trees that don't have `.circleci/docker` at its merge base, i.e. nightly�[0m
2022-04-05T21:25:09.5712105Z �[36;1mif ! git rev-parse "$MERGE_BASE:.circleci/docker"; then�[0m
2022-04-05T21:25:09.5712446Z �[36;1m  echo "Directory '.circleci/docker' not found in commit $MERGE_BASE, you should probably rebase onto a more recent commit"�[0m
2022-04-05T21:25:09.5712722Z �[36;1m  exit 1�[0m
2022-04-05T21:25:09.5712888Z �[36;1mfi�[0m
2022-04-05T21:25:09.5713108Z �[36;1mPREVIOUS_DOCKER_TAG=$(git rev-parse "$MERGE_BASE:.circleci/docker")�[0m
2022-04-05T21:25:09.5713435Z �[36;1m# If no image exists but the hash is the same as the previous hash then we should error out here�[0m
2022-04-05T21:25:09.5713811Z �[36;1mif [[ "${PREVIOUS_DOCKER_TAG}" = "${DOCKER_TAG}" ]]; then�[0m
2022-04-05T21:25:09.5714140Z �[36;1m  echo "ERROR: Something has gone wrong and the previous image isn't available for the merge-base of your branch"�[0m
2022-04-05T21:25:09.5714469Z �[36;1m  echo "       contact the PyTorch team to restore the original images"�[0m
2022-04-05T21:25:09.5714699Z �[36;1m  exit 1�[0m
2022-04-05T21:25:09.5714915Z �[36;1mfi�[0m
2022-04-05T21:25:09.5715264Z �[36;1mecho ::set-output name=rebuild::yes�[0m
2022-04-05T21:25:09.5725716Z shell: /usr/bin/bash --noprofile --norc -e -o pipefail {0}
2022-04-05T21:25:09.5725934Z env:
2022-04-05T21:25:09.5726080Z   IN_CI: 1
2022-04-05T21:25:09.5726239Z   IS_GHA: 1
2022-04-05T21:25:09.5726451Z   BASE_REVISION: e5a1a78045e66aad9e763bf69c2455f8136c1eef
2022-04-05T21:25:09.5726824Z   DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-bionic-rocm5.0-py3.7:3c8505af7f950e5107f15c680d2fb85c7eb425ef

🕵️‍♀️ 1 failure not recognized by patterns:

The following CI failures may be due to changes from the PR
Job Step Action
GitHub Actions Lint / clang-format Run clang-format 🔁 rerun

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@mcarilli mcarilli requested review from ngimel and ezyang September 20, 2021 19:31
@mcarilli mcarilli added the module: cuda Related to torch.cuda, and CUDA support in general label Sep 20, 2021
@vadimkantorov
Copy link
Contributor

Custom allocators may be useful for tracing as well... Related: #1529

@ezyang
Copy link
Contributor
ezyang commented Sep 21, 2021

Since you only plan it to be toggleable once, environment variable seems better. We got a PYTORCH_CUDA_ALLOC_CONF, consider using that?

@mcarilli
Copy link
Collaborator Author

Since you only plan it to be toggleable once, environment variable seems better. We got a PYTORCH_CUDA_ALLOC_CONF, consider using that?

sounds good, my only minor objection would be, i don't think people instinctively use the environment variables as often as the torch.backends.

@ezyang
Copy link
Contributor
ezyang commented Sep 21, 2021

I'd probably say that you'd be on the hook for making sure that people can toggle it arbitrarily, in that case

@mcarilli
Copy link
Collaborator Author

I'd probably say that you'd be on the hook for making sure that people can toggle it arbitrarily, in that case

alright, env variable it is. we don't want them switching it mid-application.

@pytorch-probot
Copy link
pytorch-probot bot commented Oct 11, 2021
CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/mcarilli/pytorch/blob/7e7c12b60afd5ab269c4a3d1d9daaa3d31eab4df/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/trunk ✅ triggered
linux-docs ciflow/all, ciflow/cpu, ciflow/default, ciflow/docs, ciflow/linux, ciflow/trunk ✅ triggered
linux-vulkan-bionic-py3.7-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk, ciflow/vulkan ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-cuda11.3-py3.7-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-build ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3-clang5-mobile-custom-build-static ciflow/all, ciflow/default, ciflow/linux, ciflow/mobile, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers, ciflow/trunk ✅ triggered
linux-xenial-py3.7-clang7-onnx ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/onnx, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
linux-xenial-py3.7-gcc7-no-ops ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-custom-build-single-full-jit ciflow/all, ciflow/android, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/trunk ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/trunk, ciflow/win ✅ triggered
Skipped Workflows
caffe2-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
docker-builds ciflow/all, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-custom-ops ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-arm64-metal ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64 ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-coreml ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
ios-12-5-1-x86-64-full-jit ciflow/all, ciflow/ios, ciflow/macos, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda10.2-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/trunk 🚫 skipped
linux-binary-conda ciflow/binaries, ciflow/binaries/conda 🚫 skipped
linux-binary-libtorch-cxx11-abi ciflow/binaries, ciflow/binaries/libtorch 🚫 skipped
linux-binary-libtorch-pre-cxx11 ciflow/binaries, ciflow/binaries/libtorch 🚫 skipped
linux-binary-manywheel ciflow/binaries, ciflow/binaries/wheel 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow, ciflow/trunk 🚫 skipped
linux-bionic-py3.6-clang9 ciflow/xla 🚫 skipped
linux-docs-push ciflow/all, ciflow/cpu, ciflow/linux, ciflow/scheduled 🚫 skipped
linux-xenial-cuda11.3-py3.7-gcc7-no-ops ciflow/all, ciflow/cuda, ciflow/linux, ciflow/trunk 🚫 skipped
macos-10-15-py3-arm64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-10-15-py3-lite-interpreter-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
macos-11-py3-x86-64 ciflow/all, ciflow/macos, ciflow/trunk 🚫 skipped
parallelnative-linux-xenial-py3.7-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped
periodic-libtorch-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-bionic-cuda11.5-py3.7-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled, ciflow/slow, ciflow/slow-gradcheck 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.7-gcc7-debug ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
periodic-win-vs2019-cuda11.5-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
pytorch-linux-xenial-py3-clang5-android-ndk-r19c-build ciflow/all, ciflow/android, ciflow/cpu, ciflow/linux, ciflow/trunk 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@mcarilli mcarilli changed the title [WIP] add cudaMallocAsync as an alternative backend for the CUDA allocator [WIP] Adds cudaMallocAsync as an alternative backend for the CUDA allocator Nov 15, 2021
}
}
};

namespace Native {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw capitalized namespace is pretty unusual, it will make people confused if they're looking at a class as opposed to a namespace

// to CachingAllocatorConfig's runtime doublecheck.
// If this works, maybe we should move all of CachingAllocatorConfig here?
AllocatorBackend parseEnvForBackend() {
const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@swolchok how does parsing things out of envvars in static initializers make you feel

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems fine as long as you don't mind failing to respect changes to the environment via putenv() etc. I don't think I have ever minded.

@github-actions
Copy link
Contributor
github-actions bot commented Jul 2, 2022

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jul 2, 2022
@ezyang ezyang removed the Stale label Jul 4, 2022
@ezyang
Copy link
Contributor
ezyang commented Jul 4, 2022

we should find someone else to push this over the finish line

@ngimel
Copy link
Collaborator
ngimel commented Jul 4, 2022

cc @ptrblck, who can take this over?

@emcastillo
Copy link
Collaborator

@ezyang @ngimel We have several internal use cases that can benefit from the direction this PR is taking, so I would like to take over it if possible. Thanks!

@ptrblck
Copy link
Collaborator
ptrblck commented Jul 5, 2022

I don't know if @mcarilli is interested in finishing this PR, but if that's not the case then @eqy could take a look at it from our side.

@zarzen
Copy link
zarzen commented Aug 1, 2022

Hi,
I am trying this PR but got hanging issues when using torch.distributed.* operations. For example, calling PYTORCH_CUDA_ALLOC_CONF=backend:cudaMallocAsync python3 pytorch/test/run_test.py --verbose -i distributed/test_c10d_nccl raise timeout errors for all collective communication operators.

After debugging, I found the code hangs at the recordStream stage at here

Any suggestions for a fix?

@ngimel
Copy link
Collaborator
ngimel commented Aug 2, 2022

cc @ptrblck, who is taking over cudaMallocAsync?

@zarzen
Copy link
zarzen commented Aug 2, 2022

I found the cause for deadlock is due to the lock operation in the AutoNcclGroup constructor, which locks a mutex from the allocator. In current implementation cudaMallocAsyncAllocator has a single general_mutex, which will be returned from getFreeMutex() at here

To fix the deadlock issue, I removed the lock_guard at the beginning of recordStream function in CudaMallocAsync at here. But not sure if this would bring other bad consequences.

@ptrblck
Copy link
Collaborator
ptrblck commented Aug 2, 2022

@ngimel @eqy started to rebase the PR today in order to reproduce the issue @zarzen described.
@zarzen could you give us more information about your setup, i.e. especially used NCCL version, please?

@zarzen
Copy link
zarzen commented Aug 2, 2022

@ngimel @eqy started to rebase the PR today in order to reproduce the issue @zarzen described. @zarzen could you give us more information about your setup, i.e. especially used NCCL version, please?

I was using NCCL-2.10.3, which is compiled by PyTorch 1.12.

@eqy
Copy link
Collaborator
eqy commented Aug 5, 2022

With the rebased PR (#82682) I can reproduce failures in
test_all_reduce_coalesced_nccl
test_broadcast_coalesced_nccl
test_nccl_barrier_device_ids
test_nccl_warn_not_in_group_debug_detail
test_nccl_warn_not_in_group_debug_info
test_nccl_warn_not_in_group_debug_off
test_pass_nccl_options_high_priority_stream
test_sequence_num_set_default_pg_nccl
test_sequence_num_incremented_nccl_default
test_sequence_num_set_nccl_new_group
test_accumulate_gradients_module
test_accumulate_gradients_module_with_grad_is_view
test_arbitrary_forward_return_value
and so on...
will investigate further.

@zarzen
Copy link
zarzen commented Aug 8, 2022

With the rebased PR (#82682) I can reproduce failures in test_all_reduce_coalesced_nccl test_broadcast_coalesced_nccl test_nccl_barrier_device_ids test_nccl_warn_not_in_group_debug_detail test_nccl_warn_not_in_group_debug_info test_nccl_warn_not_in_group_debug_off test_pass_nccl_options_high_priority_stream test_sequence_num_set_default_pg_nccl test_sequence_num_incremented_nccl_default test_sequence_num_set_nccl_new_group test_accumulate_gradients_module test_accumulate_gradients_module_with_grad_is_view test_arbitrary_forward_return_value and so on... will investigate further.

You can probably comment out this line https://github.com/mcarilli/pytorch/blob/cudaMallocAsync/c10/cuda/CUDAMallocAsyncAllocator.cpp#L552

@eqy
Copy link
Collaborator
eqy commented Aug 8, 2022

I believe a "safer" fix for this is to replace the std::mutex in the current implementation with a std::recursive_mutex (which is done in the default allocator backend), as it seems the mutex acquisition within a single thread is fast and loose without any strict convention between when caller/callee are expected to already be holding the lock.

Currently testing this fix.

@eqy
Copy link
Collaborator
eqy commented Aug 8, 2022

Using std::recursive_mutex looks good for the distributed NCCL tests locally. Added another minor change for handling nullptr/"empty" tensors, and test_cuda.py and 10000 test_nn.py doesn't seem to show any relevant failures locally on 2xA6000.

I've updated #82682 with these changes.

@eqy
Copy link
Collaborator
eqy commented Aug 8, 2022

Looking closer, the root cause appears to be AutoNcclGroup calling getFreeMutex which is reused for the public API functions in the async malloc implementation. Just using a separate std::mutex for getFreeMutex (as is done in the current CUDACachingAllocator) should be sufficient without having to reach for std::recursive_mutex.

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

pytorchmergebot pushed a commit that referenced this pull request Oct 12, 2022
… allocator (#82682)

Rebased version of @mcarilli 's cudaMallocAsync #65365 for continued testing
Pull Request resolved: #82682
Approved by: https://github.com/ngimel
IvanYashchuk added a commit to csarofeen/pytorch that referenced this pull request Oct 13, 2022
commit f925b26
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Oct 13 21:45:09 2022 +0300

    Allow skipping view with skip_ops

commit ddb769e
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Oct 13 21:38:04 2022 +0300

    Add varargs support for view

commit a9cdefa
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Oct 12 18:46:46 2022 +0300

    Use ops.view name

commit 986d76b
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Oct 12 18:27:37 2022 +0300

    Fix duplicate

commit 1c9c9c6
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Oct 12 16:49:52 2022 +0300

    Add print for ViewOpRecord

commit a67e6c2
Merge: b07eeb0 2344135
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Oct 12 16:43:53 2022 +0300

    Merge remote-tracking branch 'upstream/viable/strict' into nvprims-view

commit 2344135
Author: Khushi <khushiagrawal411@gmail.com>
Date:   Wed Oct 12 07:00:40 2022 +0000

    [primTorch] special: entr, expit (pytorch#86592)

    Add _refs for `entr` & `expit`.

    cc @mruberry @kshitij12345!
    Pull Request resolved: pytorch#86592
    Approved by: https://github.com/mruberry

commit a47f93b
Author: Sherlock Huang <bahuang@fb.com>
Date:   Wed Oct 12 02:26:02 2022 +0000

    Add type and shape annotation for gm.print_readable() (pytorch#86562)

    For
    ```
    def f(a, b):
        dim0 = a.shape[0] + b.shape[0]
        dim1 = a.shape[1] + b.shape[1]
        d = a.new_empty(dim0, dim1)
        return d

    fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
    fx_g.print_readable()
    ```

    Tracing with 'real' and 'fake' mode yields
    ```
    class f(torch.nn.Module):
        def forward(self, a_1: Tensor<f32>[5, 3], b_1: Tensor<f32>[4, 3]):

            # No stacktrace found for following nodes
            new_empty: Tensor<f32>[9, 6] = torch.ops.aten.new_empty.default(a_1, [9, 6], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  a_1 = None
            return new_empty
    ```

    Tracing with 'symbolic' mode yields
    ```
        def forward(self, a_1: Tensor<f32>[t0.size(0), t0.size(1)], b_1: Tensor<f32>[t1.size(0), t0.size(1)]):

            # No stacktrace found for following nodes
            sym_size: Symint(t0.size(0)) = torch.ops.aten.sym_size(a_1, 0)
            sym_size_1: Symint(t1.size(0)) = torch.ops.aten.sym_size(b_1, 0)
            add: Symint(t0.size(0) + t1.size(0)) = sym_size + sym_size_1;  sym_size = sym_size_1 = None
            sym_size_2: Symint(t0.size(1)) = torch.ops.aten.sym_size(a_1, 1)
            sym_size_3: Symint(t0.size(1)) = torch.ops.aten.sym_size(b_1, 1);  b_1 = None
            add_1: Symint(2*t0.size(1)) = sym_size_2 + sym_size_3;  sym_size_2 = sym_size_3 = None
            new_empty: Tensor<f32>[t0.size(0) + t1.size(0), 2*t0.size(1)] = torch.ops.aten.new_empty.default(a_1, [add, add_1], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False);  a_1 = add = add_1 = None
            return new_empty
    ```

    Pull Request resolved: pytorch#86562
    Approved by: https://github.com/Chillee

commit e0d6898
Author: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Date:   Wed Oct 12 04:12:43 2022 +0000

    Revert "Backport currently dont work with some models if: (pytorch#86510)"

    This reverts commit 4bfb734.

    Reverted pytorch#86510 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally

commit 25725fd
Author: Eddie Yan <eddiey@nvidia.com>
Date:   Wed Oct 12 03:44:21 2022 +0000

    (Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (pytorch#82682)

    Rebased version of @mcarilli 's cudaMallocAsync pytorch#65365 for continued testing
    Pull Request resolved: pytorch#82682
    Approved by: https://github.com/ngimel

commit a216f47
Author: Nikita Shulga <nshulga@fb.com>
Date:   Wed Oct 12 01:45:21 2022 +0000

     Add  testing on A10G GPU to periodic workflow (pytorch#85524)

    This enables testing on lots of modern CUDA features on sm_86 capable GPU

    While migrating to that platform, discovered that `functorch` tests for `nn.functional.conv.transpose3d` produce garbage on sm_80+ as well as 2 `nvfuser` tests unexpectedly pass and one unexpectedly fails.

    TODO:
     - Investigate unexpected success for `test_vmapvjp_linalg_householder_product_cuda_float32` and add `functorch` shard

    Pull Request resolved: pytorch#85524
    Approved by: https://github.com/ngimel

commit c4f0b93
Author: Elias Ellison <elias.ellison@gmail.com>
Date:   Tue Oct 11 01:24:48 2022 +0000

    Disable autocast in aot autograd (pytorch#86515)

    Fix for pytorch/torchdynamo#1368

    From comment:
    > When we invoke a Composite Implicit autograd operator that has an autocast rule, such as Einsum,
    autocast is disabled during its invocation. When we trace out the operators in an implicit op,
    re-applying on autocast rules on those operators might yield divergence from what was executed at runtime.
    This pass checks for divergence. If divergence is found, we will disable autocast.
    We would like to avoid disabling autocast if possible because accessing TLS is slow.

    Concretely, the problem found was when invoked `sum` in `einsum`:

    As seen by the following divergence:
    ```
    >>> with torch.cuda.amp.autocast(enabled=True):
    ...     print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
    ...
    torch.float32
    >>> print(torch.ops.aten.sum.dim_IntList(torch.rand([2, 2, 2], device="cuda", dtype=torch.half), [1, 2]).dtype)
    torch.float16
    ```

    Edit: we've decided to accept the overhead of universally disabling autocast instead
    Pull Request resolved: pytorch#86515
    Approved by: https://github.com/bdhirsh, https://github.com/Chillee

commit d598290
Author: Christian Puhrsch <cpuhrsch@fb.com>
Date:   Wed Oct 12 01:27:57 2022 +0000

    Basic SDP benchmark harness (pytorch#86729)

    Basic benchmark for reference and discussion.
    Pull Request resolved: pytorch#86729
    Approved by: https://github.com/drisspg

commit 4bfb734
Author: Han Qi (qihqi) <qihan@fb.com>
Date:   Wed Oct 12 00:39:25 2022 +0000

    Backport currently dont work with some models if: (pytorch#86510)

    Backport currently dont work with some models if:

    * model is originally exported with interface call enabled (backport would disable it)
    * model is flatbuffer (flatbuffer support is soft enabled via link time registry), so we manually trigger it

    Fixes #ISSUE_NUMBER

    Pull Request resolved: pytorch#86510
    Approved by: https://github.com/cccclai

commit ce48df9
Author: Bin Bao <binbao@fb.com>
Date:   Tue Oct 11 20:31:12 2022 +0000

    Re-enable torchdynamo unit tests (pytorch#86658)

    Pull Request resolved: pytorch#86658
    Approved by: https://github.com/jansel

commit 692b525
Author: Nikita Shulga <nshulga@fb.com>
Date:   Wed Oct 12 00:32:53 2022 +0000

    [MPS] Extend unary ops to int64 (pytorch#86615)

    Most of them are already supported for `int64` except for:
     - rounding operations (`floor`, `ceil` and `round`), which are no-ops for integral types anyway
     - sign operation, when it can be emulated by clamping it tensor to [-1, 1] range

    Test new types by test MPS

    Fixes pytorch#86319

    Pull Request resolved: pytorch#86615
    Approved by: https://github.com/DenisVieriu97, https://github.com/huydhn

commit f912b58
Author: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Date:   Tue Oct 11 23:53:12 2022 +0000

    Revert "Enable max.unary_out (pytorch#85926)"

    This reverts commit 16a0fa1.

    Reverted pytorch#85926 on behalf of https://github.com/osalpekar due to The internal diff for this commit shows a number of pytorch quantization test failures. Here is a sample output: AssertionError: Tensor-likes are not close! Mismatched elements: 319 / 320 (99.7%). Greatest absolute difference: 0.056652069091796875 at index (0, 0, 4, 5) (up to 1e-05 allowed). Link to the diff: [D40232598](https://www.internalfb.com/diff/D40232598). Link to the Sandcastle job that is failing: https://www.internalfb.com/intern/sandcastle/job/18014399302908587/

commit 2aa981a
Author: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Date:   Tue Oct 11 23:39:50 2022 +0000

    Revert "Reland 2 of Merge more symbolic meta kernels and symint changes from branch (pytorch#86334) (pytorch#86488)"

    This reverts commit 978b46d.

    Reverted pytorch#86488 on behalf of https://github.com/osalpekar due to Broke executorch builds internally with the following message: RuntimeError: Missing out variant for functional op: aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] . Make sure you have loaded your custom_ops_generated_lib

commit 9eb4f9d
Author: Nikita Shulga <nshulga@fb.com>
Date:   Tue Oct 11 19:49:23 2022 +0000

    Tweak test tolerances to be compatible with A10G (pytorch#86538)

    Pull Request resolved: pytorch#86538
    Approved by: https://github.com/ngimel

commit 7fa601b
Author: Nikita Shulga <nshulga@fb.com>
Date:   Tue Oct 11 23:27:30 2022 +0000

    Skip chalf.mean in  test_reductions_large_half_tensors (pytorch#86747)

    As `mean_reduce` is not implemented for complex half

    Fixes pytorch#86743 and unblock A10G testing

    Pull Request resolved: pytorch#86747
    Approved by: https://github.com/ngimel

commit 811b8e0
Author: PyTorch MergeBot <pytorchmergebot@users.noreply.github.com>
Date:   Tue Oct 11 23:12:40 2022 +0000

    Revert "min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (pytorch#86643)"

    This reverts commit 86f914e.

    Reverted pytorch#86643 on behalf of https://github.com/osalpekar due to Need to revert this to cleanly revert pytorch#86488. This should be safe to re-land later

commit f1fdb6e
Author: Jason Ansel <jansel@fb.com>
Date:   Tue Oct 11 23:01:21 2022 +0000

    Manual changes for moving dynamo to core (pytorch#86621)

    This is the subset of the changes in pytorch#86461 not auto-generated by `copy_to_core.sh`.
    Pull Request resolved: pytorch#86621
    Approved by: https://github.com/albanD

commit 09364f4
Author: Nikita Shulga <nshulga@fb.com>
Date:   Tue Oct 11 22:39:58 2022 +0000

    Compile C10 with `Wshadow` (pytorch#86666)

    This should prevent further regressions like pytorch#86646
    Update `fmt` to `7.1.0` to fix variable shadowing in that library

    Pull Request resolved: pytorch#86666
    Approved by: https://github.com/seemethere

commit 0337f0a
Author: Zain Rizvi <zainr@fb.com>
Date:   Tue Oct 11 21:56:01 2022 +0000

    Add error checking to flaky test bot platform parser (pytorch#86632)

    If an invalid platform is specified when disabling a test with flaky test bot, the CI crashes, skipping all tests that come after it.

    This turns it into a console message instead.  Not erroring out here since it'll affect random PRs.  Actual error message should go into the bot that parses the original issue so that it can respond on that issue directly
    Pull Request resolved: pytorch#86632
    Approved by: https://github.com/huydhn

commit 42bd275
Author: Parth
F438
o <parthodas6176@gmail.com>
Date:   Tue Oct 11 21:41:48 2022 +0000

    [doc] LR scheduler example fix (pytorch#86629)

    Fixes issue pytorch#86208
    As suggested in the issue, updated the LR scheduler example to use a regular nn.Module like the other examples on the same page.
    Pull Request resolved: pytorch#86629
    Approved by: https://github.com/soulitzer

commit 32152ce
Author: jimku9 <jimku.tw@yahoo.com.tw>
Date:   Tue Oct 11 21:21:53 2022 +0000

    Add original sources/references to Wishart.py in distributions (pytorch#86543)

    @fritzo As discussed, add original sources/references to Wishart.py in distributions and corrected typos in the error messages.

    Pull Request resolved: pytorch#86543
    Approved by: https://github.com/fritzo

commit 50af1ac
Author: Sherlock Huang <bahuang@fb.com>
Date:   Tue Oct 11 17:56:59 2022 +0000

    Mark aten ops as canonical (pytorch#86215)

    This is the first batch of canonical aten ops. 87 in total. More to come in the future PRs.

    native_dropout
    abs
    add.Tensor
    add.Scalar
    arange.start_step
    bitwise_not
    bmm
    cat
    clamp
    constant_pad_nd
    convolution
    convolution_backward
    div.Tensor
    div.Scalar
    embedding_dense_backward
    erf
    exp
    expand
    fill.Scalar
    grid_sampler_2d
    native_group_norm
    native_group_norm_backward
    native_layer_norm
    native_layer_norm_backward
    log
    _log_softmax
    max.dim
    amax
    mean.dim
    min.dim
    amin
    mm
    mul.Tensor
    mul.Scalar
    native_batch_norm
    permute
    scalar_tensor
    reciprocal
    neg
    repeat
    relu
    gelu
    rsqrt
    sigmoid
    slice.Tensor
    slice_scatter
    _softmax
    squeeze.dim
    sum.dim_IntList
    sqrt
    tanh
    unsqueeze
    var.dim
    where.self
    clone
    sub.Tensor
    sub.Scalar
    addmm
    _to_copy
    view
    scatter_add
    bitwise_and.Tensor
    bitwise_or.Tensor
    eq.Scalar
    ge.Scalar
    le.Scalar
    gt.Scalar
    lt.Scalar
    index_select
    nonzero
    gather
    maximum
    minimum
    pow.Tensor_Scalar
    hardtanh
    leaky_relu
    _adaptive_avg_pool2d
    _adaptive_avg_pool2d_backward
    avg_pool2d
    avg_pool2d_backward
    max_pool2d_with_indices
    max_pool2d_with_indices_backward
    upsample_bilinear2d.vec
    upsample_bilinear2d_backward.vec
    upsample_nearest2d.vec
    upsample_nearest2d_backward.vec
    col2im

    Pull Request resolved: pytorch#86215
    Approved by: https://github.com/suo, https://github.com/anjali411

commit 8db3025
Author: Jeff Daily <jeff.daily@amd.com>
Date:   Tue Oct 11 20:55:58 2022 +0000

    [ROCm] set nvfuser default to disabled, keep CI (pytorch#86369)

    Bug fix. nvfuser is functional for ROCm on gfx906, but some tests are failing for other gfx targets. Disable nvfuser until all features are verified. Users may still opt-in by setting the known env var PYTORCH_JIT_ENABLE_NVFUSER=1. This PR sets this env var for the github actions workflow for ROCm since all current CI hosts are gfx906.
    Pull Request resolved: pytorch#86369
    Approved by: https://github.com/huydhn

commit 5ffe24f
Author: Stephen Jia <ssjia@meta.com>
Date:   Tue Oct 11 20:16:56 2022 +0000

    [vulkan][ez] fix always printing out a warning when retrieving the global context (pytorch#86697)

    Summary: D40151818 (pytorch@82ed5ca) replaces the `TORCH_CHECK` with a `TORCH_WARN` but since it does not check if the context is valid the message gets printed every time. This diff fixes that.

    Test Plan:
    Referring to [Pytorch Vulkan Testing Procedures](https://fb.quip.com/fZALAc9zhlcU)

    On Mac:
    1. `vulkan_api_test` on Mac
    2. model comparison binary on Mac

    On Android:
    1. `vulkan_api_test` on Android
    2. benchmark binary on Android

    Reviewed By: salilsdesai

    Differential Revision: D40266820

    Pull Request resolved: pytorch#86697
    Approved by: https://github.com/kirklandsign

commit f32aeea
Author: Han Qi (qihqi) <qihan@meta.com>
Date:   Tue Oct 11 20:07:58 2022 +0000

    Set interface_call to true be default (pytorch#86668)

    Summary: ASR models need it

    Test Plan: existing unit tests

    Reviewed By: cccclai

    Differential Revision: D40251788

    Pull Request resolved: pytorch#86668
    Approved by: https://github.com/cccclai

commit 7f02f2a
Author: Huy Do <huydhn@gmail.com>
Date:   Tue Oct 11 19:34:44 2022 +0000

    [Experimentation] Add TSAN build and test (pytorch#85313)

    Some parts of the PR are adopted from the previously abandoned pytorch#36694.  This PR is the first part to setup TSAN jobs in the CI.  The data race warnings from TSAN will need to be reviewed later in a separate PR.
    Pull Request resolved: pytorch#85313
    Approved by: https://github.com/osalpekar

commit 9256204
Author: 胡玮文 <sehuww@mail.scut.edu.cn>
Date:   Tue Oct 11 19:03:43 2022 +0000

    Optimize __dlpack_device__ performance (pytorch#86665)

    This can be critical when processing a large number of tensors

    ```bash
    python -m timeit --setup 'import torch; t = torch.empty(1000, device="cuda")' 't.__dlpack_device__()'
    ```

    based on 1.12.1:
    before:
    100000 loops, best of 5: 2.32 usec per loop
    after:
    500000 loops, best of 5: 844 nsec per loop

    Pull Request resolved: pytorch#86665
    Approved by: https://github.com/SunDoge, https://github.com/soulitzer

commit c12f829
Author: Jerry Zhang <jerryzh@meta.com>
Date:   Tue Oct 11 18:49:09 2022 +0000

    [nn] Add remove_duplicate flag to named_buffers (#674) (pytorch#85903)

    Summary:
    X-link: pytorch/torchrec#674

    Pull Request resolved: pytorch#84984

    this is to allow named_buffers to return the same buffer objects with different names multiple times, needed by internal use cases
    ghstack-source-id: 168589597

    Test Plan:
    python test/test_nn.py -k test_buffers_and_named_buffers

    Imported from OSS

    Reviewed By: albanD

    Differential Revision: D39493161

    Pull Request resolved: pytorch#85903
    Approved by: https://github.com/albanD

commit 693250a
Author: David <cherrywoods@posteo.org>
Date:   Tue Oct 11 18:05:53 2022 +0000

    Docs: fx.Node docs incorrectly state that the self argument is included in args for module calls (pytorch#86685)

    It seems like the [torch.fx.Node docs](https://pytorch.org/docs/stable/fx.html#torch.fx.Node) are incorrect regarding the inclusion of the self argument for module call nodes.
    While the docs state that self (the module) is included in `args`, it is in fact not, as demonstrated by this code:
    ```python
    import torch
    from torch import fx, nn

    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.submod = nn.Linear(10, 10)
        def forward(self, x):
            x = x.flatten()
            return self.submod(x)

    graph_module = fx.symbolic_trace(Net())
    print(graph_module.graph)  # doesn't show self for the submodule call
    submod_node = list(graph_module.graph.nodes)[2]
    print(submod_node.op)  # call_module
    print(submod_node.args)  # (flatten,) => would need to have len 2 if self was included

    flatten_node = list(graph_module.graph.nodes)[1]
    print(flatten_node.op)  # call_method
    print(flatten_node.args)  # (x,) => here self is included (and docs are correct)
    ```

    Since [torch.fx.Interpreter also uses `args` as if self was is not included](https://github.com/pytorch/pytorch/blob/2fe580859012d2d24a54e452195ccbc7f3191036/torch/fx/interpreter.py#L288), I assume the docs are incorrect.
    Pull Request resolved: pytorch#86685
    Approved by: https://github.com/soulitzer

commit 160118d
Author: Fang Wang <fangwangcn@fb.com>
Date:   Tue Oct 11 17:52:18 2022 +0000

    Add test case for matrix multiply-add with large inputs (pytorch#85550)

    Summary:
    - Added test case for addmm, baddbmm and linear with large inputs
    - Testing with torch types: float32, float16, bfloat16

    Test Plan:
    Run unit tests with:
    `buck2 run mode/opt //caffe2/test:linalg_re_cuda`

    ```
    ...
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cpu_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cpu_float16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cpu_float32 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cpu_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cpu_float16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cpu_float32 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cpu_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cpu_float16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cpu_float32 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_100_100_100_cpu_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_100_100_100_cpu_float16 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_2_100_100_100_cpu_float32 (test_linalg_re_cuda.TestLinalgReCudaCPU) ... skipped 'Only runs on cuda'
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cuda_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cuda_float16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_1_10000_10000_10000_cuda_float32 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cuda_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cuda_float16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_1_10000_1000_10000_cuda_float32 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cuda_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cuda_float16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_1000_1000_1000_cuda_float32 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_100_100_100_cuda_bfloat16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_100_100_100_cuda_float16 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok
    test_addmm_baddbmm_large_input_2_100_100_100_cuda_float32 (test_linalg_re_cuda.TestLinalgReCudaCUDA) ... ok

    ----------------------------------------------------------------------
    Ran 24 tests in 63.224s

    OK (skipped=12)
    ```

    Differential Revision: D39718256

    Pull Request resolved: pytorch#85550
    Approved by: https://github.com/IvanYashchuk, https://github.com/malfet

commit 212fa87
Author: vfdev <vfdev.5@gmail.com>
Date:   Tue Oct 11 17:52:16 2022 +0000

    Fix torch histogramdd docstring (pytorch#86593)

    Fixed torch histogramdd docsting with missing common_args

    Pull Request resolved: pytorch#86593
    Approved by: https://github.com/soulitzer

commit f26292d
Author: Jane Xu <janeyx@fb.com>
Date:   Tue Oct 11 17:42:51 2022 +0000

    [BE] Fix python docs typos up till torch.chunk (pytorch#86642)

    Was doing the Views lab linked https://github.com/pytorch/pytorch/wiki/Tensor-and-Operator-Basics and noticed a few typos, which led to this PR.

    Test plan:
    verified in preview
    Pull Request resolved: pytorch#86642
    Approved by: https://github.com/soulitzer

commit 86f914e
Author: albanD <desmaison.alban@gmail.com>
Date:   Tue Oct 11 10:35:18 2022 -0400

    min/max support for SymInt/Floats, finish as_strided/scatter/squeeze() backward symint support (pytorch#86643)

    Pull Request resolved: pytorch#86643
    Approved by: https://github.com/anjali411

commit b07eeb0
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 29 17:01:50 2022 +0300

    Use string names for matching view-like functions

commit d8c005a
Merge: 59cb4be ad87365
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 29 17:01:03 2022 +0300

    Merge remote-tracking branch 'upstream/viable/strict' into nvprims-view

commit 59cb4be
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 18:37:59 2022 +0300

    lint

commit 92edd1a
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 18:15:35 2022 +0300

    Add view_copy

commit 79c18da
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 18:08:25 2022 +0300

    Add _unsafe_view to list

commit 254161d
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 18:07:51 2022 +0300

    Add _unsafe_view to tests

commit 487a7a8
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 18:00:30 2022 +0300

    Use func == torch.ops.aten.view.default

commit 24e61bf
Author: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Date:   Thu Sep 22 17:57:48 2022 +0300

    Update torch/_prims/nvfuser_prims.py

commit abad276
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 17:53:42 2022 +0300

    Modify python frontend according latest changes

commit 712447f
Merge: a135db1 0c46e3e
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Thu Sep 22 17:22:44 2022 +0300

    Merge remote-tracking branch 'upstream/viable/strict' into nvprims-view

commit a135db1
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Sep 7 17:06:30 2022 +0300

    Add interception of view for TorchRefsNvfuserCapabilityMode

commit f0c039e
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Sep 7 17:06:07 2022 +0300

    Add test for view -> nvprims.view lowering

commit 246c999
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Sep 7 16:40:13 2022 +0300

    Add tests

commit c48ba8e
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Sep 7 16:39:59 2022 +0300

    Add nvprims.view

commit 3980f32
Author: Ivan Yashchuk <ivan.yashchuk@aalto.fi>
Date:   Wed Sep 7 16:39:38 2022 +0300

    Add fd.ops.view
@yzs981130
Copy link

Are there any milestones for using cudaMallocAsync in a stable version, considering the recent reopen pr #82682?

@ngimel
Copy link
Collaborator
ngimel commented Nov 1, 2022

You can use it in the nightlies by setting env var.

@github-actions
Copy link
Contributor

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Dec 31, 2022
@github-actions github-actions bot closed this Jan 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed module: cuda Related to torch.cuda, and CUDA support in general open source Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0