Asking GPT3.5 to refactor gpt2() into a parallel prefix over inputs with code from:
def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, chunk_size):
n_chunks = len(inputs) // chunk_size
chunked_inputs = np.split(inputs, n_chunks)
embeddings = []
for i, chunk in enumerate(chunked_inputs):
emb_chunk = wte[chunk] + wpe[range(i*chunk_size, (i+1)*chunk_size)]
embeddings.append(emb_chunk)
embeddings = np.concatenate(embeddings, axis=0)
for block in blocks:
attn_outputs = []
for i, chunk in enumerate(np.split(embeddings, n_chunks)):
attn_chunk = transformer_block(chunk, **block, n_head=n_head)
attn_outputs.append(attn_chunk)
attn_outputs = np.concatenate(attn_outputs, axis=0)
attn_outputs_prefix = np.zeros_like(attn_outputs)
for i in range(1, n_chunks):
attn_outputs_prefix[i*chunk_size:(i+1)*chunk_size] = \
attn_outputs_prefix[(i-1)*chunk_size:i*chunk_size] + \
attn_outputs[i*chunk_size:(i+1)*chunk_size]
embeddings = embeddings + attn_outputs_prefix
final_output = layer_norm(embeddings, **ln_f) @ wte.T
return final_output
GPT3.5 suggestion to permute the token space given a list of prompts to speed inference.