8000 Merge branch 'pytorch:main' into fix_sort_doc_error · pytorch/pytorch@08b1020 · GitHub
[go: up one dir, main page]

Skip to content

Commit 08b1020

Browse files
authored
Merge branch 'pytorch:main' into fix_sort_doc_error
2 parents 4f7b736 + a4fb657 commit 08b1020

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1899
-1409
lines changed

.github/workflows/_link_check.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ jobs:
1313
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }}
1414
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1515
with:
16+
job-name: lint-urls
1617
timeout: 120
1718
runner: ${{ inputs.runner }}linux.2xlarge
1819
docker-image: ci-image:pytorch-linux-jammy-linter
@@ -38,6 +39,7 @@ jobs:
3839
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }}
3940
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
4041
with:
42+
job-name: lint-xrefs
4143
timeout: 60
4244
runner: ${{ inputs.runner }}linux.2xlarge
4345
docker-image: ci-image:pytorch-linux-jammy-linter

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
388388
dv_expanded = dv;
389389
}
390390

391-
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
392-
std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
393-
394-
uint64_t* drop_seed, drop_offset;
395-
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
396-
std::pair<uint64_t*, uint64_t*> drop_seed_offset = {nullptr,nullptr};
397-
if(is_dropout) {
398-
drop_seed_offset.first = philox_seed[0].data_ptr<uint64_t>();
399-
drop_seed_offset.second = philox_seed[1].data_ptr<uint64_t>();
400-
}
391+
uint64_t drop_seed = 1, drop_offset = 0;
392+
drop_seed = *philox_seed.data_ptr<int64_t>();
393+
drop_offset = *philox_offset.data_ptr<int64_t>();
394+
auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset);
395+
401396

402397
if (seqlen_q > 0) {
403398
ck_tile::stream_config stream_config{stream};

aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
177177
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
178178

179179
const auto sizes = q.sizes();
180+
180181
const int batch_size = sizes[0];
181182
int seqlen_q = sizes[1];
182183
int num_heads = sizes[2];
@@ -225,6 +226,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
225226
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
226227
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
227228

229+
228230
at::Tensor q_padded, k_padded, v_padded;
229231
if (head_size % 8 != 0) {
230232
q_padded = at::pad(temp_q, {0, 8 - head_size % 8});
@@ -237,6 +239,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
237239
v_padded = v;
238240
}
239241

242+
240243
at::Tensor out;
241244
if (out_.has_value()) {
242245
out = out_.value();
@@ -263,6 +266,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
263266
auto opts = q.options();
264267
bool has_lse = true;
265268
bool has_dropout = p_dropout > 0.0f;
269+
266270
at::Tensor softmax_lse;
267271
// TODO - check gradient, only training require lse
268272
softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
@@ -273,41 +277,46 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
273277
p = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype(at::kByte));
274278
}
275279
else {
276-
p = at::empty({ 0 }, opts.dtype(at::kByte));
280+
p = at::empty({ 0 }, opts);
277281
}
278282

279-
280-
uint64_t drop_seed = 1, drop_offset = 0;
281283
int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size();
284+
auto rng_state = at::empty({2}, opts.dtype(at::kLong));
285+
auto rng_state_ptr = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
282286

283-
auto rng_state_options = at::TensorOptions().dtype(at::kUInt64).device(at::kCUDA);
284-
auto rng_state = at::zeros({2}, rng_state_options.dtype(at::kUInt64));
285-
auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
286287

287-
if (p_dropout > 0.0) {
288288

289+
at::Tensor seed_t, offset_t;
290+
291+
if (p_dropout > 0.0) {
289292
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
290293
gen_, at::cuda::detail::getDefaultCUDAGenerator());
291-
292294
// See Note [Acquire lock when using random generators]
293295
std::lock_guard<std::mutex> lock(gen->mutex_);
296+
294297
auto philox_args = gen->philox_cuda_state(counter_offset);
295298

296-
std::tie(drop_seed, drop_offset) = at::cuda::philox::unpack(philox_args);
297299

300+
301+
hipLaunchKernelGGL(
302+
flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::hip::getCurrentHIPStreamMasqueradingAsCUDA(), philox_args, rng_state_ptr);
303+
seed_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[0])), at::dtype(at::kLong));
304+
offset_t = at::scalar_tensor(at::Scalar(static_cast<uint64_t>(rng_state_ptr[1])), at::dtype(at::kLong));
305+
}
306+
else
307+
{
308+
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
309+
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
298310
}
299-
rng_state[0] = *(reinterpret_cast<int64_t*>(&drop_seed));
300-
rng_state[1] = *(reinterpret_cast<int64_t*>(&drop_offset));
301-
auto drop_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
302311

303312
std::optional<at::Tensor> attn_bias;
304313
if( attn_bias_.has_value())
305314
{
306315
attn_bias = attn_bias_;
307316
}
317+
308318
if (seqlen_k > 0) {
309-
auto drop_seed_offset = std::make_pair(rng_state[0].data_ptr<uint64_t>(),
310-
rng_state[1].data_ptr<uint64_t>());
319+
auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1);
311320
auto stream = at::cuda::getCurrentHIPStream().stream();
312321
ck_tile::stream_config stream_config{stream};
313322

@@ -323,7 +332,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
323332
auto args =
324333
get_ck_fmha_fwd_args(
325334
has_lse,
326-
has_dropout,
335+
return_dropout_randval,
327336
mask,
328337
batch_size,
329338
seqlen_q,
@@ -349,11 +358,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
349358
out.zero_();
350359
softmax_lse.fill_(std::numeric_limits<float>::infinity());
351360
}
361+
352362
if (seqlenq_ngroups_swapped) {
353363
out = out.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
354364
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size});
355365
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
356366
}
357-
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
367+
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
358368
}
359369
} //namespace pytorch_flash

0 commit comments

Comments
 (0)
0