Skip to content

Instantly share code, notes, and snippets.

View kouroshHakha's full-sized avatar

kourosh hakhamaneshi kouroshHakha

View GitHub Profile
@kouroshHakha
kouroshHakha / bw_test.py
Last active January 30, 2025 23:15
Test peer-to-peer GPU bandwidth
import ray
import torch
import time
import numpy as np
from ray.util.collective.types import Backend
from ray.util.collective.const import get_store_name
import ray.util.collective as collective
import os
@ray.remote(num_gpus=1)
@kouroshHakha
kouroshHakha / create_test_dataset.py
Created December 11, 2023 05:06
JSON Mode and Function-calling on Open LLMs Blogpost
import datasets
import re
import json
import tqdm
ds = datasets.load_dataset("glaiveai/glaive-function-calling-v2", split="train")
out_ds_size = 100
class UserAssistantNotFoundError(Exception):
@kouroshHakha
kouroshHakha / fp16_vs_bf16_model_loading.py
Created October 20, 2023 17:31
Studies the diff on precision when loading in fp16 or bf16
from safetensors import safe_open
import torch
import numpy as np
import matplotlib.pyplot as plt
tensors = {}
model_ckpt = "/home/ray/default/7b-chat-lora-ckpt/adapter_model.safetensors"
with safe_open(model_ckpt, framework="pt") as f:
for k in f.keys():
tensors[k] = f.get_tensor(k)
import torch
import torch.nn.functional as F
# Create random inputs for testing
batch_size = 128
seq_length = 512
embed_dim = 64
enable_math = False
query = torch.rand(batch_size, seq_length, embed_dim, device="cuda", requires_grad=True)
@kouroshHakha
kouroshHakha / bm_attn.py
Created August 22, 2023 03:58
benchmark_flash
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import pickle
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
import numpy as np
import pandas as pd
import os
from ray.train.huggingface import HuggingFacePredictor
import pandas as pd
import re
(ray) kourosh@kourosh-JRFKXJ33VL auto_prompting % python main.py
============== Trial 1 ===============
Current Prompt format:
I want you to act as a linux terminal. I will type commands and you will reply with what the terminal should show.
Failed Example:
Input:
pwd
Output:
import torch
import torch.nn.functional as F
import unittest
import xformers.ops as xops
import math
import time
MAX_ITER = 100
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from pprint import pprint
import time
import gc
import matplotlib.pyplot as plt
import numpy as np
model_base = "gpt2"
#
# A fatal error has been detected by the Java Runtime Environment:
#
# SIGSEGV (0xb) at pc=0x00007f7a24d482ab, pid=38579, tid=0x00007f7a24b41340
#
# JRE version: OpenJDK Runtime Environment (Zulu 8.62.0.19-CA-linux64) (8.0_332-b09) (build 1.8.0_332-b09)
# Java VM: OpenJDK 64-Bit Server VM (25.332-b09 mixed mode linux-amd64 compressed oops)
# Problematic frame:
# C [libpthread.so.0+0x142ab] raise+0xcb
#