Last active
October 24, 2025 17:03
-
-
Save vanbasten23/64be65a8331dae95e95b4a9d90214691 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # This script demonstrate that under torchax, tensor.copy_(lora_tensor) will not change the sharding of `tensor`. | |
| import jax | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchax | |
| from torchax.interop import jax_view | |
| from jax.sharding import Mesh, NamedSharding, PartitionSpec | |
| from torchax.interop import jax_view, torch_view | |
| from torchax.ops.mappings import t2j | |
| P = PartitionSpec | |
| torchax.enable_globally() | |
| class M(torch.nn.Module): | |
| def __init__(self, a: int): | |
| super().__init__() | |
| self.a = a # | |
| self.param = torch.randn([4,8]) | |
| self.fc1 = nn.Linear(8, 4) # weight.shape=[8, 4] | |
| def forward(self, x): | |
| return x @ self.param * self.a | |
| def _convert_to_torchax_and_shard(tensor: torch.Tensor, | |
| sharding: NamedSharding) -> torch.Tensor: | |
| if isinstance(tensor, torchax.tensor.Tensor): | |
| tensor = jax_view(tensor) | |
| else: | |
| tensor = t2j(tensor) | |
| return torch_view(jax.device_put(tensor, sharding)) | |
| m = M(2).to('jax') | |
| print(f'after initializing the model, {type(m.param)=}, {type(m.fc1.weight)=}') | |
| # Prints: after initializing the model, type(m.param)=<class 'torch.Tensor'>, type(m.fc1.weight)=<class 'torchax.tensor.Tensor'> | |
| # note that m.param is still torch.Tensor but m.fc1.weight changed to torchax tensor. | |
| num_devices = len(jax.devices()) | |
| axis_names = ("data", "model") | |
| mesh_shape = (1, 4) | |
| mesh = jax.make_mesh(mesh_shape, axis_names=axis_names, devices=jax.devices()) | |
| sharding = NamedSharding(mesh, P(None, "model")) | |
| m.param = _convert_to_torchax_and_shard(m.param, sharding) | |
| # import pdb; pdb.set_trace() | |
| new_fc1_weight = _convert_to_torchax_and_shard(m.fc1.weight, sharding) | |
| m.fc1.weight = torch.nn.Parameter(new_fc1_weight, requires_grad=False) | |
| print(f"Initial sharding: {jax_view(m.param).sharding=}, {jax_view(m.fc1.weight).sharding=}") | |
| # Prints "Initial sharding: jax_view(m.param).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device), jax_view(m.fc1.weight).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device)" | |
| lora_param = torch.randn([8], device='jax') | |
| m.param[0,:].copy_(lora_param, non_blocking=True) | |
| lora_fc1 = torch.randn([8], device='jax') | |
| m.fc1.weight[0,:].copy_(lora_fc1, non_blocking=True) | |
| print(f"After copy_, sharding: {jax_view(m.param).sharding=}, {jax_view(m.fc1.weight).sharding=}") | |
| # Prints "After copy_, sharding: jax_view(m.param).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device), jax_view(m.fc1.weight).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device)" | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment