8000 Update · pytorch/pytorch@25763ba · GitHub
[go: up one dir, main page]

Skip to content

Commit 25763ba

Browse files
committed
Update
[ghstack-poisoned]
2 parents 9f3d4de + 921956b commit 25763ba

File tree

106 files changed

+2399
-871
lines changed
  • distributed
  • fx
  • nn/parallel
  • profiler
  • testing/_internal
  • utils
  • Some content is hidden

    Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

    106 files changed

    +2399
    -871
    lines changed

    .ci/aarch64_linux/aarch64_wheel_ci_build.py

    Lines changed: 4 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -99,10 +99,14 @@ def update_wheel(wheel_path, desired_cuda) -> None:
    9999
    if "126" in desired_cuda:
    100100
    libs_to_copy += [
    101101
    "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.6",
    102+
    "/usr/local/cuda/lib64/libcufile.so.0",
    103+
    "/usr/local/cuda/lib64/libcufile_rdma.so.1",
    102104
    ]
    103105
    elif "128" in desired_cuda:
    104106
    libs_to_copy += [
    105107
    "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8",
    108+
    "/usr/local/cuda/lib64/libcufile.so.0",
    109+
    "/usr/local/cuda/lib64/libcufile_rdma.so.1",
    106110
    ]
    107111
    else:
    108112
    libs_to_copy += [

    aten/src/ATen/native/TensorShape.cpp

    Lines changed: 1 addition & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -3059,8 +3059,7 @@ Tensor slice(
    30593059
    }
    30603060
    auto storage_offset = self.storage_offset() + start_val * strides[dim];
    30613061
    auto len = end_val - start_val;
    3062-
    sizes[dim] =
    3063-
    (len == 0) ? 0 : (1 + (len - 1) / step); // round-up, avoiding overflow
    3062+
    sizes[dim] = (len + step - 1) / step; // round-up
    30643063
    strides[dim] *= step;
    30653064

    30663065
    Tensor result;

    aten/src/ATen/native/cpu/Activation.cpp

    Lines changed: 10 additions & 10 deletions
    Original file line numberDiff line numberDiff line change
    @@ -832,9 +832,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
    832832
    cpu_kernel_vec(
    833833
    iter,
    834834
    [&](scalar_t grad_val, scalar_t self_val) -> scalar_t {
    835-
    if (float(self_val) < neg_three) {
    835+
    if (float(self_val) <= neg_three) {
    836836
    return zero;
    837-
    } else if (float(self_val) <= three) {
    837+
    } else if (float(self_val) < three) {
    838838
    return float(grad_val) * ((float(self_val) / three) + one_half);
    839839
    } else {
    840840
    return grad_val;
    @@ -847,19 +847,19 @@ void hardswish_backward_kernel(TensorIterator& iter) {
    847847
    Vec::blendv(
    848848
    grad_val0 * ((self_val0 / kThreeVec) + kOneHalfVec),
    849849
    grad_val0,
    850-
    self_val0 > kThreeVec
    850+
    self_val0 >= kThreeVec
    851851
    ),
    852852
    kZeroVec,
    853-
    self_val0 < kNegThreeVec
    853+
    self_val0 <= kNegThreeVec
    854854
    );
    855855
    self_val1 = Vec::blendv(
    856856
    Vec::blendv(
    857857
    grad_val1 * ((self_val1 / kThreeVec) + kOneHalfVec),
    858858
    grad_val1,
    859-
    self_val1 > kThreeVec
    859+
    self_val1 >= kThreeVec
    860860
    ),
    861861
    kZeroVec,
    862-
    self_val1 < kNegThreeVec
    862+
    self_val1 <= kNegThreeVec
    863863
    );
    864864
    return convert_from_float<scalar_t>(self_val0, self_val1);
    865865
    });
    @@ -878,9 +878,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
    878878
    cpu_kernel_vec(
    879879
    iter,
    880880
    [&](scalar_t grad_val, scalar_t self_val) {
    881-
    if (self_val < neg_three) {
    881+
    if (self_val <= neg_three) {
    882882
    return zero;
    883-
    } else if (self_val <= three) {
    883+
    } else if (self_val < three) {
    884884
    return grad_val * ((self_val / three) + one_half);
    885885
    } else {
    886886
    return grad_val;
    @@ -891,10 +891,10 @@ void hardswish_backward_kernel(TensorIterator& iter) {
    891891
    Vec::blendv(
    892892
    grad_val * ((self_val / kThreeVec) + kOneHalfVec),
    893893
    grad_val,
    894-
    self_val > kThreeVec
    894+
    self_val >= kThreeVec
    895895
    ),
    896896
    kZeroVec,
    897-
    self_val < kNegThreeVec
    897+
    self_val <= kNegThreeVec
    898898
    );
    899899
    }
    900900
    );

    aten/src/ATen/native/cuda/ActivationHardswishKernel.cu

    Lines changed: 2 additions & 2 deletions
    Original file line numberDiff line numberDiff line change
    @@ -45,9 +45,9 @@ void hardswish_backward_kernel(TensorIterator& iter) {
    4545
    [zero, three, neg_three, one_half]GPU_LAMBDA(scalar_t grad_val_, scalar_t self_val_) -> scalar_t {
    4646
    opmath_t grad_val = static_cast<opmath_t>(grad_val_);
    4747
    opmath_t self_val = static_cast<opmath_t>(self_val_);
    48-
    if (self_val < neg_three) {
    48+
    if (self_val <= neg_three) {
    4949
    return zero;
    50-
    } else if (self_val <= three) {
    50+
    } else if (self_val < three) {
    5151
    return grad_val * ((self_val / three) + one_half);
    5252
    } else {
    5353
    return grad_val;

    aten/src/ATen/native/cuda/Blas.cpp

    Lines changed: 2 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1149,9 +1149,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
    11491149
    TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type");
    11501150
    TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type());
    11511151
    TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type());
    1152+
    #ifndef USE_ROCM
    11521153
    // Type restrictions imposed by CuBLASLt as of CUDA-12.1
    11531154
    TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2,
    11541155
    "Multiplication of two Float8_e5m2 matrices is not supported");
    1156+
    #endif
    11551157
    if (bias) {
    11561158
    TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32");
    11571159
    TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half,

    aten/src/ATen/native/cuda/CUDAScalar.cu

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -26,6 +26,7 @@ namespace at::native {
    2626

    2727
    Scalar _local_scalar_dense_cuda(const Tensor& self) {
    2828
    Scalar r;
    29+
    TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported");
    2930
    #if defined(USE_ROCM)
    3031
    if (!use_sync_mode()){
    3132
    #endif

    aten/src/ATen/native/mps/operations/BitwiseOps.mm

    Lines changed: 8 additions & 43 deletions
    Original file line numberDiff line numberDiff line change
    @@ -103,22 +103,14 @@ kernel void bitwise_not(device {0} *out [[buffer(0)]],
    103103
    return getMetalType(t.scalar_type());
    104104
    }
    105105

    106-
    static inline std::string getMetalType(const c10::Scalar& s) {
    107-
    return getMetalType(s.type());
    108-
    }
    109-
    110-
    template <typename ScalarOrTensor>
    111106
    static id<MTLComputePipelineState> getCPLState(const Tensor& t1,
    112107
    const Tensor& t2,
    113-
    const ScalarOrTensor& t3,
    108+
    const Tensor& t3,
    114109
    const std::string& fname) {
    115110
    return lib.getPipelineStateForFunc(fname, {getMetalType(t1), getMetalType(t2), getMetalType(t3)});
    116111
    }
    117112

    118-
    static void handle_tensor_tensor_binary_op(const Tensor& self,
    119-
    const Tensor& other,
    120-
    Tensor& output,
    121-
    const std::string& kernel_name) {
    113+
    static void handle_binary_op(const Tensor& self, const Tensor& other, Tensor& output, const std::string& kernel_name) {
    122114
    using namespace at::mps;
    123115
    MPSStream* stream = getCurrentMPSStream();
    124116
    auto cplState = getCPLState(output, self, other, kernel_name);
    @@ -142,33 +134,6 @@ static void handle_tensor_tensor_binary_op(const Tensor& self,
    142134
    });
    143135
    }
    144136

    145-
    static void handle_tensor_scalar_binary_op(const Tensor& self,
    146-
    const Scalar& other,
    147-
    Tensor& output,
    148-
    const std::string& kernel_name) {
    149-
    using namespace at::mps;
    150-
    MPSStream* stream = getCurrentMPSStream();
    151-
    auto cplState = getCPLState(output, self, other, kernel_name);
    152-
    uint64_t sval = other.to<int64_t>();
    153-
    uint32_t length = output.numel();
    154-
    if (length == 0) {
    155-
    return;
    156-
    }
    157-
    158-
    dispatch_sync(stream->queue(), ^() {
    159-
    getMPSProfiler().beginProfileKernel(cplState, kernel_name, {self});
    160-
    161-
    id<MTLComputeCommandEncoder> commandEncoder = stream->commandEncoder();
    162-
    163-
    [commandEncoder pushDebugGroup:[NSString stringWithFormat:@"Dispatch %s kernel", kernel_name.c_str()]];
    164-
    [commandEncoder setComputePipelineState:cplState];
    165-
    mtl_setArgs(commandEncoder, output, self, sval);
    166-
    mtl_dispatch1DJob(commandEncoder, cplState, length);
    167-
    168-
    getMPSProfiler().endProfileKernel(cplState);
    169-
    });
    170-
    }
    171-
    172137
    static void _bitwise_op_out_mps(const Tensor& self,
    173138
    const Tensor& other,
    174139
    const Tensor& output_,
    @@ -201,14 +166,14 @@ static void _bitwise_op_out_mps(const Tensor& self,
    201166
    TORCH_CHECK(false, "Unknown operation to be performed over scalars ", op_name);
    202167
    }
    203168
    } else if (is_other_scalar) {
    204-
    handle_tensor_scalar_binary_op(self.contiguous(), other.item(), output, fmt::format("bitwise_{}_scalar", op_name));
    169+
    handle_binary_op(self.contiguous(), other, output, fmt::format("bitwise_{}_scalar", op_name));
    205170
    } else if (is_self_scalar) {
    206-
    handle_tensor_scalar_binary_op(other.contiguous(), self.item(), output, fmt::format("bitwise_{}_scalar", op_name));
    171+
    handle_binary_op(other.contiguous(), self, output, fmt::format("bitwise_{}_scalar", op_name));
    207172
    } else {
    208-
    handle_tensor_tensor_binary_op(self.expand(output_size).contiguous(),
    209-
    other.expand(output_size).contiguous(),
    210-
    output,
    211-
    fmt::format("bitwise_{}_tensor", op_name));
    173+
    handle_binary_op(self.expand(output_size).contiguous(),
    174+
    other.expand(output_size).contiguous(),
    175+
    output,
    176+
    fmt::format("bitwise_{}_tensor", op_name));
    212177
    }
    213178
    if (needs_output_copy) {
    214179
    output_.copy_(output);

    aten/src/ATen/native/mps/operations/Scalar.mm

    Lines changed: 1 addition & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -15,6 +15,7 @@
    1515

    1616
    Scalar _local_scalar_dense_mps(const Tensor& self) {
    1717
    Scalar r;
    18+
    TORCH_CHECK(self.numel() > 0, "_local_scalar_dense: Empty tensor not supported");
    1819

    1920
    auto output = at::empty_like(self, TensorOptions(kCPU));
    2021
    mps::mps_copy_(output, self, false);

    aten/src/ATen/native/transformers/cuda/attention.cu

    Lines changed: 20 additions & 36 deletions
    Original file line numberDiff line numberDiff line change
    @@ -1246,22 +1246,19 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
    12461246
    auto ret = aotriton::v2::flash::check_gpu(stream);
    12471247
    if (hipSuccess != ret) {
    12481248
    TORCH_CHECK(false,
    1249-
    "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs"
    1250-
    " (gfx90a/gfx942/gfx1100/gfx1201)")
    1249+
    "[AOTriton] Accelerated SDPA only supports MI200/MI300X/Navi31 GPUs"
    1250+
    " (gfx90a:sramecc+:xnack-/gfx942:sramecc+:xnack-/gfx1100)")
    12511251
    }
    12521252

    12531253
    // AOTriton may accept aligned on logsumexp tensor in the future for better
    12541254
    // performance, but for now it requires compact logsumexp tensor, even if
    12551255
    // compute_logsumexp is false
    12561256
    constexpr int kAlignLSE = 1;
    12571257
    res = at::empty({B, M, num_heads, Kv}, query.options());
    1258-
    at::Tensor softmax_lse;
    12591258
    logsumexp = at::empty(
    1260-
    { B, num_heads, compute_logsumexp ? max_seqlen_q : 0},
    1259+
    { B, num_heads, max_seqlen_q },
    12611260
    query.options().dtype(at::ScalarType::Float));
    1262-
    if (compute_logsumexp) {
    1263-
    softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
    1264-
    }
    1261+
    at::Tensor softmax_lse = logsumexp.view({B * num_heads, max_seqlen_q});
    12651262
    at::Tensor q_t = query.transpose(1, 2);
    12661263
    at::Tensor k_t = key.transpose(1, 2);
    12671264
    at::Tensor v_t = value.transpose(1, 2);
    @@ -1277,40 +1274,32 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
    12771274

    12781275
    const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float();
    12791276

    1280-
    at::Tensor atomic_counter;
    1281-
    if (is_causal) {
    1282-
    atomic_counter = at::zeros({1}, query.options().dtype(at::kInt));
    1283-
    }
    1284-
    12851277
    using aotriton::v2::flash::attn_fwd;
    12861278
    using aotriton::v2::flash::attn_fwd_compact_varlen;
    12871279
    using sdp::aotriton_adapter::mk_aotensor;
    12881280
    using sdp::aotriton_adapter::mk_aoscalartensor;
    12891281
    using sdp::aotriton_adapter::mk_philoxtensor;
    1290-
    using sdp::aotriton_adapter::mk_atomictensor;
    12911282
    aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, aotriton::DType::kFloat16);
    1292-
    aotriton::TensorView<2> empty_t2(0, {0, 0}, {0, 0}, aotriton::DType::kFloat32);
    12931283
    at::Tensor softmax_fa_t = at::empty({ 0, 0, 0, 0 }, query.options());
    12941284
    const bool use_philox_state = in_capture_stream;
    12951285
    auto seed = use_philox_state ? mk_philoxtensor(philox_state.seed_.ptr) : mk_aoscalartensor(seed_t);
    12961286
    auto offset1 = use_philox_state ? mk_philoxtensor(philox_state.offset_.ptr) : mk_aoscalartensor(offset_t);
    12971287
    auto offset2 = use_philox_state ? philox_state.offset_i 10000 ntragraph_ : 0;
    1298-
    auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr<int64_t>() : nullptr);
    1299-
    auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr<int64_t>() : nullptr);
    1300-
    auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr);
    1288+
    auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
    1289+
    auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr<int64_t>()) : mk_philoxtensor(nullptr);
    13011290
    hipError_t err; // TODO: Error handling
    13021291
    if (seqstart_q.has_value()) {
    13031292
    // varlen aka nested tensor
    13041293
    err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"),
    13051294
    mk_aotensor(k_t, "k"),
    13061295
    mk_aotensor(v_t, "v"),
    1307-
    bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
    13081296
    mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"),
    13091297
    mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"),
    13101298
    max_seqlen_q,
    13111299
    max_seqlen_k,
    1300+
    bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
    13121301
    softmax_scale,
    1313-
    compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
    1302+
    mk_aotensor<2>(softmax_lse, "M"),
    13141303
    mk_aotensor(output_t, "Out"),
    13151304
    dropout_p,
    13161305
    seed,
    @@ -1320,15 +1309,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
    13201309
    offset_output,
    13211310
    mk_aotensor(softmax_fa_t, "encoded_softmax"),
    13221311
    is_causal,
    1323-
    persistent_counter,
    13241312
    stream);
    13251313
    } else {
    13261314
    err = attn_fwd(mk_aotensor(q_t, "q"),
    13271315
    mk_aotensor(k_t, "k"),
    13281316
    mk_aotensor(v_t, "v"),
    13291317
    bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4,
    13301318
    softmax_scale,
    1331-
    compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2,
    1319+
    mk_aotensor<2>(softmax_lse, "M"),
    13321320
    mk_aotensor(output_t, "Out"),
    13331321
    dropout_p,
    13341322
    seed,
    @@ -1338,9 +1326,14 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
    13381326
    offset_output,
    13391327
    mk_aotensor(softmax_fa_t, "encoded_softmax"),
    13401328
    is_causal,
    1341-
    persistent_counter,
    13421329
    stream);
    13431330
    }
    1331+
    if (!compute_logsumexp) {
    1332+
    // Set the tensor to empty when compute_logsumexp is false
    1333+
    logsumexp = at::empty(
    1334+
    { B * num_heads, max_seqlen_q, 0 },
    1335+
    query.options().dtype(at::ScalarType::Float));
    1336+
    }
    13441337
    #else
    13451338
    // CUDA Implementation
    13461339
    cudaDeviceProp* p = at::cuda::getDeviceProperties(query.device().index());
    @@ -1602,24 +1595,15 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
    16021595
    #if defined(USE_MEM_EFF_ATTENTION)
    16031596

    16041597
    #ifdef USE_ROCM
    1605-
    using aotriton::v2::flash::debug_simulate_encoded_softmax;
    1598+
    using aotriton::v2::flash::debug_fill_dropout_rng;
    16061599
    using sdp::aotriton_adapter::mk_aotensor;
    1607-
    using sdp::aotriton_adapter::mk_aoscalartensor;
    1608-
    at::cuda::CUDAGuard device_guard(self.device());
    16091600
    cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    1610-
    1611-
    at::Tensor seed_t, offset_t;
    1612-
    const auto options = at::dtype(at::kLong).device(at::kCUDA);
    1613-
    seed_t = at::scalar_tensor(at::Scalar(seed), options);
    1614-
    offset_t = at::scalar_tensor(at::Scalar(offset), options);
    16151601
    hipError_t err; // TODO: Error handling
    16161602

    1617-
    err = debug_simulate_encoded_softmax(mk_aotensor(self, "r"),
    1618-
    dropout_p,
    1619-
    mk_aoscalartensor(seed_t),
    1620-
    mk_aoscalartensor(offset_t),
    1621-
    0,
    1622-
    stream);
    1603+
    err = debug_fill_dropout_rng(mk_aotensor(self, "r"),
    1604+
    static_cast<uint64_t>(seed),
    1605+
    static_cast<uint64_t>(offset),
    1606+
    stream);
    16231607
    #else
    16241608
    at::PhiloxCudaState rng_engine_inputs;
    16251609
    rng_engine_inputs = at::PhiloxCudaState(seed, offset);

    0 commit comments

    Comments
     (0)
    0