Created
August 15, 2025 04:51
-
-
Save msaroufim/6a6e3cfeb725d3f731ae6e42bf88bcfb to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py | |
index 938cb7dd97a..d3ac1369e6a 100644 | |
--- a/torch/testing/_internal/common_methods_invocations.py | |
+++ b/torch/testing/_internal/common_methods_invocations.py | |
@@ -7443,6 +7443,57 @@ def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs | |
yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d}) | |
+def sample_inputs_copy(op_info, device, dtype, requires_grad, **kwargs): | |
+ """Sample inputs for copy and copy_ operations. | |
+ | |
+ copy(dst_tensor, src_tensor, non_blocking=False) copies data from src to dst. | |
+ For copy_: dst_tensor.copy_(src_tensor, non_blocking=False) | |
+ """ | |
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) | |
+ | |
+ # Basic same-shape copies | |
+ dst = make_arg((S, M)) | |
+ src = make_arg((S, M)) | |
+ yield SampleInput(dst, args=(src,)) | |
+ | |
+ # Scalar cases | |
+ dst_scalar = make_arg(()) | |
+ src_scalar = make_arg(()) | |
+ yield SampleInput(dst_scalar, args=(src_scalar,)) | |
+ | |
+ # Broadcasting cases - src can be broadcast to dst shape | |
+ dst_broadcast = make_arg((S, M, L)) | |
+ src_broadcast1 = make_arg((1, M, L)) # broadcast in first dim | |
+ src_broadcast2 = make_arg((S, 1, L)) # broadcast in second dim | |
+ src_broadcast3 = make_arg((1,)) # broadcast scalar to tensor | |
+ yield SampleInput(dst_broadcast, args=(src_broadcast1,)) | |
+ yield SampleInput(dst_broadcast, args=(src_broadcast2,)) | |
+ yield SampleInput(dst_broadcast, args=(src_broadcast3,)) | |
+ | |
+ # Different dtypes that are compatible | |
+ if dtype != torch.bool: # bool has limited conversion support | |
+ if dtype.is_floating_point: | |
+ # Float to float conversions | |
+ other_dtype = torch.float64 if dtype != torch.float64 else torch.float32 | |
+ elif dtype.is_complex: | |
+ # Complex to complex conversions | |
+ other_dtype = torch.complex128 if dtype != torch.complex128 else torch.complex64 | |
+ else: | |
+ # Int to int conversions | |
+ other_dtype = torch.int64 if dtype != torch.int64 else torch.int32 | |
+ | |
+ if other_dtype != dtype: | |
+ dst_convert = make_arg((S, M)) | |
+ src_convert = make_tensor((S, M), dtype=other_dtype, device=device, requires_grad=False) | |
+ yield SampleInput(dst_convert, args=(src_convert,)) | |
+ | |
+ # Non-blocking parameter | |
+ dst_nonblock = make_arg((S, M)) | |
+ src_nonblock = make_arg((S, M)) | |
+ yield SampleInput(dst_nonblock, args=(src_nonblock,), kwargs={'non_blocking': True}) | |
+ yield SampleInput(dst_nonblock, args=(src_nonblock,), kwargs={'non_blocking': False}) | |
+ | |
+ | |
def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs): | |
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) | |
@@ -12847,6 +12898,31 @@ op_db: list[OpInfo] = [ | |
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'), | |
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'), | |
),), | |
+ OpInfo('copy_', | |
+ # copy_ is the in-place tensor method: tensor.copy_(src, non_blocking=False) | |
+ op=lambda x, src, non_blocking=False: x.copy_(src, non_blocking=non_blocking), | |
+ inplace_variant=True, | |
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), | |
+ sample_inputs_func=sample_inputs_copy, | |
+ supports_forward_ad=True, | |
+ supports_fwgrad_bwgrad=True, | |
+ supports_out=False, | |
+ skips=( | |
+ # copy_ requires compatible dtypes/devices which may not be satisfied in all test cases | |
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), | |
+ ),), | |
+ OpInfo('copy', | |
+ # aten::copy functional version: copy(dst_tensor, src_tensor, non_blocking=False) | |
+ op=torch.ops.aten.copy.default, | |
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), | |
+ sample_inputs_func=sample_inputs_copy, | |
+ supports_forward_ad=True, | |
+ supports_fwgrad_bwgrad=True, | |
+ supports_out=False, | |
+ skips=( | |
+ # copy requires compatible dtypes/devices which may not be satisfied in all test cases | |
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'), | |
+ ),), | |
OpInfo('contiguous', | |
op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs), | |
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf), |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment