Last active
April 12, 2023 20:08
-
-
Save OhadRubin/f02da6476fd21d8594a09579087219af to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#The following code is the result of taking the code for gpt-j-6b and changing the model name and config name to gpt-neox along with concating the gpt-neox code from pytorch | |
# our job is to modify it such that it will be correct for gpt-neox, since right now it is just a copy of gpt-j-6b (with names replaced) | |
from functools import partial | |
from typing import Optional, Tuple | |
import flax.linen as nn | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze | |
from flax.linen import combine_masks, make_causal_mask | |
from flax.linen.attention import dot_product_attention_weights | |
from flax.traverse_util import flatten_dict, unflatten_dict | |
from jax import lax | |
from ...modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput | |
from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring | |
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging | |
from .configuration_gpt_neox import GPTNeoXConfig | |
logger = logging.get_logger(__name__) | |
_CHECKPOINT_FOR_DOC = "gpt_neox" | |
_CONFIG_FOR_DOC = "GPTNeoXConfig" | |
GPTNeoX_START_DOCSTRING = r""" | |
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
etc.) | |
This model is also a Flax Linen | |
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a | |
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. | |
Finally, this model supports inherent JAX features such as: | |
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) | |
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) | |
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) | |
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) | |
Parameters: | |
config ([`GPTNeoXConfig`]): Model configuration class with all the parameters of the model. | |
Initializing with a config file does not load the weights associated with the model, only the | |
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. | |
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): | |
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and | |
`jax.numpy.bfloat16` (on TPUs). | |
This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If | |
specified all the computation will be performed with the given `dtype`. | |
**Note that this only specifies the dtype of the computation and does not influence the dtype of model | |
parameters.** | |
If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and | |
[`~FlaxPreTrainedModel.to_bf16`]. | |
""" | |
GPTNeoX_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`): | |
`input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |
config.max_position_embeddings - 1]`. | |
past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): | |
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast | |
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
def create_sinusoidal_positions(num_pos, dim): | |
inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim)) | |
sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32") | |
sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp) | |
sentinel = dim // 2 + dim % 2 | |
out = np.zeros((num_pos, dim)) | |
out[:, 0:sentinel] = sin | |
out[:, sentinel:] = cos | |
return jnp.array(out) | |
def rotate_every_two(tensor): | |
rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1) | |
rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,)) | |
return rotate_half_tensor | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
def pt_apply_rotary_pos_emb(q, k, cos, sin, position_ids): | |
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] | |
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) | |
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) | |
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def apply_rotary_pos_emb(tensor, sincos): | |
sin_pos, cos_pos = sincos | |
sin_pos = sin_pos[:, :, None, :].repeat(2, 3) | |
cos_pos = cos_pos[:, :, None, :].repeat(2, 3) | |
return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos) | |
class RotaryEmbedding(torch.nn.Module): | |
def __init__(self, dim, max_position_embeddings, base=10000, device=None): | |
super().__init__() | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
# Build here to make `torch.jit.trace` work. | |
self.max_seq_len_cached = max_position_embeddings | |
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Different from paper, but it uses a different permutation in order to obtain the same calculation | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.cos_cached = emb.cos()[None, None, :, :] | |
self.sin_cached = emb.sin()[None, None, :, :] | |
def forward(self, x, seq_len=None): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case. | |
if seq_len > self.max_seq_len_cached: | |
self.max_seq_len_cached = seq_len | |
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
# Different from paper, but it uses a different permutation in order to obtain the same calculation | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.cos_cached = emb.cos()[None, None, :, :] | |
self.sin_cached = emb.sin()[None, None, :, :] | |
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) | |
class GPTNeoXAttention(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.num_attention_heads = config.num_attention_heads | |
self.hidden_size = config.hidden_size | |
self.head_size = self.hidden_size // self.num_attention_heads | |
self.rotary_ndims = int(self.head_size * config.rotary_pct) | |
max_positions = config.max_position_embeddings | |
self.register_buffer( | |
"bias", | |
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view( | |
1, 1, max_positions, max_positions | |
), | |
) | |
self.register_buffer("masked_bias", torch.tensor(-1e9)) | |
self.rotary_emb = RotaryEmbedding( | |
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base | |
) | |
self.norm_factor = torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()) | |
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) | |
self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
def forward( | |
self, | |
hidden_states: torch.FloatTensor, | |
attention_mask: torch.FloatTensor, | |
position_ids: torch.LongTensor, | |
head_mask: Optional[torch.FloatTensor] = None, | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
use_cache: Optional[bool] = False, | |
output_attentions: Optional[bool] = False, | |
): | |
has_layer_past = layer_past is not None | |
# Compute QKV | |
# Attention heads [batch, seq_len, hidden_size] | |
# --> [batch, seq_len, (np * 3 * head_size)] | |
qkv = self.query_key_value(hidden_states) | |
# [batch, seq_len, (num_heads * 3 * head_size)] | |
# --> [batch, seq_len, num_heads, 3 * head_size] | |
new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) | |
qkv = qkv.view(*new_qkv_shape) | |
# [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] | |
query = qkv[..., : self.head_size].permute(0, 2, 1, 3) | |
key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) | |
value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) | |
# Compute rotary embeddings on rotary_ndims | |
query_rot = query[..., : self.rotary_ndims] | |
query_pass = query[..., self.rotary_ndims :] | |
key_rot = key[..., : self.rotary_ndims] | |
key_pass = key[..., self.rotary_ndims :] | |
# Compute token offset for rotary embeddings (when decoding) | |
seq_len = key.shape[-2] | |
if has_layer_past: | |
seq_len += layer_past[0].shape[-2] | |
cos, sin = self.rotary_emb(value, seq_len=seq_len) | |
query, key = pt_apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) | |
query = torch.cat((query, query_pass), dim=-1) | |
key = torch.cat((key, key_pass), dim=-1) | |
# Cache QKV values | |
if has_layer_past: | |
past_key = layer_past[0] | |
past_value = layer_past[1] | |
key = torch.cat((past_key, key), dim=-2) | |
value = torch.cat((past_value, value), dim=-2) | |
present = (key, value) if use_cache else None | |
# Compute attention | |
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) | |
# Reshape outputs | |
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) | |
attn_output = self.dense(attn_output) | |
outputs = (attn_output, present) | |
if output_attentions: | |
outputs += (attn_weights,) | |
return outputs | |
@classmethod | |
def _split_heads(cls, tensor, num_attention_heads, attn_head_size): | |
""" | |
Splits hidden dim into attn_head_size and num_attention_heads | |
""" | |
# tensor: [bs, seq_len, hidden_size] | |
new_shape = tensor.size()[:-1] + (num_attention_heads, attn_head_size) | |
# -> [bs, seq_len, num_attention_heads, attn_head_size] | |
tensor = tensor.view(new_shape) | |
# -> [bs, num_attention_heads, seq_len, attn_head_size] | |
tensor = tensor.permute(0, 2, 1, 3) | |
return tensor | |
@classmethod | |
def _merge_heads(cls, tensor, num_attention_heads, attn_head_size): | |
""" | |
Merges attn_head_size dim and num_attn_heads dim into hidden dim | |
""" | |
# tensor [bs, num_attention_heads, seq_len, attn_head_size] | |
tensor = tensor.permute(0, 2, 1, 3).contiguous() | |
# -> [bs, seq_len, num_attention_heads, attn_head_size] | |
tensor = tensor.view(tensor.size(0), tensor.size(1), num_attention_heads * attn_head_size) | |
# -> [bs, seq_len, hidden_size] | |
return tensor | |
def _attn(self, query, key, value, attention_mask=None, head_mask=None): | |
# q, k, v: [bs, num_attention_heads, seq_len, attn_head_size] | |
# compute causal mask from causal mask buffer | |
batch_size, num_attention_heads, query_length, attn_head_size = query.size() | |
key_length = key.size(-2) | |
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] | |
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) | |
key = key.view(batch_size * num_attention_heads, key_length, attn_head_size) | |
attn_scores = torch.zeros( | |
batch_size * num_attention_heads, | |
query_length, | |
key_length, | |
dtype=query.dtype, | |
device=key.device, | |
) | |
attn_scores = torch.baddbmm( | |
attn_scores, | |
query, | |
key.transpose(1, 2), | |
beta=1.0, | |
alpha=(torch.tensor(1.0, dtype=self.norm_factor.dtype, device=self.norm_factor.device) / self.norm_factor), | |
) | |
attn_scores = attn_scores.view(batch_size, num_attention_heads, query_length, key_length) | |
mask_value = torch.finfo(attn_scores.dtype).min | |
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. | |
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` | |
mask_value = torch.tensor(mask_value, dtype=attn_scores.dtype).to(attn_scores.device) | |
attn_scores = torch.where(causal_mask, attn_scores, mask_value) | |
if attention_mask is not None: | |
# Apply the attention mask | |
attn_scores = attn_scores + attention_mask | |
attn_weights = nn.functional.softmax(attn_scores, dim=-1) | |
attn_weights = attn_weights.to(value.dtype) | |
# Mask heads if we want to | |
if head_mask is not None: | |
attn_weights = attn_weights * head_mask | |
attn_output = torch.matmul(attn_weights, value) | |
return attn_output, attn_weights | |
class FlaxGPTNeoXAttention(nn.Module): | |
config: GPTNeoXConfig | |
dtype: jnp.dtype = jnp.float32 | |
causal: bool = True | |
is_cross_attention: bool = False | |
def setup(self): | |
config = self.config | |
self.embed_dim = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.head_dim = self.embed_dim // self.num_heads | |
self.rotary_dim = config.rotary_dim | |
dense = partial( | |
nn.Dense, | |
self.embed_dim, | |
use_bias=False, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range), | |
) | |
self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() | |
self.out_proj = dense() | |
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) | |
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool") | |
pos_embd_dim = self.rotary_dim or self.embed_dim | |
self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim) | |
def _split_heads(self, hidden_states): | |
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) | |
def _merge_heads(self, hidden_states): | |
return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) | |
@nn.compact | |
def _concatenate_to_cache(self, key, value, query, attention_mask): | |
""" | |
This function takes projected key, value states from a single input token and concatenates the states to cached | |
states from previous steps. This function is slighly adapted from the official Flax repository: | |
https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 | |
""" | |
# detect if we're initializing by absence of existing cache data. | |
is_initialized = self.has_variable("cache", "cached_key") | |
cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) | |
cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) | |
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) | |
if is_initialized: | |
*batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape | |
# update key, value caches with our new 1d spatial slices | |
cur_index = cache_index.value | |
indices = (0,) * len(batch_dims) + (cur_index, 0, 0) | |
key = lax.dynamic_update_slice(cached_key.value, key, indices) | |
value = lax.dynamic_update_slice(cached_value.value, value, indices) | |
cached_key.value = key | |
cached_value.value = value | |
num_updated_cache_vectors = query.shape[1] | |
cache_index.value = cache_index.value + num_updated_cache_vectors | |
# causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. | |
pad_mask = jnp.broadcast_to( | |
jnp.arange(max_length) < cur_index + num_updated_cache_vectors, | |
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), | |
) | |
attention_mask = combine_masks(pad_mask, attention_mask) | |
return key, value, attention_mask | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask, | |
position_ids, | |
deterministic: bool = True, | |
init_cache: bool = False, | |
output_attentions: bool = False, | |
): | |
query = self.q_proj(hidden_states) | |
key = self.k_proj(hidden_states) | |
value = self.v_proj(hidden_states) | |
query = self._split_heads(query) | |
key = self._split_heads(key) | |
value = self._split_heads(value) | |
sincos = jnp.take(self.embed_positions, position_ids, axis=0) | |
sincos = jnp.split(sincos, 2, axis=-1) | |
if self.rotary_dim is not None: | |
k_rot = key[:, :, :, : self.rotary_dim] | |
k_pass = key[:, :, :, self.rotary_dim :] | |
q_rot = query[:, :, :, : self.rotary_dim] | |
q_pass = query[:, :, :, self.rotary_dim :] | |
k_rot = apply_rotary_pos_emb(k_rot, sincos) | |
q_rot = apply_rotary_pos_emb(q_rot, sincos) | |
key = jnp.concatenate([k_rot, k_pass], axis=-1) | |
query = jnp.concatenate([q_rot, q_pass], axis=-1) | |
else: | |
key = apply_rotary_pos_emb(key, sincos) | |
query = apply_rotary_pos_emb(query, sincos) | |
query_length, key_length = query.shape[1], key.shape[1] | |
if self.has_variable("cache", "cached_key"): | |
mask_shift = self.variables["cache"]["cache_index"] | |
max_decoder_length = self.variables["cache"]["cached_key"].shape[1] | |
causal_mask = lax.dynamic_slice( | |
self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) | |
) | |
else: | |
causal_mask = self.causal_mask[:, :, :query_length, :key_length] | |
batch_size = hidden_states.shape[0] | |
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) | |
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) | |
attention_mask = combine_masks(attention_mask, causal_mask) | |
dropout_rng = None | |
if not deterministic and self.config.attn_pdrop > 0.0: | |
dropout_rng = self.make_rng("dropout") | |
# During fast autoregressive decoding, we feed one position at a time, | |
# and cache the keys and values step by step. | |
if self.has_variable("cache", "cached_key") or init_cache: | |
key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask) | |
# transform boolean mask into float mask | |
attention_bias = lax.select( | |
attention_mask > 0, | |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype), | |
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype), | |
) | |
# usual dot product attention | |
attn_weights = dot_product_attention_weights( | |
query, | |
key, | |
bias=attention_bias, | |
dropout_rng=dropout_rng, | |
dropout_rate=self.config.attn_pdrop, | |
deterministic=deterministic, | |
dtype=self.dtype, | |
precision=None, | |
) | |
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) | |
attn_output = self._merge_heads(attn_output) | |
attn_output = self.out_proj(attn_output) | |
attn_output = self.resid_dropout(attn_output, deterministic=deterministic) | |
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) | |
return outputs | |
class GPTNeoXMLP(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.dense_h_to_4h = nn.Linear(config.hidden_size, config.intermediate_size) | |
self.dense_4h_to_h = nn.Linear(config.intermediate_size, config.hidden_size) | |
self.act = ACT2FN[config.hidden_act] | |
def forward(self, hidden_states): | |
hidden_states = self.dense_h_to_4h(hidden_states) | |
hidden_states = self.act(hidden_states) | |
hidden_states = self.dense_4h_to_h(hidden_states) | |
return hidden_states | |
class FlaxGPTNeoXMLP(nn.Module): | |
config: GPTNeoXConfig | |
intermediate_size: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
embed_dim = self.config.hidden_size | |
kernel_init = jax.nn.initializers.normal(self.config.initializer_range) | |
self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) | |
self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) | |
self.act = ACT2FN[self.config.activation_function] | |
self.dropout = nn.Dropout(rate=self.config.resid_pdrop) | |
def __call__(self, hidden_states, deterministic: bool = True): | |
hidden_states = self.fc_in(hidden_states) | |
hidden_states = self.act(hidden_states) | |
hidden_states = self.fc_out(hidden_states) | |
hidden_states = self.dropout(hidden_states, deterministic=deterministic) | |
return hidden_states | |
class GPTNeoXLayer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.use_parallel_residual = config.use_parallel_residual | |
self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.attention = GPTNeoXAttention(config) | |
self.mlp = GPTNeoXMLP(config) | |
def forward( | |
self, | |
hidden_states: Optional[torch.FloatTensor], | |
attention_mask: Optional[torch.FloatTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = False, | |
layer_past: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: Optional[bool] = False, | |
): | |
attention_layer_outputs = self.attention( | |
self.input_layernorm(hidden_states), | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
layer_past=layer_past, | |
head_mask=head_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
) | |
attn_output = attention_layer_outputs[0] # output_attn: attn_output, present, (attn_weights) | |
outputs = attention_layer_outputs[1:] | |
if self.use_parallel_residual: | |
# pseudocode: | |
# x = x + attn(ln1(x)) + mlp(ln2(x)) | |
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) | |
hidden_states = mlp_output + attn_output + hidden_states | |
else: | |
# pseudocode: | |
# x = x + attn(ln1(x)) | |
# x = x + mlp(ln2(x)) | |
attn_output = attn_output + hidden_states | |
mlp_output = self.mlp(self.post_attention_layernorm(attn_output)) | |
hidden_states = mlp_output + attn_output | |
if use_cache: | |
outputs = (hidden_states,) + outputs # hidden_states, present, (attn_weights) | |
else: | |
outputs = (hidden_states,) + outputs[1:] # hidden_states, (attn_weights) | |
return outputs | |
class FlaxGPTNeoXBlock(nn.Module): | |
config: GPTNeoXConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
hidden_size = self.config.hidden_size | |
inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size | |
self.ln_1 = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) | |
self.attn = FlaxGPTNeoXAttention(self.config, dtype=self.dtype) | |
self.mlp = FlaxGPTNeoXMLP(self.config, inner_dim, dtype=self.dtype) | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask=None, | |
position_ids=None, | |
deterministic: bool = True, | |
init_cache: bool = False, | |
output_attentions: bool = False, | |
): | |
residual = hidden_states | |
hidden_states = self.ln_1(hidden_states) | |
attn_outputs = self.attn( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
deterministic=deterministic, | |
init_cache=init_cache, | |
output_attentions=output_attentions, | |
) | |
attn_output = attn_outputs[0] | |
feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic) | |
# residual connection | |
hidden_states = attn_output + feed_forward_hidden_states + residual | |
return (hidden_states,) + attn_outputs[1:] | |
class GPTNeoXPreTrainedModel(PreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = GPTNeoXConfig | |
base_model_prefix = "gpt_neox" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["GPTNeoXLayer"] | |
def _init_weights(self, module): | |
"""Initialize the weights""" | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
elif isinstance(module, nn.LayerNorm): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
def _set_gradient_checkpointing(self, module, value=False): | |
if isinstance(module, GPTNeoXModel): | |
module.gradient_checkpointing = value | |
class FlaxGPTNeoXPreTrainedModel(FlaxPreTrainedModel): | |
""" | |
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained | |
models. | |
""" | |
config_class = GPTNeoXConfig | |
base_model_prefix = "transformer" | |
module_class: nn.Module = None | |
def __init__( | |
self, | |
config: GPTNeoXConfig, | |
input_shape: Tuple = (1, 1), | |
seed: int = 0, | |
dtype: jnp.dtype = jnp.float32, | |
_do_init: bool = True, | |
**kwargs, | |
): | |
module = self.module_class(config=config, dtype=dtype, **kwargs) | |
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) | |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: | |
# init input tensors | |
input_ids = jnp.zeros(input_shape, dtype="i4") | |
attention_mask = jnp.ones_like(input_ids) | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) | |
params_rng, dropout_rng = jax.random.split(rng) | |
rngs = {"params": params_rng, "dropout": dropout_rng} | |
if self.config.add_cross_attention: | |
encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,)) | |
encoder_attention_mask = attention_mask | |
module_init_outputs = self.module.init( | |
rngs, | |
input_ids, | |
attention_mask, | |
position_ids, | |
encoder_hidden_states, | |
encoder_attention_mask, | |
return_dict=False, | |
) | |
else: | |
module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) | |
random_params = module_init_outputs["params"] | |
if params is not None: | |
random_params = flatten_dict(unfreeze(random_params)) | |
params = flatten_dict(unfreeze(params)) | |
for missing_key in self._missing_keys: | |
params[missing_key] = random_params[missing_key] | |
self._missing_keys = set() | |
return freeze(unflatten_dict(params)) | |
else: | |
return random_params | |
def init_cache(self, batch_size, max_length): | |
r""" | |
Args: | |
batch_size (`int`): | |
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. | |
max_length (`int`): | |
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized | |
cache. | |
""" | |
# init input variables to retrieve cache | |
input_ids = jnp.ones((batch_size, max_length)) | |
attention_mask = jnp.ones_like(input_ids) | |
position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) | |
init_variables = self.module.init( | |
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True | |
) | |
return init_variables["cache"] | |
@add_start_docstrings_to_model_forward(GPTNeoX_INPUTS_DOCSTRING) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask=None, | |
position_ids=None, | |
params: dict = None, | |
past_key_values: dict = None, | |
dropout_rng: jax.random.PRNGKey = None, | |
train: bool = False, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.return_dict | |
batch_size, sequence_length = input_ids.shape | |
if position_ids is None: | |
if past_key_values is not None: | |
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.") | |
position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) | |
if attention_mask is None: | |
attention_mask = jnp.ones((batch_size, sequence_length)) | |
# Handle any PRNG if needed | |
rngs = {} | |
if dropout_rng is not None: | |
rngs["dropout"] = dropout_rng | |
inputs = {"params": params or self.params} | |
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTNeoXAttention module | |
if past_key_values: | |
inputs["cache"] = past_key_values | |
mutable = ["cache"] | |
else: | |
mutable = False | |
outputs = self.module.apply( | |
inputs, | |
jnp.array(input_ids, dtype="i4"), | |
jnp.array(attention_mask, dtype="i4"), | |
jnp.array(position_ids, dtype="i4"), | |
not train, | |
False, | |
output_attentions, | |
output_hidden_states, | |
return_dict, | |
rngs=rngs, | |
mutable=mutable, | |
) | |
# add updated cache to model output | |
if past_key_values is not None and return_dict: | |
outputs, past_key_values = outputs | |
outputs["past_key_values"] = unfreeze(past_key_values["cache"]) | |
return outputs | |
elif past_key_values is not None and not return_dict: | |
outputs, past_key_values = outputs | |
outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] | |
return outputs | |
class FlaxGPTNeoXBlockCollection(nn.Module): | |
config: GPTNeoXConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.blocks = [ | |
FlaxGPTNeoXBlock(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers) | |
] | |
def __call__( | |
self, | |
hidden_states, | |
attention_mask=None, | |
position_ids=None, | |
deterministic: bool = True, | |
init_cache: bool = False, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
all_attentions = () if output_attentions else None | |
all_hidden_states = () if output_hidden_states else None | |
for block in self.blocks: | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
layer_outputs = block( | |
hidden_states, | |
attention_mask, | |
position_ids=position_ids, | |
deterministic=deterministic, | |
init_cache=init_cache, | |
output_attentions=output_attentions, | |
) | |
hidden_states = layer_outputs[0] | |
if output_attentions: | |
all_attentions += (layer_outputs[1],) | |
# this contains possible `None` values - `FlaxGPTNeoXModule` will filter them out | |
outputs = (hidden_states, all_hidden_states, all_attentions) | |
return outputs | |
class FlaxGPTNeoXModule(nn.Module): | |
config: GPTNeoXConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.embed_dim = self.config.hidden_size | |
self.wte = nn.Embed( | |
self.config.vocab_size, | |
self.config.hidden_size, | |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), | |
) | |
self.dropout = nn.Dropout(rate=self.config.embd_pdrop) | |
self.h = FlaxGPTNeoXBlockCollection(self.config, dtype=self.dtype) | |
self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
position_ids, | |
deterministic=True, | |
init_cache: bool = False, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
input_embeds = self.wte(input_ids.astype("i4")) | |
hidden_states = self.dropout(input_embeds, deterministic=deterministic) | |
outputs = self.h( | |
hidden_states, | |
attention_mask, | |
position_ids=position_ids, | |
deterministic=deterministic, | |
init_cache=init_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
hidden_states = self.ln_f(hidden_states) | |
if output_hidden_states: | |
all_hidden_states = outputs[1] + (hidden_states,) | |
outputs = (hidden_states, all_hidden_states) + outputs[2:] | |
else: | |
outputs = (hidden_states,) + outputs[1:] | |
if not return_dict: | |
return tuple(v for v in outputs if v is not None) | |
return FlaxBaseModelOutput( | |
last_hidden_state=hidden_states, | |
hidden_states=outputs[1], | |
attentions=outputs[-1], | |
) | |
@add_start_docstrings( | |
"The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.", | |
GPT_NEOX_START_DOCSTRING, | |
) | |
class GPTNeoXModel(GPTNeoXPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
self.embed_in = nn.Embedding(config.vocab_size, config.hidden_size) | |
self.layers = nn.ModuleList([GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) | |
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
self.gradient_checkpointing = False | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.embed_in | |
def set_input_embeddings(self, value): | |
self.embed_in = value | |
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | |
@add_code_sample_docstrings( | |
checkpoint=_CHECKPOINT_FOR_DOC, | |
real_checkpoint=_REAL_CHECKPOINT_FOR_DOC, | |
output_type=BaseModelOutputWithPast, | |
config_class=_CONFIG_FOR_DOC, | |
) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, BaseModelOutputWithPast]: | |
r""" | |
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): | |
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that | |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all | |
`decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
elif input_ids is not None: | |
input_shape = input_ids.size() | |
elif inputs_embeds is not None: | |
input_shape = inputs_embeds.size()[:-1] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
batch_size, seq_length = input_shape | |
if past_key_values is None: | |
past_length = 0 | |
past_key_values = tuple([None] * self.config.num_hidden_layers) | |
else: | |
past_length = past_key_values[0][0].size(-2) | |
if position_ids is None: | |
device = input_ids.device if input_ids is not None else inputs_embeds.device | |
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device) | |
position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
else: | |
position_ids = position_ids.view(-1, seq_length).long() | |
# Attention mask. | |
if attention_mask is not None: | |
assert batch_size > 0, "batch_size has to be defined and > 0" | |
attention_mask = attention_mask.view(batch_size, -1) | |
# We create a 3D attention mask from a 2D tensor mask. | |
# Sizes are [batch_size, 1, 1, to_seq_length] | |
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] | |
# this attention mask is more simple than the triangular masking of causal attention | |
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. | |
attention_mask = attention_mask[:, None, None, :] | |
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for | |
# masked positions, this operation will create a tensor which is 0.0 for | |
# positions we want to attend and the dtype's smallest value for masked positions. | |
# Since we are adding it to the raw scores before the softmax, this is | |
# effectively the same as removing these entirely. | |
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility | |
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min | |
# Prepare head mask if needed | |
# 1.0 in head_mask indicate we keep the head | |
# attention_probs has shape bsz x n_heads x N x N | |
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] | |
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] | |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_in(input_ids) | |
hidden_states = inputs_embeds | |
if self.gradient_checkpointing and self.training: | |
if use_cache: | |
logger.warning( | |
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
) | |
use_cache = False | |
presents = () if use_cache else None | |
all_attentions = () if output_attentions else None | |
all_hidden_states = () if output_hidden_states else None | |
for i, (layer, layer_past) in enumerate(zip(self.layers, past_key_values)): | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
def create_custom_forward(module): | |
def custom_forward(*inputs): | |
# None for layer_past | |
return module(*inputs, use_cache, None, output_attentions) | |
return custom_forward | |
outputs = torch.utils.checkpoint.checkpoint( | |
create_custom_forward(layer), | |
hidden_states, | |
attention_mask, | |
position_ids, | |
head_mask[i], | |
) | |
else: | |
outputs = layer( | |
hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask[i], | |
layer_past=layer_past, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
) | |
hidden_states = outputs[0] | |
if use_cache is True: | |
presents = presents + (outputs[1],) | |
if output_attentions: | |
all_attentions = all_attentions + (outputs[2 if use_cache else 1],) | |
hidden_states = self.final_layer_norm(hidden_states) | |
# Add last hidden state | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) | |
return BaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=presents, | |
hidden_states=all_hidden_states, | |
attentions=all_attentions, | |
) | |
@add_start_docstrings( | |
"The bare GPTNeoX Model transformer outputting raw hidden-states without any specific head on top.", | |
GPTNeoX_START_DOCSTRING, | |
) | |
class FlaxGPTNeoXModel(FlaxGPTNeoXPreTrainedModel): | |
module_class = FlaxGPTNeoXModule | |
append_call_sample_docstring( | |
FlaxGPTNeoXModel, | |
_CHECKPOINT_FOR_DOC, | |
FlaxCausalLMOutput, | |
_CONFIG_FOR_DOC, | |
) | |
@add_start_docstrings( | |
"""GPTNeoX Model with a `language modeling` head on top for CLM fine-tuning.""", GPT_NEOX_START_DOCSTRING | |
) | |
class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): | |
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.gpt_neox = GPTNeoXModel(config) | |
self.embed_out = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_output_embeddings(self): | |
return self.embed_out | |
def set_output_embeddings(self, new_embeddings): | |
self.embed_out = new_embeddings | |
@add_start_docstrings_to_model_forward(GPT_NEOX_INPUTS_DOCSTRING.format("batch_size, sequence_length")) | |
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, | |
labels: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple, CausalLMOutputWithPast]: | |
r""" | |
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape | |
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. The two additional tensors are | |
only required when the model is used as a decoder in a Sequence to Sequence model. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks that can be used (see | |
`past_key_values` input) to speed up sequential decoding. | |
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that | |
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all | |
`decoder_input_ids` of shape `(batch_size, sequence_length)`. | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in | |
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are | |
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, GPTNeoXForCausalLM, GPTNeoXConfig | |
>>> import torch | |
>>> tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") | |
>>> config = GPTNeoXConfig.from_pretrained("EleutherAI/gpt-neox-20b") | |
>>> config.is_decoder = True | |
>>> model = GPTNeoXForCausalLM.from_pretrained("EleutherAI/gpt-neox-20b", config=config) | |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") | |
>>> outputs = model(**inputs) | |
>>> prediction_logits = outputs.logits | |
```""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
outputs = self.gpt_neox( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
past_key_values=past_key_values, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
lm_logits = self.embed_out(hidden_states) | |
lm_loss = None | |
if labels is not None: | |
# move labels to correct device to enable model parallelism | |
labels = labels.to(lm_logits.device) | |
# we are doing next-token prediction; shift prediction scores and input ids by one | |
shift_logits = lm_logits[:, :-1, :].contiguous() | |
labels = labels[:, 1:].contiguous() | |
loss_fct = CrossEntropyLoss() | |
lm_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)) | |
if not return_dict: | |
output = (lm_logits,) + outputs[1:] | |
return ((lm_loss,) + output) if lm_loss is not None else output | |
return CausalLMOutputWithPast( | |
loss=lm_loss, | |
logits=lm_logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): | |
input_shape = input_ids.shape | |
# cut decoder_input_ids if past is used | |
if past_key_values and past_key_values[0] is not None: | |
input_ids = input_ids[:, -1:] | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = attention_mask.long().cumsum(-1) - 1 | |
position_ids.masked_fill_(attention_mask == 0, 1) | |
if past_key_values: | |
position_ids = position_ids[:, -1].unsqueeze(-1) | |
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly | |
if attention_mask is None: | |
attention_mask = input_ids.new_ones(input_shape) | |
return { | |
"input_ids": input_ids, | |
"attention_mask": attention_mask, | |
"position_ids": position_ids, | |
"past_key_values": past_key_values, | |
} | |
def _reorder_cache(self, past_key_values, beam_idx): | |
reordered_past = () | |
for layer_past in past_key_values: | |
reordered_past += ( | |
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], | |
) | |
return reordered_past | |
class FlaxGPTNeoXForCausalLMModule(nn.Module): | |
config: GPTNeoXConfig | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self): | |
self.transformer = FlaxGPTNeoXModule(self.config, dtype=self.dtype) | |
self.lm_head = nn.Dense( | |
self.config.vocab_size, | |
dtype=self.dtype, | |
kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), | |
) | |
def __call__( | |
self, | |
input_ids, | |
attention_mask, | |
position_ids, | |
deterministic: bool = True, | |
init_cache: bool = False, | |
output_attentions: bool = False, | |
output_hidden_states: bool = False, | |
return_dict: bool = True, | |
): | |
outputs = self.transformer( | |
input_ids, | |
attention_mask, | |
position_ids, | |
deterministic=deterministic, | |
init_cache=init_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
hidden_states = outputs[0] | |
if self.config.tie_word_embeddings: | |
shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T | |
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states) | |
else: | |
lm_logits = self.lm_head(hidden_states) | |
if not return_dict: | |
return (lm_logits,) + outputs[1:] | |
return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) | |
@add_start_docstrings( | |
""" | |
The GPTNeoX Model transformer with a language modeling head on top. | |
""", | |
GPTNeoX_START_DOCSTRING, | |
) | |
class FlaxGPTNeoXForCausalLM(FlaxGPTNeoXPreTrainedModel): | |
module_class = FlaxGPTNeoXForCausalLMModule | |
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): | |
# initializing the cache | |
batch_size, seq_length = input_ids.shape | |
past_key_values = self.init_cache(batch_size, max_length) | |
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. | |
# But since GPTNeoX uses a causal mask, those positions are masked anyways. | |
# Thus we can create a single static attention_mask here, which is more efficient for compilation | |
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") | |
if attention_mask is not None: | |
position_ids = attention_mask.cumsum(axis=-1) - 1 | |
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) | |
else: | |
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) | |
return { | |
"past_key_values": past_key_values, | |
"attention_mask": extended_attention_mask, | |
"position_ids": position_ids, | |
} | |
def update_inputs_for_generation(self, model_outputs, model_kwargs): | |
model_kwargs["past_key_values"] = model_outputs.past_key_values | |
model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 | |
return model_kwargs | |
append_call_sample_docstring( | |
FlaxGPTNeoXForCausalLM, | |
_CHECKPOINT_FOR_DOC, | |
FlaxCausalLMOutput, | |
_CONFIG_FOR_DOC, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment