Skip to content

Instantly share code, notes, and snippets.

@vanbasten23
Last active October 24, 2025 17:03
Show Gist options
  • Select an option

  • Save vanbasten23/64be65a8331dae95e95b4a9d90214691 to your computer and use it in GitHub Desktop.

Select an option

Save vanbasten23/64be65a8331dae95e95b4a9d90214691 to your computer and use it in GitHub Desktop.
# This script demonstrate that under torchax, tensor.copy_(lora_tensor) will not change the sharding of `tensor`.
import jax
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchax
from torchax.interop import jax_view
from jax.sharding import Mesh, NamedSharding, PartitionSpec
from torchax.interop import jax_view, torch_view
from torchax.ops.mappings import t2j
P = PartitionSpec
torchax.enable_globally()
class M(torch.nn.Module):
def __init__(self, a: int):
super().__init__()
self.a = a #
self.param = torch.randn([4,8])
self.fc1 = nn.Linear(8, 4) # weight.shape=[8, 4]
def forward(self, x):
return x @ self.param * self.a
def _convert_to_torchax_and_shard(tensor: torch.Tensor,
sharding: NamedSharding) -> torch.Tensor:
if isinstance(tensor, torchax.tensor.Tensor):
tensor = jax_view(tensor)
else:
tensor = t2j(tensor)
return torch_view(jax.device_put(tensor, sharding))
m = M(2).to('jax')
print(f'after initializing the model, {type(m.param)=}, {type(m.fc1.weight)=}')
# Prints: after initializing the model, type(m.param)=<class 'torch.Tensor'>, type(m.fc1.weight)=<class 'torchax.tensor.Tensor'>
# note that m.param is still torch.Tensor but m.fc1.weight changed to torchax tensor.
num_devices = len(jax.devices())
axis_names = ("data", "model")
mesh_shape = (1, 4)
mesh = jax.make_mesh(mesh_shape, axis_names=axis_names, devices=jax.devices())
sharding = NamedSharding(mesh, P(None, "model"))
m.param = _convert_to_torchax_and_shard(m.param, sharding)
# import pdb; pdb.set_trace()
new_fc1_weight = _convert_to_torchax_and_shard(m.fc1.weight, sharding)
m.fc1.weight = torch.nn.Parameter(new_fc1_weight, requires_grad=False)
print(f"Initial sharding: {jax_view(m.param).sharding=}, {jax_view(m.fc1.weight).sharding=}")
# Prints "Initial sharding: jax_view(m.param).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device), jax_view(m.fc1.weight).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device)"
lora_param = torch.randn([8], device='jax')
m.param[0,:].copy_(lora_param, non_blocking=True)
lora_fc1 = torch.randn([8], device='jax')
m.fc1.weight[0,:].copy_(lora_fc1, non_blocking=True)
print(f"After copy_, sharding: {jax_view(m.param).sharding=}, {jax_view(m.fc1.weight).sharding=}")
# Prints "After copy_, sharding: jax_view(m.param).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device), jax_view(m.fc1.weight).sharding=NamedSharding(mesh=Mesh('data': 1, 'model': 4, axis_types=(Auto, Auto)), spec=PartitionSpec(None, 'model'), memory_kind=device)"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment