Skip to content

Instantly share code, notes, and snippets.

@hvaara
Last active August 27, 2024 15:53
Show Gist options
  • Save hvaara/140a09f0fd5b4c36932bebdfc9275c3c to your computer and use it in GitHub Desktop.
Save hvaara/140a09f0fd5b4c36932bebdfc9275c3c to your computer and use it in GitHub Desktop.
git log v2.4.0..HEAD --oneline
d74039f7010 (HEAD -> repro-24-134580, origin/repro-24-134580) Repro case for #134580
ccdbe084a9e Skip memory_format tests
49f0d3f1111 Update common_modules.py
630cc4ea8b2 Update test_nn.py
b702a483965 Use newer `toAccumulateType` signature in Normalization.cpp
$ git diff v2.4.0
diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp
index e9e7c001837..16ada4cead5 100644
--- a/aten/src/ATen/native/Normalization.cpp
+++ b/aten/src/ATen/native/Normalization.cpp
@@ -552,7 +552,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t> _batch_norm_impl_index(
if (input.sym_numel() == 0) {
Tensor reserve = at::empty({0}, input.options().dtype(kByte));
auto options = input.options().dtype(
- at::toAccumulateType(input.scalar_type(), /*is_cuda=*/input.is_cuda()));
+ at::toAccumulateType(input.scalar_type(), input.device().type()));
auto save_mean = at::empty_symint(c10::SymIntArrayRef({num_features}), options);
auto save_invstd = at::empty_symint(c10::SymIntArrayRef({std::move(num_features)}), options);
diff --git a/test/test_nn.py b/test/test_nn.py
index b4283cbbad8..cfc3875538e 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8917,7 +8917,9 @@ class TestNNDeviceType(NNTestCase):
else:
self.assertEqual(hx.grad, hx_device.grad)
- def test_BatchNorm_empty(self, device):
+ @dtypesIfMPS(torch.float)
+ @dtypes(torch.double)
+ def test_BatchNorm_empty(self, device, dtype):
mod = torch.nn.BatchNorm2d(3).to(device)
inp = torch.randn(0, 3, 2, 2, device=device)
_test_module_empty_input(self, mod, inp)
diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py
index 07caa0ac3ee..72b357971f1 100644
--- a/torch/testing/_internal/common_device_type.py
+++ b/torch/testing/_internal/common_device_type.py
@@ -1398,6 +1398,10 @@ def expectedFailureXPU(fn):
def expectedFailureMeta(fn):
return skipIfTorchDynamo()(expectedFailure('meta')(fn))
+def expectedFailureMPS(fn):
+ return expectedFailure("mps")(fn)
+
+
def expectedFailureXLA(fn):
return expectedFailure('xla')(fn)
diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py
index 5e7e3739695..eafc06ae173 100644
--- a/torch/testing/_internal/common_modules.py
+++ b/torch/testing/_internal/common_modules.py
@@ -15,8 +15,9 @@ from torch.testing._internal.common_cuda import TEST_CUDNN
from torch.testing._internal.common_dtype import (
floating_types, floating_and_complex_types_and, get_all_fp_dtypes)
from torch.testing._internal.common_device_type import (
- _TestParametrizer, _update_param_kwargs, toleranceOverride, tol,
- skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS, skipCUDAVersionIn)
+ _TestParametrizer, _update_param_kwargs, expectedFailureMPS, toleranceOverride, tol,
+ skipCUDAIfCudnnVersionLessThan, skipCUDAIfRocm, precisionOverride, skipMeta, skipMPS,
+ skipCUDAVersionIn)
from torch.testing._internal.common_methods_invocations import DecorateInfo
from torch.testing._internal.common_nn import (
cosineembeddingloss_reference, cross_entropy_loss_reference, ctcloss_reference,
@@ -3433,9 +3434,6 @@ module_db: List[ModuleInfo] = [
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm1d,
skips=(
- # test fails on MPS backend and is being investigated.
- # See https://github.com/pytorch/pytorch/issues/100914
- DecorateInfo(skipMPS),
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
DecorateInfo(
@@ -3454,9 +3452,8 @@ module_db: List[ModuleInfo] = [
train_and_eval_differ=True,
module_inputs_func=module_inputs_torch_nn_BatchNorm2d,
skips=(
- # test fails on MPS backend and is being investigated.
- # See https://github.com/pytorch/pytorch/issues/100914
- DecorateInfo(skipMPS),
+ # See https://github.com/pytorch/pytorch/issues/134580
+ # DecorateInfo(expectedFailureMPS, 'TestModule', 'test_memory_format'),
# tracking here rather than in the list in test_aotdispatch.py as eval mode passes
# RuntimeError: tried to get Double out of SymInt
DecorateInfo(
$ python test/test_modules.py -v -k test_memory_format_nn_BatchNorm2d_
test_memory_format_nn_BatchNorm2d_eval_mode_cpu_float32 (__main__.TestModuleCPU.test_memory_format_nn_BatchNorm2d_eval_mode_cpu_float32) ... ok
test_memory_format_nn_BatchNorm2d_eval_mode_cpu_float64 (__main__.TestModuleCPU.test_memory_format_nn_BatchNorm2d_eval_mode_cpu_float64) ... ok
test_memory_format_nn_BatchNorm2d_train_mode_cpu_float32 (__main__.TestModuleCPU.test_memory_format_nn_BatchNorm2d_train_mode_cpu_float32) ... ok
test_memory_format_nn_BatchNorm2d_train_mode_cpu_float64 (__main__.TestModuleCPU.test_memory_format_nn_BatchNorm2d_train_mode_cpu_float64) ... ok
test_memory_format_nn_BatchNorm2d_eval_mode_mps_float16 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_eval_mode_mps_float16) ... ok
test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32) ... ok
test_memory_format_nn_BatchNorm2d_train_mode_mps_float16 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float16) ... FAIL
test_memory_format_nn_BatchNorm2d_train_mode_mps_float32 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float32) ... FAIL
======================================================================
FAIL: test_memory_format_nn_BatchNorm2d_train_mode_mps_float16 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float16)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2744, in wrapper
method(*args, **kwargs)
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
result = test(self, **param_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1364, in wrapper
fn(*args, **kwargs)
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_modules.py", line 129, in test_wrapper
return test(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_cuda.py", line 205, in wrapped
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/test/test_modules.py", line 775, in test_memory_format
self.assertEqual(
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3636, in assertEqual
raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!
Mismatched elements: 216 / 216 (100.0%)
Greatest absolute difference: 1.05078125 at index (1, 0, 1, 4) (up to 1e-05 allowed)
Greatest relative difference: 2576.0 at index (0, 2, 5, 0) (up to 0.001 allowed)
The failure occurred for item [0]
To execute this test, run the following from the base repo dir:
python test/test_modules.py -k TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float16
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
======================================================================
FAIL: test_memory_format_nn_BatchNorm2d_train_mode_mps_float32 (__main__.TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float32)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 2744, in wrapper
method(*args, **kwargs)
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_device_type.py", line 419, in instantiated_test
result = test(self, **param_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 1364, in wrapper
fn(*args, **kwargs)
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_modules.py", line 129, in test_wrapper
return test(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_cuda.py", line 205, in wrapped
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "/Users/hvaara/dev/pytorch/pytorch/test/test_modules.py", line 775, in test_memory_format
self.assertEqual(
File "/Users/hvaara/dev/pytorch/pytorch/torch/testing/_internal/common_utils.py", line 3636, in assertEqual
raise error_metas.pop()[0].to_error(
AssertionError: Tensor-likes are not close!
Mismatched elements: 216 / 216 (100.0%)
Greatest absolute difference: 1.0258822441101074 at index (1, 1, 4, 4) (up to 1e-05 allowed)
Greatest relative difference: 1735.79296875 at index (0, 1, 0, 0) (up to 1e-05 allowed)
The failure occurred for item [0]
To execute this test, run the following from the base repo dir:
python test/test_modules.py -k TestModuleMPS.test_memory_format_nn_BatchNorm2d_train_mode_mps_float32
This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0
----------------------------------------------------------------------
Ran 8 tests in 0.362s
FAILED (failures=2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment