This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
def stable_softmax(x, axis=None): | |
"""taken from scipy.special.softmax""" | |
x_max = np.amax(x, axis=axis, keepdims=True) | |
exp_x_shifted = np.exp(x - x_max) | |
return exp_x_shifted / np.sum(exp_x_shifted, axis=axis, keepdims=True) | |
def get_prob(arr: np.ndarray, temp: float) -> np.ndarray: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2024 AllenAI. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from collections import deque | |
import queue | |
import time | |
import numpy as np | |
import ray | |
from vllm import SamplingParams, LLM | |
import wandb | |
from open_instruct.dataset_transformation import TokenizerConfig, get_cached_dataset_rlvr | |
from open_instruct.vllm_utils3 import create_vllm_engines | |
from transformers import HfArgumentParser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import torch.nn.functional as F | |
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M") | |
tokenizer.add_special_tokens({"pad_token": "<PAD>"}) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
device = torch.device("cpu") | |
model.to(device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import numpy as np | |
p = 100 # padding token id | |
o = 1 # observation (prompt / input ids) | |
a = 2 # action (response ids) | |
queries = [ | |
[p, p, o, o, o], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Create target distribution (fixed) | |
target_logits = torch.randn(10) | |
target_log_probs = torch.log_softmax(target_logits, dim=0) | |
# Create learnable distribution | |
learnable_logits = nn.Parameter(torch.rand_like(target_logits)) # Initialize randomly |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"name": "material-ui-nextjs-ts", | |
"version": "5.0.0", | |
"lockfileVersion": 3, | |
"requires": true, | |
"packages": { | |
"": { | |
"name": "material-ui-nextjs-ts", | |
"version": "5.0.0", | |
"dependencies": { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Taken and modified from https://github.com/huggingface/trl | |
# Copyright 2024 The AllenAI Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0" | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from huggingface_hub import HfApi, snapshot_download |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Adapted from | |
# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py | |
# Copyright 2024 The vLLM team. | |
# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. | |
# | |
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
# and OPT implementations in this library. It has been modified from its | |
# original forms to accommodate minor architectural differences compared | |
# to GPT-NeoX and OPT used by the Meta AI team that trained the model. |
NewerOlder