Last active
January 11, 2025 21:14
-
-
Save Mistobaan/04a81ad19b64434c1bf8ae69dd38d1ba to your computer and use it in GitHub Desktop.
Financial Model to train a small Math reasoning model
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
def compute_cost(num_nodes, gpus_per_node, cost_per_gpu_hour, days): | |
""" | |
Computes the cost of running a certain setup (num_nodes, gpus_per_node) | |
for a specified number of days at a particular cost_per_gpu_hour. | |
:param num_nodes: Number of nodes | |
:param gpus_per_node: Number of GPUs in each node | |
:param cost_per_gpu_hour: Cost per GPU-hour in USD | |
:param days: Total number of days the job runs | |
:return: Total cost in USD | |
""" | |
total_gpus = num_nodes * gpus_per_node | |
hours = days * 24 | |
total_cost = total_gpus * cost_per_gpu_hour * hours | |
return total_cost | |
def main(): | |
# ---------------------------------------------------------------------------- | |
# Assumed GPU-hour costs (these numbers are examples; adjust as needed): | |
# from: https://lambdalabs.com/service/gpu-cloud#pricing | |
# ---------------------------------------------------------------------------- | |
cost_per_gpu_hour_h100 = 2.99 | |
cost_per_gpu_hour_a100_40gb = 1.29 | |
# ---------------------------------------------------------------------------- | |
# Round 1: Using 10 nodes of 8×80GB H100 GPUs for 2 weeks | |
# ---------------------------------------------------------------------------- | |
round1_days = 14 # 2 weeks | |
round1_nodes = 10 | |
round1_gpus_per_node = 8 | |
round1_cost = compute_cost( | |
num_nodes=round1_nodes, | |
gpus_per_node=round1_gpus_per_node, | |
cost_per_gpu_hour=cost_per_gpu_hour_h100, | |
days=round1_days | |
) | |
# ---------------------------------------------------------------------------- | |
# Rounds 2–4: Using 15 nodes of 4×40GB A100 GPUs, each round 3 days (3 rounds) | |
# ---------------------------------------------------------------------------- | |
# Each round is 3 days, total of 3 rounds => 9 days | |
rounds_2_4_days = 3 * 3 # 9 days total | |
rounds_2_4_nodes = 15 | |
rounds_2_4_gpus_per_node = 4 | |
rounds_2_4_cost = compute_cost( | |
num_nodes=rounds_2_4_nodes, | |
gpus_per_node=rounds_2_4_gpus_per_node, | |
cost_per_gpu_hour=cost_per_gpu_hour_a100_40gb, | |
days=rounds_2_4_days | |
) | |
# ---------------------------------------------------------------------------- | |
# Final round: Same 15 nodes of 4×40GB A100 GPUs, 1 week (7 days), | |
# but with 64 MCTS rollouts => we expect higher usage, but | |
# for cost modeling, we simply multiply time * GPU count * cost/hour | |
# ---------------------------------------------------------------------------- | |
final_round_days = 7 | |
final_round_nodes = 15 | |
final_round_gpus_per_node = 4 | |
final_round_cost = compute_cost( | |
num_nodes=final_round_nodes, | |
gpus_per_node=final_round_gpus_per_node, | |
cost_per_gpu_hour=cost_per_gpu_hour_a100_40gb, | |
days=final_round_days | |
) | |
# ---------------------------------------------------------------------------- | |
# Sum up the total cost across all rounds | |
# ---------------------------------------------------------------------------- | |
total_cost = round1_cost + rounds_2_4_cost + final_round_cost | |
# Print summary | |
print("============================================================") | |
print(" FINANCIAL MODEL FOR SELF-EVOLUTION INFERENCE COSTS ") | |
print("============================================================") | |
print(f"Round 1 (H100 x 8, 10 nodes, 2 wks): ${round1_cost:,.2f}") | |
print(f"Rounds 2–4 (A100 40GB x 4, 15 nodes, 9 days): ${rounds_2_4_cost:,.2f}") | |
print(f"Final Round (A100 40GB x 4, 15 nodes, 7 days):${final_round_cost:,.2f}") | |
print("------------------------------------------------------------") | |
print(f"Total estimated cost: ${total_cost:,.2f}") | |
print("============================================================") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment