This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import random | |
import ray | |
from transformers import RagConfig, RagRetriever, RagTokenizer | |
from transformers.file_utils import requires_datasets, requires_faiss | |
from transformers.models.rag.retrieval_rag import CustomHFIndex | |
from transformers import ( | |
DPRContextEncoderTokenizerFast) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if isOtherThreadIndexBusy: | |
if not threadHandle_index.is_alive(): | |
saved_dataset_shards=[] | |
for address in data_shard_addressses: | |
saved_dataset_shards.append(load_from_disk(address)) | |
concat=concatenate_datasets(saved_dataset_shards) | |
concat.save_to_disk(self.config.passages_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def training_step(self, batch, batch_idx) -> Dict: | |
global stepCount | |
global isEmUpdateBusy | |
global isAddIndexBusy | |
global processes | |
global isOtherThreadIndexBusy | |
if (self.trainer.global_rank==0): #we initialize the embedding computing parrele process only on master DDP. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#check training_step hook in pytorch-lightning for further details in the function. | |
def training_step(self, batch, batch_idx) -> Dict: | |
global isEmUpdateBusy #global variable used in the parrelle process of embeddings computation | |
if (self.trainer.global_rank==0): #we initialize the embedding computing parrele process only on master DDP. | |
if (not batch_idx==0) and (batch_idx%500==0): #We want our embeddings to get updated in every 500th step | |
######we can assign any number of free GPUs to update the embeddings (optional)########## | |
free_gpu_list=[] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Trains an agent with (stochastic) Policy Gradients on Pong. Uses OpenAI Gym. """ | |
import numpy as np | |
import cPickle as pickle | |
import gym | |
# hyperparameters | |
H = 200 # number of hidden layer neurons | |
batch_size = 10 # every how many episodes to do a param update? | |
learning_rate = 1e-4 | |
gamma = 0.99 # discount factor for reward |