Created
October 2, 2025 23:33
-
-
Save andrewor14/378a86485e2bc167f851bfd648625b70 to your computer and use it in GitHub Desktop.
UnslothGRPOTrainer excerpt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| class _UnslothGRPOTrainer(Trainer): | |
| ... | |
| def _sync_fsdp1_params_to_vllm(self, module: nn.Module, prefix: str = "", visited=None): | |
| """Memory-efficient post-order traversal of FSDP modules to extract full parameters and sync with vLLM.""" | |
| # For FSDP1, we need to recurse into children and also use summon_full_params | |
| if visited is None: | |
| visited = set() | |
| for child_name, child_module in module.named_children(): | |
| child_prefix = f"{prefix}.{child_name}" if prefix else child_name | |
| self._sync_fsdp1_params_to_vllm( | |
| child_module, prefix=child_prefix, visited=visited | |
| ) # recurse into the child | |
| if isinstance(module, FSDP): | |
| with FSDP.summon_full_params(module, recurse=False, writeback=False): | |
| for param_name, param in module.named_parameters(): | |
| full_name = f"{prefix}.{param_name}" if prefix else param_name | |
| full_name = self._fix_param_name_to_vllm(full_name, extra_prefixes=["_fsdp_wrapped_module."]) | |
| if full_name in visited: | |
| continue # skip FSDP subtrees already traversed | |
| visited.add(full_name) | |
| if self.vllm_mode == "server" and self.accelerator.is_main_process: | |
| self.vllm_client.update_named_param(full_name, param.data) | |
| elif self.vllm_mode == "colocate": | |
| pass | |
| pass | |
| def _sync_fsdp2_params_to_vllm(self, module: nn.Module): | |
| # For FSDP2, module already covers all parameters, so no need for recursion | |
| for name, param in module.items(): | |
| if param.is_cpu: | |
| param = param.to(torch.device("cuda")) | |
| param = param.full_tensor() | |
| if self.vllm_mode == "server" and self.accelerator.is_main_process: | |
| self.vllm_client.update_named_param(name, param) | |
| elif self.vllm_mode == "colocate": | |
| pass | |
| pass | |
| def _move_model_to_vllm(self, *args, **kwargs): return None |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment