From 2a93f9fcba0afbc9dd8acd0d83ceabe2debceb5b Mon Sep 17 00:00:00 2001 From: wengshiy Date: Fri, 9 May 2025 01:33:59 -0400 Subject: [PATCH 1/2] Add assertion to align with cuda --- aten/src/ATen/native/Normalization.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index e74aa4fa9c3f..011e598ecd89 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -770,6 +770,10 @@ std::tuple batch_norm_update_stats_cpu( std::tuple batch_norm_cpu_out(const Tensor& self, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) { + const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); + const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined()); + TORCH_CHECK(has_running_mean == has_running_var); + // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; From 136ab31906933f16248d244a1f91a1792e1af95f Mon Sep 17 00:00:00 2001 From: wengshiy Date: Sun, 11 May 2025 22:18:39 -0400 Subject: [PATCH 2/2] use TORCH_CHECK_VALUE instead of TORCH_CHECK --- aten/src/ATen/native/Normalization.cpp | 3 ++- aten/src/ATen/native/cuda/Normalization.cu | 3 ++- aten/src/ATen/native/mps/operations/Normalization.mm | 9 ++++++--- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 011e598ecd89..fb4ce917bf16 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -772,7 +772,8 @@ std::tuple batch_norm_cpu_out(const Tensor& self, con bool train, double momentum, double eps, Tensor& out, Tensor& save_mean, Tensor& save_var) { const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined()); - TORCH_CHECK(has_running_mean == has_running_var); + TORCH_CHECK_VALUE(has_running_mean == has_running_var, + "running_mean and running_var must either both be None or neither be None"); // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index 8db7241dee13..55d848610f5d 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -435,7 +435,8 @@ void batch_norm_calc_invstd(const Tensor& out_invstd, const Tensor& running_var, std::tuple batch_norm_cuda_out(const Tensor& self, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool train, double momentum, double epsilon, Tensor& output, Tensor& save_mean, Tensor& save_invstd) { const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined()); - TORCH_CHECK(has_running_mean == has_running_var); + TORCH_CHECK_VALUE(has_running_mean == has_running_var, + "running_mean and running_var must either both be None or neither be None"); if (train) { batch_norm_mean_var(self, save_mean, save_invstd); diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index a4ce5350aade..57656075b680 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -103,7 +103,8 @@ static void get_shapes(MPSShape* input_shape_readonly, const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined()); - TORCH_CHECK(has_running_mean == has_running_var); + TORCH_CHECK_VALUE(has_running_mean == has_running_var, + "running_mean and running_var must either both be None or neither be None"); const bool has_weight = (weight_opt.has_value() && weight_opt->defined()); const bool has_bias = (bias_opt.has_value() && bias_opt->defined()); @@ -587,10 +588,12 @@ Check if running mean exists (maybe do this check before making graph) const bool has_running_mean = (running_mean_opt.has_value() && running_mean_opt->defined()); const bool has_running_var = (running_var_opt.has_value() && running_var_opt->defined()); - TORCH_CHECK(has_running_mean == has_running_var); + TORCH_CHECK_VALUE(has_running_mean == has_running_var, + "running_mean and running_var must either both be None or neither be None"); const bool has_save_mean = (save_mean_opt.has_value() && save_mean_opt->defined()); const bool has_save_var = (save_var_opt.has_value() && save_var_opt->defined()); - TORCH_CHECK(has_save_mean == has_save_var); + TORCH_CHECK_VALUE(has_save_mean == has_save_var, + "save_mean and save_var must either both be None or neither be None"); const bool has_weight = (weight_opt.has_value() && weight_opt->defined());