@@ -177,6 +177,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
177177 TORCH_CHECK (v.stride (-1 ) == 1 , " Input tensor must have contiguous last dimension" );
178178
179179 const auto sizes = q.sizes ();
180+
180181 const int batch_size = sizes[0 ];
181182 int seqlen_q = sizes[1 ];
182183 int num_heads = sizes[2 ];
@@ -225,6 +226,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
225226 CHECK_SHAPE (k, batch_size, seqlen_k, num_heads_k, head_size);
226227 CHECK_SHAPE (v, batch_size, seqlen_k, num_heads_k, head_size);
227228
229+
228230 at::Tensor q_padded, k_padded, v_padded;
229231 if (head_size % 8 != 0 ) {
230232 q_padded = at::pad (temp_q, {0 , 8 - head_size % 8 });
@@ -237,6 +239,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
237239 v_padded = v;
238240 }
239241
242+
240243 at::Tensor out;
241244 if (out_.has_value ()) {
242245 out = out_.value ();
@@ -263,6 +266,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
263266 auto opts = q.options ();
264267 bool has_lse = true ;
265268 bool has_dropout = p_dropout > 0 .0f ;
269+
266270 at::Tensor softmax_lse;
267271 // TODO - check gradient, only training require lse
268272 softmax_lse = at::empty ({batch_size, num_heads, seqlen_q}, opts.dtype (at::kFloat ));
@@ -273,41 +277,46 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
273277 p = at::empty ({batch_size, num_heads, seqlen_q, seqlen_k}, opts.dtype (at::kByte ));
274278 }
275279 else {
276- p = at::empty ({ 0 }, opts. dtype (at:: kByte ) );
280+ p = at::empty ({ 0 }, opts);
277281 }
278282
279-
280- uint64_t drop_seed = 1 , drop_offset = 0 ;
281283 int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size ();
284+ auto rng_state = at::empty ({2 }, opts.dtype (at::kLong ));
285+ auto rng_state_ptr = reinterpret_cast <uint64_t *>(rng_state.data_ptr ());
282286
283- auto rng_state_options = at::TensorOptions ().dtype (at::kUInt64 ).device (at::kCUDA );
284- auto rng_state = at::zeros ({2 }, rng_state_options.dtype (at::kUInt64 ));
285- auto _unused = at::empty ({}, at::dtype (c10::kUInt64 ).device (at::kCUDA ));
286287
287- if (p_dropout > 0.0 ) {
288288
289+ at::Tensor seed_t , offset_t ;
290+
291+ if (p_dropout > 0.0 ) {
289292 auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
290293 gen_, at::cuda::detail::getDefaultCUDAGenerator ());
291-
292294 // See Note [Acquire lock when using random generators]
293295 std::lock_guard<std::mutex> lock (gen->mutex_ );
296+
294297 auto philox_args = gen->philox_cuda_state (counter_offset);
295298
296- std::tie (drop_seed, drop_offset) = at::cuda::philox::unpack (philox_args);
297299
300+
301+ hipLaunchKernelGGL (
302+ flash::ParsePhiloxCudaState, dim3 (1 ), dim3 (64 ), 0 , at::hip::getCurrentHIPStreamMasqueradingAsCUDA (), philox_args, rng_state_ptr);
303+ seed_t = at::scalar_tensor (at::Scalar (static_cast <uint64_t >(rng_state_ptr[0 ])), at::dtype (at::kLong ));
304+ offset_t = at::scalar_tensor (at::Scalar (static_cast <uint64_t >(rng_state_ptr[1 ])), at::dtype (at::kLong ));
305+ }
306+ else
307+ {
308+ seed_t = at::empty ({}, at::dtype (at::kLong ).device (at::kCUDA ));
309+ offset_t = at::empty ({}, at::dtype (at::kLong ).device (at::kCUDA ));
298310 }
299- rng_state[0 ] = *(reinterpret_cast <int64_t *>(&drop_seed));
300- rng_state[1 ] = *(reinterpret_cast <int64_t *>(&drop_offset));
301- auto drop_options = at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA );
302311
303312 std::optional<at::Tensor> attn_bias;
304313 if ( attn_bias_.has_value ())
305314 {
306315 attn_bias = attn_bias_;
307316 }
317+
308318 if (seqlen_k > 0 ) {
309- auto drop_seed_offset = std::make_pair (rng_state[0 ].data_ptr <uint64_t >(),
310- rng_state[1 ].data_ptr <uint64_t >());
319+ auto drop_seed_offset = std::make_pair (rng_state_ptr, rng_state_ptr + 1 );
311320 auto stream = at::cuda::getCurrentHIPStream ().stream ();
312321 ck_tile::stream_config stream_config{stream};
313322
@@ -323,7 +332,7 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
323332 auto args =
324333 get_ck_fmha_fwd_args (
325334 has_lse,
326- has_dropout ,
335+ return_dropout_randval ,
327336 mask,
328337 batch_size,
329338 seqlen_q,
@@ -349,11 +358,12 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x
349358 out.zero_ ();
350359 softmax_lse.fill_ (std::numeric_limits<float >::infinity ());
351360 }
361+
352362 if (seqlenq_ngroups_swapped) {
353363 out = out.transpose (1 , 2 ).reshape ({batch_size, 1 , num_heads_k * seqlen_q, head_size});
354364 q_padded = q_padded.transpose (1 , 2 ).reshape ({batch_size, 1 , num_heads_k * seqlen_q, head_size});
355365 softmax_lse = softmax_lse.reshape ({batch_size, num_heads_k * seqlen_q, 1 });
356366 }
357- return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused , p};
367+ return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t , offset_t , p};
358368}
359369} // namespace pytorch_flash
0 commit comments