Skip to content

Instantly share code, notes, and snippets.

@stjordanis
Forked from chadbrewbaker/ParallelGPT.md
Created April 3, 2023 13:07
Show Gist options
  • Save stjordanis/be398bb224b56c0a7839a9b346ee9f31 to your computer and use it in GitHub Desktop.
Save stjordanis/be398bb224b56c0a7839a9b346ee9f31 to your computer and use it in GitHub Desktop.
ParallelGPT

Asking GPT3.5 to refactor gpt2() into a parallel prefix over inputs with code from:

https://github.com/jaymody/picoGPT/blob/3b7f4d180bb125a76b2f4f7b3a74268e5ec5f131/gpt2_pico.py#L37

def gpt2(inputs, wte, wpe, blocks, ln_f, n_head, chunk_size):
    n_chunks = len(inputs) // chunk_size
    chunked_inputs = np.split(inputs, n_chunks)
    embeddings = []
    for i, chunk in enumerate(chunked_inputs):
        emb_chunk = wte[chunk] + wpe[range(i*chunk_size, (i+1)*chunk_size)]
        embeddings.append(emb_chunk)
    embeddings = np.concatenate(embeddings, axis=0)

    for block in blocks:
        attn_outputs = []
        for i, chunk in enumerate(np.split(embeddings, n_chunks)):
            attn_chunk = transformer_block(chunk, **block, n_head=n_head)
            attn_outputs.append(attn_chunk)
        attn_outputs = np.concatenate(attn_outputs, axis=0)

        attn_outputs_prefix = np.zeros_like(attn_outputs)
        for i in range(1, n_chunks):
            attn_outputs_prefix[i*chunk_size:(i+1)*chunk_size] = \
                attn_outputs_prefix[(i-1)*chunk_size:i*chunk_size] + \
                attn_outputs[i*chunk_size:(i+1)*chunk_size]

        embeddings = embeddings + attn_outputs_prefix

    final_output = layer_norm(embeddings, **ln_f) @ wte.T
    return final_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment