This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import numpy as np | |
from torch.distributions import Categorical | |
# Define the policy network | |
class PolicyNetwork(nn.Module): | |
def __init__(self, state_dim, action_dim): | |
super(PolicyNetwork, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// concurrent_queue.cpp | |
#include "concurrent_queue.h" | |
#include <iostream> | |
template <typename T> | |
void ConcurrentQueue<T>::push(T value) { | |
std::lock_guard<std::mutex> lock(mutex); | |
queue.push(std::move(value)); | |
cv.notify_one(); // Notify one waiting thread if it's waiting | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// concurrent_queue.h | |
#ifndef CONCURRENT_QUEUE_H | |
#define CONCURRENT_QUEUE_H | |
#include <torch/torch.h> | |
#include <queue> | |
#include <mutex> | |
#include <condition_variable> | |
template <typename T> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
from torch_geometric.data import Data | |
from torch_geometric.nn.conv import DNAConv | |
import numpy as np | |
class DynamicGraphNN(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, num_layers = 2, heads=1, groups=1): | |
super().__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import numpy as np | |
import pandas as pd | |
from torch.utils.data import DataLoader, Dataset, TensorDataset | |
# Assuming the model is a simple neural network for regression/classification | |
class SimpleNN(nn.Module): | |
def __init__(self, input_size, output_size): | |
super(SimpleNN, self).__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn.functional as F | |
def sample_reweight(loss_curve, loss_values, k_th, alpha1=1.0, alpha2=1.0, bins_sr=10, decay=0.9): | |
""" | |
The SR module of Double Ensemble using PyTorch. | |
Args: | |
- loss_curve: Tensor, shape (N, T), the loss curve for each sample over training iterations. | |
- loss_values: Tensor, shape (N,), the loss of the current ensemble on each sample. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import argparse | |
import time | |
import torch | |
import torch.distributed as dist | |
import transformers | |
from transformers import LlamaForCausalLM, LlamaTokenizer | |
import colossalai | |
from colossalai.inference import CaiInferEngine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from vllm import LLM, SamplingParams | |
import torch | |
from torch import distributed as dist | |
import time | |
from tqdm import tqdm | |
import numpy as np | |
# # Create an LLM. | |
llm = LLM( | |
# model="/home/lclcq/share/llama-7b", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import warnings | |
import time | |
import torch | |
import torch.distributed as dist | |
import argparse | |
from packaging import version | |
import colossalai |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from vllm import LLM, SamplingParams | |
import torch | |
from torch import distributed as dist | |
import time | |
from tqdm import tqdm | |
import numpy as np | |
# # Create an LLM. | |
llm = LLM( | |
model="/home/lclcq/share/llama-7b", |
NewerOlder