Skip to content

Instantly share code, notes, and snippets.

View xiabingquan's full-sized avatar

xiabingquan xiabingquan

  • Beijing, China
View GitHub Profile
@xiabingquan
xiabingquan / selective_sort.rs
Created February 11, 2025 14:37
A minimal example of selective sort in Rust
fn selective_sort<T: Ord + Clone>(arr: &mut [T]) {
let n = arr.len();
for i in 0..n {
let mut max_val = arr[i].clone();
let mut max_idx = i;
for j in i..n {
if arr[j] > max_val {
max_val = arr[j].clone();
max_idx = j;
}
@xiabingquan
xiabingquan / insert_sort.rs
Created February 11, 2025 15:14
A minimal example of insert sort in Rust
fn insert_sort<T: Ord + Clone>(arr: &mut [T]) {
for i in 0..arr.len() - 1 {
let val = arr[i + 1].clone();
let mut j = i as isize;
while j >= 0 && val > arr[j as usize] {
j -= 1;
}
for k in ((j+1) as usize..i+1).rev() {
arr[k + 1] = arr[k].clone();
}
@xiabingquan
xiabingquan / time_decorator.py
Created February 17, 2025 04:28
A universal Python decorator function to record the runtime of arbitrary functions (Copied from MegatronLM)
import time
import logging
from typing import Callable
from functools import wraps
# Reference: https://github.com/NVIDIA/Megatron-LM/blob/9a496c976e12a62ce8e39e14496e52a985588730/megatron/core/dist_checkpointing/strategies/two_stage.py#L35
def timed(verbose=True) -> Callable:
def timed_dec(fn):
name = fn.__name__
@xiabingquan
xiabingquan / interleaved_print.py
Created March 11, 2025 10:48
A print function used in distributed environment
# References: https://github.com/huggingface/picotron_tutorial/blob/master/step1_modeling/utils.py
def print(*args, is_print_rank=True, **kwargs):
""" solves multi-process interleaved print problem """
if not is_print_rank: return
with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
builtins.print(*args, **kwargs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)
@xiabingquan
xiabingquan / calc_tokens_by_filesize.py
Created March 11, 2025 12:10
Given a binary file of tokens in Megatron format, calculate its number of tokens. Expect each token occupies exactly 4 bytes.
import re
def calc_num_token_of_bin(filesize: str) -> int:
m = re.findall(r"(\d+)([G|T])", filesize)
assert len(m) == 1, f"Expect string like '10G', '1T', but got {filesize}"
m = m[0]
digit, dim = int(m[0]), m[1]
pow = 2 if dim == 'G' else 3
return digit / 4 * (1024 / 1000) ** pow