Created
March 5, 2024 20:25
-
-
Save Rocketknight1/77003f78147a9485a0f619e7202bb030 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
# Copyright 2024 Google Inc. HuggingFace Inc. team. All rights reserved. | |
# | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
""" TensorFlow Gemma model.""" | |
import math | |
import warnings | |
from typing import List, Optional, Tuple, Union | |
import tensorflow as tf | |
from ...activations import get_tf_activation | |
from ...modeling_tf_outputs import TFBaseModelOutputWithPast, TFCausalLMOutputWithPast, TFSequenceClassifierOutputWithPast | |
from ...modeling_tf_utils import ( | |
TFPreTrainedModel, | |
TFCausalLanguageModelingLoss, | |
TFSequenceClassificationLoss, | |
get_initializer, | |
unpack_inputs, | |
keras_serializable, | |
) | |
from ...tf_utils import shape_list, stable_softmax, scaled_dot_product_attention | |
from ...utils import ( | |
add_start_docstrings, | |
add_start_docstrings_to_model_forward, | |
logging, | |
replace_return_docstrings, | |
) | |
from .configuration_gemma import GemmaConfig | |
logger = logging.get_logger(__name__) | |
_CONFIG_FOR_DOC = "GemmaConfig" | |
def _get_unpad_data(attention_mask): | |
seqlens_in_batch = tf.reduce_sum(tf.cast(attention_mask, tf.int32), axis=-1) | |
indices = tf.where(tf.reshape(attention_mask, [-1])) | |
max_seqlen_in_batch = tf.reduce_max(seqlens_in_batch) | |
cu_seqlens = tf.pad(tf.cumsum(seqlens_in_batch), [[1, 0]]) | |
return ( | |
indices, | |
cu_seqlens, | |
max_seqlen_in_batch, | |
) | |
class TFGemmaRMSNorm(tf.keras.layers.Layer): | |
def __init__(self, hidden_size, eps=1e-6, **kwargs): | |
super().__init__(**kwargs) | |
self.eps = eps | |
self.weight = self.add_weight(shape=(hidden_size,), initializer="zeros", trainable=True, name="weight") | |
def _norm(self, x): | |
return x * tf.math.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps) | |
def call(self, x): | |
output = self._norm(tf.cast(x, tf.float32)) | |
return output * (1 + self.weight) | |
class TFGemmaRotaryEmbedding(tf.keras.layers.Layer): | |
def __init__(self, dim, max_position_embeddings=2048, base=10000, **kwargs): | |
super().__init__(**kwargs) | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
def build(self, input_shape): | |
inv_freq = 1.0 / (self.base ** (tf.range(0, self.dim, 2, dtype=tf.float32) / self.dim)) | |
self.inv_freq = tf.expand_dims(tf.expand_dims(inv_freq, 0), 0) | |
super().build(input_shape) | |
def call(self, x, position_ids): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
inv_freq_expanded = tf.cast(self.inv_freq, x.dtype) | |
position_ids_expanded = tf.expand_dims(tf.cast(position_ids, x.dtype), -1) | |
freqs = tf.transpose(inv_freq_expanded @ position_ids_expanded, perm=[0, 2, 1]) | |
emb = tf.concat([freqs, freqs], axis=-1) | |
return tf.cos(emb), tf.sin(emb) | |
# Claude: Translated from PyTorch to TensorFlow | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : tf.shape(x)[-1] // 2] | |
x2 = x[..., tf.shape(x)[-1] // 2 :] | |
return tf.concat([-x2, x1], axis=-1) | |
# Claude: Translated from PyTorch to TensorFlow | |
def apply_rotary_pos_emb(q, k, cos, sin): | |
"""Applies Rotary Position Embedding to the query and key tensors. | |
Args: | |
q (`tf.Tensor`): The query tensor. | |
k (`tf.Tensor`): The key tensor. | |
cos (`tf.Tensor`): The cosine part of the rotary embedding. | |
sin (`tf.Tensor`): The sine part of the rotary embedding. | |
Returns: | |
`tuple(tf.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
""" | |
cos = tf.expand_dims(cos, 1) | |
sin = tf.expand_dims(sin, 1) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaMLP(tf.keras.layers.Layer): | |
def __init__(self, config, **kwargs): | |
super().__init__(**kwargs) | |
self.config = config | |
self.gate_proj = tf.keras.layers.Dense(config.intermediate_size, use_bias=False, name="gate_proj") | |
self.up_proj = tf.keras.layers.Dense(config.intermediate_size, use_bias=False, name="up_proj") | |
self.down_proj = tf.keras.layers.Dense(config.hidden_size, use_bias=False, name="down_proj") | |
self.act_fn = get_tf_activation(config.hidden_act) | |
def call(self, x): | |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
# Claude: Translated from PyTorch to TensorFlow | |
def repeat_kv(hidden_states, n_rep): | |
""" | |
This is the equivalent of tf.repeat(x, repeats=n_rep, axis=1). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = shape_list(hidden_states) | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = tf.expand_dims(hidden_states, 2) | |
hidden_states = tf.repeat(hidden_states, n_rep, axis=2) | |
return tf.reshape(hidden_states, [batch, num_key_value_heads * n_rep, slen, head_dim]) | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaAttention(tf.keras.layers.Layer): | |
"""Multi-headed attention from 'Attention Is All You Need' paper""" | |
def __init__(self, config, layer_idx=None, **kwargs): | |
super().__init__(**kwargs) | |
self.config = config | |
self.layer_idx = layer_idx | |
if layer_idx is None: | |
logger.warning_once( | |
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | |
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | |
"when creating this class." | |
) | |
self.hidden_size = config.hidden_size | |
self.num_heads = config.num_attention_heads | |
self.head_dim = config.head_dim | |
self.num_key_value_heads = config.num_key_value_heads | |
self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
self.max_position_embeddings = config.max_position_embeddings | |
self.rope_theta = config.rope_theta | |
self.is_causal = True | |
if self.hidden_size % self.num_heads != 0: | |
raise ValueError( | |
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | |
f" and `num_heads`: {self.num_heads})." | |
) | |
self.q_proj = tf.keras.layers.Dense( | |
self.num_heads * self.head_dim, use_bias=config.attention_bias, name="q_proj" | |
) | |
self.k_proj = tf.keras.layers.Dense( | |
self.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, name="k_proj" | |
) | |
self.v_proj = tf.keras.layers.Dense( | |
self.num_key_value_heads * self.head_dim, use_bias=config.attention_bias, name="v_proj" | |
) | |
self.o_proj = tf.keras.layers.Dense(self.hidden_size, use_bias=config.attention_bias, name="o_proj") | |
self.rotary_emb = TFGemmaRotaryEmbedding( | |
self.head_dim, | |
max_position_embeddings=self.max_position_embeddings, | |
base=self.rope_theta, | |
name="rotary_emb", | |
) | |
def call( | |
self, | |
hidden_states, | |
attention_mask=None, | |
position_ids=None, | |
past_key_value=None, | |
output_attentions=False, | |
use_cache=False, | |
cache_position=None, | |
**kwargs, | |
): | |
bsz, q_len = shape_list(hidden_states)[:2] | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = tf.reshape(query_states, [bsz, q_len, self.num_heads, self.head_dim]) | |
query_states = tf.transpose(query_states, [0, 2, 1, 3]) | |
key_states = tf.reshape(key_states, [bsz, q_len, self.num_key_value_heads, self.head_dim]) | |
key_states = tf.transpose(key_states, [0, 2, 1, 3]) | |
value_states = tf.reshape(value_states, [bsz, q_len, self.num_key_value_heads, self.head_dim]) | |
value_states = tf.transpose(value_states, [0, 2, 1, 3]) | |
cos, sin = self.rotary_emb(value_states, position_ids) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
if past_key_value is not None: | |
# cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | |
# key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
# Claude: The above code is commented out since the Cache class is not defined in this translation. | |
# It would need to be implemented separately in TensorFlow. For now, just using the key and value states directly. | |
pass | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
attn_weights = tf.matmul(query_states, key_states, transpose_b=True) / math.sqrt(self.head_dim) | |
if attention_mask is not None: # no matter the length, we just slice it | |
if cache_position is not None: | |
causal_mask = attention_mask[:, :, cache_position, : tf.shape(key_states)[-2]] | |
else: | |
causal_mask = attention_mask | |
attn_weights = attn_weights + tf.cast(causal_mask, attn_weights.dtype) | |
# upcast attention to fp32 | |
attn_weights = stable_softmax(attn_weights, axis=-1) | |
attn_weights = tf.cast(attn_weights, query_states.dtype) | |
attn_weights = tf.nn.dropout(attn_weights, rate=self.config.attention_dropout if self.training else 0.0) | |
attn_output = tf.matmul(attn_weights, value_states) | |
attn_output_shape = shape_list(attn_output) | |
expected_shape = [bsz, self.num_heads, q_len, self.head_dim] | |
if attn_output_shape != expected_shape: | |
raise ValueError( | |
f"`attn_output` should be of size {expected_shape}, but is {attn_output_shape}" | |
) | |
attn_output = tf.transpose(attn_output, [0, 2, 1, 3]) | |
attn_output = tf.reshape(attn_output, [bsz, q_len, self.hidden_size]) | |
attn_output = self.o_proj(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaDecoderLayer(tf.keras.layers.Layer): | |
def __init__(self, config, layer_idx, **kwargs): | |
super().__init__(**kwargs) | |
self.hidden_size = config.hidden_size | |
self.self_attn = TFGemmaAttention(config, layer_idx=layer_idx, name="self_attn") | |
self.mlp = TFGemmaMLP(config, name="mlp") | |
self.input_layernorm = TFGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="input_layernorm") | |
self.post_attention_layernorm = TFGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="post_attention_layernorm") | |
def call( | |
self, | |
hidden_states, | |
attention_mask=None, | |
position_ids=None, | |
past_key_value=None, | |
output_attentions=False, | |
use_cache=False, | |
cache_position=None, | |
**kwargs, | |
): | |
if "padding_mask" in kwargs: | |
warnings.warn( | |
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | |
) | |
residual = hidden_states | |
hidden_states = self.input_layernorm(hidden_states) | |
# Self Attention | |
hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
hidden_states=hidden_states, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_value, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cache_position=cache_position, | |
**kwargs, | |
) | |
hidden_states = residual + hidden_states | |
# Fully Connected | |
residual = hidden_states | |
hidden_states = self.post_attention_layernorm(hidden_states) | |
hidden_states = self.mlp(hidden_states) | |
hidden_states = residual + hidden_states | |
outputs = (hidden_states,) | |
if output_attentions: | |
outputs += (self_attn_weights,) | |
if use_cache: | |
outputs += (present_key_value,) | |
return outputs | |
GEMMA_START_DOCSTRING = r""" | |
This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the | |
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads | |
etc.) | |
This model is also a [tf.keras.Model](https://www.tensorflow.org/api_docs/python/tf/keras/Model) subclass. Use it | |
as a regular TF 2.0 Keras Model and refer to the TF 2.0 documentation for all matter related to general usage and | |
behavior. | |
Parameters: | |
config ([`GemmaConfig`]): | |
Model configuration class with all the parameters of the model. Initializing with a config file does not | |
load the weights associated with the model, only the configuration. Check out the | |
[`~PreTrainedModel.from_pretrained`] method to load the model weights. | |
""" | |
@add_start_docstrings( | |
"The bare Gemma Model outputting raw hidden-states without any specific head on top.", | |
GEMMA_START_DOCSTRING, | |
) | |
class TFGemmaPreTrainedModel(TFPreTrainedModel): | |
config_class = GemmaConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_keep_in_fp32_modules = ["inv_freq", "rotary_emb", "cos_cached", "sin_cached"] | |
_no_split_modules = ["TFGemmaDecoderLayer"] | |
_skip_keys_device_placement = ["past_key_values", "causal_mask"] | |
_supports_cache_class = True | |
def _init_weights(self, module): | |
std = self.config.initializer_range | |
if isinstance(module, tf.keras.layers.Dense): | |
module.kernel.assign(tf.keras.initializers.TruncatedNormal(stddev=std)(shape_list(module.kernel))) | |
if module.bias is not None: | |
module.bias.assign(tf.keras.initializers.Zeros()(shape_list(module.bias))) | |
elif isinstance(module, tf.keras.layers.Embedding): | |
module.embeddings.assign(tf.keras.initializers.TruncatedNormal(stddev=std)(shape_list(module.embeddings))) | |
if module.padding_idx is not None: | |
module.embeddings[module.padding_idx].assign(tf.zeros_like(module.embeddings[module.padding_idx])) | |
def _setup_cache(self, max_batch_size, max_cache_len=None): | |
# Claude: The Cache class is not defined in this translation, so this method is left unimplemented for now. | |
pass | |
def _reset_cache(self): | |
for layer in self.model.layers: | |
layer.self_attn.past_key_value = None | |
GEMMA_INPUTS_DOCSTRING = r""" | |
Args: | |
input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`): | |
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide | |
it. | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
[What are input IDs?](../glossary#input-ids) | |
attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: | |
- 1 for tokens that are **not masked**, | |
- 0 for tokens that are **masked**. | |
[What are attention masks?](../glossary#attention-mask) | |
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and | |
[`PreTrainedTokenizer.__call__`] for details. | |
If `past_key_values` is used, optionally only the last `input_ids` have to be input (see | |
`past_key_values`). | |
If you want to change padding behavior, you should read [`modeling_tf_utils._prepare_decoder_attention_mask`] | |
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more | |
information on the default strategy. | |
- 1 indicates the head is **not masked**, | |
- 0 indicates the head is **masked**. | |
position_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, | |
config.n_positions - 1]`. | |
[What are position IDs?](../glossary#position-ids) | |
past_key_values (`tuple(tuple(tf.Tensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): | |
Tuple of `tuple(tf.Tensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape | |
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape | |
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. | |
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention | |
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. | |
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't | |
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` | |
of shape `(batch_size, sequence_length)`. | |
inputs_embeds (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): | |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This | |
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the | |
model's internal embedding lookup matrix. | |
use_cache (`bool`, *optional*): | |
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see | |
`past_key_values`). | |
output_attentions (`bool`, *optional*): | |
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned | |
tensors for more detail. | |
output_hidden_states (`bool`, *optional*): | |
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for | |
more detail. | |
return_dict (`bool`, *optional*): | |
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. | |
""" | |
@add_start_docstrings( | |
"The bare Gemma Model outputting raw hidden-states without any specific head on top.", | |
GEMMA_START_DOCSTRING, | |
) | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaModel(TFGemmaPreTrainedModel): | |
""" | |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`TFGemmaDecoderLayer`] | |
Args: | |
config: GemmaConfig | |
""" | |
def __init__(self, config, *inputs, **kwargs): | |
super().__init__(config, *inputs, **kwargs) | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
self.embed_tokens = tf.keras.layers.Embedding( | |
config.vocab_size, config.hidden_size, embeddings_initializer=get_initializer(config.initializer_range), name="embed_tokens" | |
) | |
self.layers = [TFGemmaDecoderLayer(config, layer_idx=i, name=f"layers.{i}") for i in range(config.num_hidden_layers)] | |
self.norm = TFGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, name="norm") | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.embed_tokens | |
def set_input_embeddings(self, value): | |
self.embed_tokens.embeddings = value | |
self.embed_tokens.vocab_size = shape_list(value)[0] | |
@unpack_inputs | |
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | |
@replace_return_docstrings(output_type=TFBaseModelOutputWithPast, config_class=_CONFIG_FOR_DOC) | |
def call( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
cache_position=None, | |
training=False, | |
): | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if input_ids is not None and inputs_embeds is not None: | |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") | |
elif input_ids is not None: | |
input_shape = shape_list(input_ids) | |
elif inputs_embeds is not None: | |
input_shape = shape_list(inputs_embeds)[:-1] | |
else: | |
raise ValueError("You have to specify either input_ids or inputs_embeds") | |
if inputs_embeds is None: | |
inputs_embeds = self.embed_tokens(input_ids) | |
past_seen_tokens = 0 | |
if use_cache: # kept for BC (cache positions) | |
# Claude: The Cache class is not defined in this translation, so this code path is not implemented. | |
pass | |
if cache_position is None: | |
cache_position = tf.range(past_seen_tokens, past_seen_tokens + input_shape[1]) | |
if position_ids is None: | |
position_ids = tf.expand_dims(cache_position, 0) | |
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds) | |
# embed positions | |
hidden_states = inputs_embeds | |
# normalized | |
hidden_states = hidden_states * (self.config.hidden_size**0.5) | |
# decoder layers | |
all_hidden_states = () if output_hidden_states else None | |
all_self_attns = () if output_attentions else None | |
next_decoder_cache = None | |
for decoder_layer in self.layers: | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
layer_outputs = decoder_layer( | |
hidden_states, | |
attention_mask=causal_mask, | |
position_ids=position_ids, | |
past_key_value=past_key_values, | |
output_attentions=output_attentions, | |
use_cache=use_cache, | |
cache_position=cache_position, | |
) | |
hidden_states = layer_outputs[0] | |
if use_cache: | |
next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
if output_attentions: | |
all_self_attns = all_self_attns + (layer_outputs[1],) | |
hidden_states = self.norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states = all_hidden_states + (hidden_states,) | |
next_cache = None | |
if use_cache: | |
# Claude: The Cache class is not defined in this translation, so this code path is not implemented. | |
pass | |
if not return_dict: | |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
return TFBaseModelOutputWithPast( | |
last_hidden_state=hidden_states, | |
past_key_values=next_cache, | |
hidden_states=all_hidden_states, | |
attentions=all_self_attns, | |
) | |
def _update_causal_mask(self, attention_mask, input_tensor): | |
batch_size, seq_length = shape_list(input_tensor)[:2] | |
dtype = input_tensor.dtype | |
device = input_tensor.device | |
# support going beyond cached `max_position_embedding` | |
if seq_length > self.config.max_position_embeddings: | |
causal_mask = tf.ones((2 * self.config.max_position_embeddings, 2 * self.config.max_position_embeddings), dtype=tf.bool) | |
causal_mask = tf.linalg.band_part(causal_mask, 0, -1) | |
causal_mask = tf.cast(causal_mask, dtype) | |
else: | |
causal_mask = tf.cast(self.causal_mask[:seq_length, :seq_length], dtype) | |
causal_mask = tf.expand_dims(tf.expand_dims(causal_mask, 0), 0) | |
causal_mask = tf.repeat(causal_mask, batch_size, axis=0) | |
if attention_mask is not None and tf.rank(attention_mask) == 2: | |
mask_length = shape_list(attention_mask)[-1] | |
padding_mask = tf.equal(causal_mask[..., :mask_length], 0.0) & tf.equal(tf.expand_dims(tf.expand_dims(attention_mask, 1), 1), 0.0) | |
causal_mask = tf.where(padding_mask, tf.cast(tf.float32.min, dtype), causal_mask[..., :mask_length]) | |
return causal_mask | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaForCausalLM(TFGemmaPreTrainedModel): | |
_keys_to_ignore_on_load_missing = [r"lm_head.weight"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.model = TFGemmaModel(config, name="model") | |
self.vocab_size = config.vocab_size | |
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head", kernel_initializer=get_initializer(config.initializer_range)) | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens.embeddings = value | |
self.model.embed_tokens.vocab_size = shape_list(value)[0] | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
@unpack_inputs | |
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | |
@replace_return_docstrings(output_type=TFCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) | |
def call( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
cache_position=None, | |
training=False, | |
): | |
r""" | |
Args: | |
labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., | |
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored | |
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. | |
Returns: | |
Example: | |
```python | |
>>> from transformers import AutoTokenizer, TFGemmaForCausalLM | |
>>> model = TFGemmaForCausalLM.from_pretrained("google/gemma-7b") | |
>>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b") | |
>>> prompt = "What is your favorite condiment?" | |
>>> inputs = tokenizer(prompt, return_tensors="tf") | |
>>> # Generate | |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) | |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
"What is your favorite condiment?" | |
```""" | |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
output_hidden_states = ( | |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
) | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
outputs = self.model( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
cache_position=cache_position, | |
training=training, | |
) | |
hidden_states = outputs[0] | |
logits = self.lm_head(hidden_states) | |
loss = None | |
if labels is not None: | |
# Shift so that tokens < n predict n | |
shift_logits = logits[..., :-1, :] | |
shift_labels = labels[..., 1:] | |
# Flatten the tokens | |
loss = tf.keras.losses.sparse_categorical_crossentropy( | |
shift_labels, shift_logits, from_logits=True, axis=-1 | |
) | |
if not return_dict: | |
output = (logits,) + outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return TFCausalLMOutputWithPast( | |
loss=loss, | |
logits=logits, | |
past_key_values=outputs.past_key_values, | |
hidden_states=outputs.hidden_states, | |
attentions=outputs.attentions, | |
) | |
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): | |
past_length = 0 | |
if past_key_values is not None: | |
# Claude: The Cache class is not defined in this translation, so this code path is not fully implemented. | |
# It would need to be adapted to work with the TensorFlow cache format. | |
past_length = shape_list(past_key_values[0][0])[2] | |
# Keep only the unprocessed tokens: | |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where | |
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as | |
# input) | |
if attention_mask is not None and shape_list(attention_mask)[1] > shape_list(input_ids)[1]: | |
input_ids = input_ids[:, -(shape_list(attention_mask)[1] - past_length) :] | |
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard | |
# input_ids based on the past_length. | |
elif past_length < shape_list(input_ids)[1]: | |
input_ids = input_ids[:, past_length:] | |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. | |
position_ids = kwargs.get("position_ids", None) | |
if attention_mask is not None and position_ids is None: | |
# create position_ids on the fly for batch generation | |
position_ids = tf.cumsum(attention_mask, axis=-1, exclusive=True) | |
position_ids = tf.where(tf.equal(attention_mask, 0), 1, position_ids) | |
if past_key_values: | |
position_ids = position_ids[:, -shape_list(input_ids)[1] :] | |
# TODO @gante we should only keep a `cache_position` in generate, and do +=1. | |
# same goes for position ids. Could also help with continued generation. | |
cache_position = tf.range(past_length, past_length + shape_list(position_ids)[-1]) | |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step | |
if inputs_embeds is not None and past_key_values is None: | |
model_inputs = {"inputs_embeds": inputs_embeds} | |
else: | |
model_inputs = {"input_ids": input_ids} | |
model_inputs.update( | |
{ | |
"position_ids": position_ids, | |
"cache_position": cache_position, | |
"past_key_values": past_key_values, | |
"use_cache": kwargs.get("use_cache"), | |
"attention_mask": attention_mask, | |
} | |
) | |
return model_inputs | |
@staticmethod | |
def _reorder_cache(past_key_values, beam_idx): | |
# Claude: The Cache class is not defined in this translation, so this method is left unimplemented for now. | |
# It would need to be adapted to work with the TensorFlow cache format. | |
pass | |
@add_start_docstrings( | |
""" | |
The Gemma Model transformer with a sequence classification head on top (linear layer). | |
[`TFGemmaForSequenceClassification`] uses the last token in order to do the classification, as other causal models | |
(e.g. GPT-2) do. | |
Since it does classification on the last token, it requires to know the position of the last token. If a | |
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If | |
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the | |
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in | |
each row of the batch). | |
""", | |
GEMMA_START_DOCSTRING, | |
) | |
# Claude: Translated from PyTorch to TensorFlow | |
class TFGemmaForSequenceClassification(TFGemmaPreTrainedModel): | |
def __init__(self, config): | |
super().__init__(config) | |
self.num_labels = config.num_labels | |
self.model = TFGemmaModel(config, name="model") | |
self.score = tf.keras.layers.Dense(self.num_labels, use_bias=False, name="score") | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens.embeddings = value | |
self.model.embed_tokens.vocab_size = shape_list(value)[0] | |
@unpack_inputs | |
@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING) | |
@replace_return_docstrings(output_type=TFSequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC) | |
def call( | |
self, | |
input_ids=None, | |
attention_mask=None, | |
position_ids=None, | |
past_key_values=None, | |
inputs_embeds=None, | |
labels=None, | |
use_cache=None, | |
output_attentions=None, | |
output_hidden_states=None, | |
return_dict=None, | |
training=False, | |
): | |
r""" | |
labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): | |
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., | |
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If | |
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). | |
""" | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
transformer_outputs = self.model( | |
input_ids, | |
attention_mask=attention_mask, | |
position_ids=position_ids, | |
past_key_values=past_key_values, | |
inputs_embeds=inputs_embeds, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
training=training, | |
) | |
hidden_states = transformer_outputs[0] | |
logits = self.score(hidden_states) | |
if input_ids is not None: | |
batch_size = shape_list(input_ids)[0] | |
else: | |
batch_size = shape_list(inputs_embeds)[0] | |
if self.config.pad_token_id is None and batch_size != 1: | |
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") | |
if self.config.pad_token_id is None: | |
sequence_lengths = -1 | |
else: | |
if input_ids is not None: | |
sequence_lengths = tf.reduce_sum(tf.cast(tf.not_equal(input_ids, self.config.pad_token_id), tf.int32), axis=-1) - 1 | |
else: | |
sequence_lengths = -1 | |
pooled_logits = tf.gather(logits, sequence_lengths, batch_dims=1) | |
loss = None | |
if labels is not None: | |
if self.config.problem_type is None: | |
if self.num_labels == 1: | |
self.config.problem_type = "regression" | |
elif self.num_labels > 1 and (labels.dtype == tf.int32 or labels.dtype == tf.int64): | |
self.config.problem_type = "single_label_classification" | |
else: | |
self.config.problem_type = "multi_label_classification" | |
if self.config.problem_type == "regression": | |
loss_fct = tf.keras.losses.MeanSquaredError() | |
if self.num_labels == 1: | |
loss = loss_fct(labels, pooled_logits) | |
else: | |
loss = loss_fct(labels, pooled_logits) | |
elif self.config.problem_type == "single_label_classification": | |
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) | |
loss = loss_fct(labels, pooled_logits) | |
elif self.config.problem_type == "multi_label_classification": | |
loss_fct = tf.keras.losses.BinaryCrossentropy(from_logits=True) | |
loss = loss_fct(labels, pooled_logits) | |
if not return_dict: | |
output = (pooled_logits,) + transformer_outputs[1:] | |
return ((loss,) + output) if loss is not None else output | |
return TFSequenceClassifierOutputWithPast( | |
loss=loss, | |
logits=pooled_logits, | |
past_key_values=transformer_outputs.past_key_values, | |
hidden_states=transformer_outputs.hidden_states, | |
attentions=transformer_outputs.attentions, | |
) | |
# Claude: Some key points about the translation: | |
# - The PyTorch modules and functions have been replaced with their TensorFlow/Keras equivalents where possible. This includes using `tf.keras.layers` for the model layers, `tf.keras.losses` for the loss functions, etc. | |
# - The `torch.Tensor` inputs and outputs have been replaced with `tf.Tensor`. | |
# - PyTorch operations like `torch.matmul`, `torch.arange`, `torch.where`, etc. have been replaced with the TensorFlow equivalents. | |
# - The `forward` methods have been renamed to `call` to match the Keras layer API. | |
# - The `past_key_values` cache format seems to be custom in the PyTorch implementation with a `Cache` class. I've left comments indicating that this would need to be implemented separately for TensorFlow, as the cache format is different. For now, the cache-related code paths are mostly unimplemented or simplified. | |
# - There are a few PyTorch-specific functions like `apply_rotary_pos_emb` that I've directly translated to TensorFlow, but they may need further testing and optimization. | |
# - I've added some comments prefixed with "Claude:" to explain certain choices or highlight areas that may need further work. | |
# Let me know if you have any other questions! Translating a complex model like this requires careful testing and iteration to ensure correctness and performance. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment