Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save ashvinnihalani/a520e455b2d7f53ac376c80ea349e8a5 to your computer and use it in GitHub Desktop.
Save ashvinnihalani/a520e455b2d7f53ac376c80ea349e8a5 to your computer and use it in GitHub Desktop.
Param Groups Fix
min_lr=config.min_lr,
decoupled_lr=config.decoupled_lr,
decoupled_min_lr=config.decoupled_min_lr,
)
param_groups = list(filter(filter_fn, param_groups))
param_groups.sort(key=lambda g: g.get("wd_mult", 1.0), reverse=True)
buffers = {}
for model_chunk_idx, model_chunk in enumerate(model_chunks):
if hasattr(model_chunk, buffer_name):
buffers[model_chunk_idx + model_chunk_offset] = getattr(model_chunk, buffer_name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment