Skip to content

Instantly share code, notes, and snippets.

@thomwolf
thomwolf / gradient_accumulation.py
Last active November 23, 2024 20:53
PyTorch gradient accumulation training loop
model.zero_grad() # Reset gradients tensors
for i, (inputs, labels) in enumerate(training_set):
predictions = model(inputs) # Forward pass
loss = loss_function(predictions, labels) # Compute loss function
loss = loss / accumulation_steps # Normalize our loss (if averaged)
loss.backward() # Backward pass
if (i+1) % accumulation_steps == 0: # Wait for several backward steps
optimizer.step() # Now we can do an optimizer step
model.zero_grad() # Reset gradients tensors
if (i+1) % evaluation_steps == 0: # Evaluate the model when we...
@cgoldberg
cgoldberg / geckodriver-install.sh
Last active August 18, 2023 10:20
download and install latest geckodriver for linux or mac (selenium webdriver)
#!/bin/bash
# download and install latest geckodriver for linux or mac.
# required for selenium to drive a firefox browser.
install_dir="/usr/local/bin"
json=$(curl -s https://api.github.com/repos/mozilla/geckodriver/releases/latest)
if [[ $(uname) == "Darwin" ]]; then
url=$(echo "$json" | jq -r '.assets[].browser_download_url | select(contains("macos"))')
elif [[ $(uname) == "Linux" ]]; then
url=$(echo "$json" | jq -r '.assets[].browser_download_url | select(contains("linux64"))')