Skip to content

Instantly share code, notes, and snippets.

@msaroufim
Created August 15, 2025 04:51
Show Gist options
  • Save msaroufim/6a6e3cfeb725d3f731ae6e42bf88bcfb to your computer and use it in GitHub Desktop.
Save msaroufim/6a6e3cfeb725d3f731ae6e42bf88bcfb to your computer and use it in GitHub Desktop.
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 938cb7dd97a..d3ac1369e6a 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -7443,6 +7443,57 @@ def reference_inputs_clone_contiguous(op, device, dtype, requires_grad, **kwargs
yield SampleInput(a, kwargs={'memory_format': torch.channels_last_3d})
+def sample_inputs_copy(op_info, device, dtype, requires_grad, **kwargs):
+ """Sample inputs for copy and copy_ operations.
+
+ copy(dst_tensor, src_tensor, non_blocking=False) copies data from src to dst.
+ For copy_: dst_tensor.copy_(src_tensor, non_blocking=False)
+ """
+ make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
+
+ # Basic same-shape copies
+ dst = make_arg((S, M))
+ src = make_arg((S, M))
+ yield SampleInput(dst, args=(src,))
+
+ # Scalar cases
+ dst_scalar = make_arg(())
+ src_scalar = make_arg(())
+ yield SampleInput(dst_scalar, args=(src_scalar,))
+
+ # Broadcasting cases - src can be broadcast to dst shape
+ dst_broadcast = make_arg((S, M, L))
+ src_broadcast1 = make_arg((1, M, L)) # broadcast in first dim
+ src_broadcast2 = make_arg((S, 1, L)) # broadcast in second dim
+ src_broadcast3 = make_arg((1,)) # broadcast scalar to tensor
+ yield SampleInput(dst_broadcast, args=(src_broadcast1,))
+ yield SampleInput(dst_broadcast, args=(src_broadcast2,))
+ yield SampleInput(dst_broadcast, args=(src_broadcast3,))
+
+ # Different dtypes that are compatible
+ if dtype != torch.bool: # bool has limited conversion support
+ if dtype.is_floating_point:
+ # Float to float conversions
+ other_dtype = torch.float64 if dtype != torch.float64 else torch.float32
+ elif dtype.is_complex:
+ # Complex to complex conversions
+ other_dtype = torch.complex128 if dtype != torch.complex128 else torch.complex64
+ else:
+ # Int to int conversions
+ other_dtype = torch.int64 if dtype != torch.int64 else torch.int32
+
+ if other_dtype != dtype:
+ dst_convert = make_arg((S, M))
+ src_convert = make_tensor((S, M), dtype=other_dtype, device=device, requires_grad=False)
+ yield SampleInput(dst_convert, args=(src_convert,))
+
+ # Non-blocking parameter
+ dst_nonblock = make_arg((S, M))
+ src_nonblock = make_arg((S, M))
+ yield SampleInput(dst_nonblock, args=(src_nonblock,), kwargs={'non_blocking': True})
+ yield SampleInput(dst_nonblock, args=(src_nonblock,), kwargs={'non_blocking': False})
+
+
def sample_inputs_sum_to_size(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@@ -12847,6 +12898,31 @@ op_db: list[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_numpy_ref_mps'),
),),
+ OpInfo('copy_',
+ # copy_ is the in-place tensor method: tensor.copy_(src, non_blocking=False)
+ op=lambda x, src, non_blocking=False: x.copy_(src, non_blocking=non_blocking),
+ inplace_variant=True,
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+ sample_inputs_func=sample_inputs_copy,
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ supports_out=False,
+ skips=(
+ # copy_ requires compatible dtypes/devices which may not be satisfied in all test cases
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+ ),),
+ OpInfo('copy',
+ # aten::copy functional version: copy(dst_tensor, src_tensor, non_blocking=False)
+ op=torch.ops.aten.copy.default,
+ dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
+ sample_inputs_func=sample_inputs_copy,
+ supports_forward_ad=True,
+ supports_fwgrad_bwgrad=True,
+ supports_out=False,
+ skips=(
+ # copy requires compatible dtypes/devices which may not be satisfied in all test cases
+ DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
+ ),),
OpInfo('contiguous',
op=lambda x, *args, **kwargs: x.contiguous(*args, **kwargs),
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16, torch.chalf),
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment