Skip to content

Instantly share code, notes, and snippets.

@usametov
Last active March 13, 2025 18:43
Show Gist options
  • Save usametov/20f6dec8d4e8af5ee22929452359585e to your computer and use it in GitHub Desktop.
Save usametov/20f6dec8d4e8af5ee22929452359585e to your computer and use it in GitHub Desktop.
gemma 3 27B on colab tpu using JAX

Running the Gemma 3-27B model on a TPU using JAX in Google Colab requires careful setup, as the model is large and TPUs have specific requirements for efficient execution. Below is a step-by-step guide to help you achieve this. Note that the Gemma 3-27B model is computationally intensive, and Google Colab's free TPU resources may not be sufficient due to memory and runtime limitations. You may need access to premium Colab resources or alternative platforms like Kaggle or Google Cloud TPUs for successful execution.

Step-by-Step Guide

1. Set Up Google Colab with TPU Runtime

  • Open a new notebook in Google Colab: Google Colab.
  • Go to the Runtime menu, select Change runtime type, and choose TPU v2 as the hardware accelerator. Note that Colab provides TPU v2, which may not have enough memory for a 27B model, so consider upgrading to a paid plan or using an alternative platform if necessary.

2. Install Required Libraries

  • You need to install JAX, Keras, and KerasNLP, ensuring compatibility with TPUs. Use the following commands in a Colab code cell to install the necessary libraries:

    # Install JAX with TPU support
    !pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
    # Install Keras and KerasNLP for Gemma models
    !pip install -q -U keras-nlp
    !pip install -q -U keras>=3
    
    # Workaround for TensorFlow compatibility on TPU
    !pip install -q -U tensorflow-cpu
    !pip install -q -U tensorflow-hub
  • Explanation:

    • jax[tpu]: Installs JAX with TPU support.
    • keras-nlp: Provides tools to work with Gemma models.
    • keras>=3: Ensures you have the latest Keras version, which supports JAX as a backend.
    • tensorflow-cpu and tensorflow-hub: Prevent TensorFlow from attempting to access the TPU directly, ensuring compatibility.

3. Set Up Environment Variables for Kaggle Access

  • The Gemma models are hosted on Kaggle, so you need to authenticate your access. First, obtain your Kaggle API key from your Kaggle account settings (kaggle.json).

  • In Colab, upload your kaggle.json file or set environment variables manually. Use the following code to set up access:

    import os
    from google.colab import userdata
    
    # Replace with your Kaggle username and key (or use userdata if stored in Colab secrets)
    os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')  # Or set manually
    os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')  # Or set manually
  • If you're not using Colab secrets, you can manually set os.environ["KAGGLE_USERNAME"] = "your_username" and os.environ["KAGGLE_KEY"] = "your_key".

4. Configure Keras to Use JAX Backend

  • Set the Keras backend to JAX, as it is required for TPU execution. Add this code to your notebook:

    import os
    os.environ["KERAS_BACKEND"] = "jax"
    # Allocate memory efficiently for TPU
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
  • Explanation:

    • KERAS_BACKEND: Ensures Keras uses JAX, which is optimized for TPUs.
    • XLA_PYTHON_CLIENT_MEM_FRACTION: Allocates 90% of TPU memory to the process, preventing memory errors.

5. Initialize TPU and Set Up Distribution Strategy

  • To utilize the TPU, you need to initialize it and set up a distribution strategy for parallel execution across TPU cores. Use the following code:

    import tensorflow as tf
    import keras
    
    # Initialize TPU
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
    
    # Set up TPUStrategy for distributed training/inference
    strategy = tf.distribute.TPUStrategy(resolver)
  • Explanation:

    • TPUClusterResolver: Connects to the TPU cluster in Colab.
    • initialize_tpu_system: Prepares the TPU for use.
    • TPUStrategy: Enables data and model parallelism across the 8 TPU cores, which is essential for handling large models like Gemma 3-27B.

6. Load the Gemma 3-27B Model

  • Use KerasNLP to load the Gemma 3-27B model. Since this is a large model, you need to shard it across the TPU cores using Keras' distribution API. Here's how to do it:

    import keras_nlp
    import keras
    
    # Create a device mesh for TPU cores (1, 8) shape
    device_mesh = keras.distribution.DeviceMesh(
        (1, 8),  # 1 batch dimension, 8 model dimensions (for 8 TPU cores)
        ["batch", "model"],
        devices=keras.distribution.list_devices()
    )
    
    # Define layout map for sharding weights
    layout_map = keras.distribution.LayoutMap(device_mesh)
    layout_map["token_embedding/embeddings"] = ("model", None)
    layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ("model", None, None)
    layout_map["decoder_block.*attention_output.*kernel"] = ("model", None, None)
    layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
    layout_map["decoder_block.*ffw_linear.*kernel"] = ("model", None)
    
    # Set the distribution layout
    model_parallel = keras.distribution.ModelParallel(device_mesh, layout_map, batch_dim_name="batch")
    keras.distribution.set_distribution(model_parallel)
    
    # Load the Gemma 3-27B model within the strategy scope
    with strategy.scope():
        gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma3_27b_en")
        gemma_lm.summary()
  • Explanation:

    • DeviceMesh: Represents the TPU cores as a mesh for distributing computation. The (1, 8) shape shards the model across all 8 TPU cores.
    • LayoutMap: Specifies how model weights and tensors are sharded or replicated across TPU cores.
    • ModelParallel: Enables model parallelism, which is necessary for large models like Gemma 3-27B that don't fit on a single TPU core.
    • from_preset: Loads the pre-trained Gemma 3-27B model. Note that you may need to specify the exact preset name, which can be found on Kaggle or in the KerasNLP documentation.

7. Test the Model

  • Once the model is loaded, you can test it by generating text. Use the generate method provided by KerasNLP:

    # Generate text
    prompt = "What is the meaning of life?"
    output = gemma_lm.generate(prompt, max_length=64)
    print(output)
  • Explanation:

    • generate: Generates text based on the prompt. The max_length argument limits the length of the generated sequence.
    • The first run may be slow due to XLA compilation, but subsequent runs will be faster.

8. Handle Potential Issues

  • Memory Errors: If you encounter out-of-memory errors, it may indicate that the Colab TPU v2 (with 16GB per core) is insufficient for the 27B model. Consider:
    • Using a smaller batch size or reducing max_length.
    • Switching to a premium Colab plan with access to more powerful TPUs (e.g., TPU v3 or v4).
    • Using Kaggle, which offers TPU v3 for free, or Google Cloud TPUs, which provide more memory and newer TPU versions.
  • Slow Compilation: The first run of generate compiles the model with XLA, which can take several minutes. Subsequent runs will be much faster.
  • Authentication Errors: Ensure your Kaggle API key is correctly set up to access the Gemma model.

9. Optimize for Performance

  • To improve performance, consider the following:
    • Use mixed precision (e.g., bfloat16) to reduce memory usage and speed up computation. You can set this in KerasNLP by enabling mixed precision policies before loading the model:

      from keras import mixed_precision
      mixed_precision.set_global_policy('mixed_bfloat16')
    • Ensure your input data is preprocessed efficiently and stored in a format accessible to the TPU (e.g., Google Cloud Storage for large datasets).

10. Clean Up Resources

  • After running your experiments, disconnect the TPU runtime to free up resources:
    • Go to Runtime > Disconnect and delete runtime.

Important Notes

  • Colab Limitations: The free Colab TPU v2 may not have enough memory to run the Gemma 3-27B model efficiently. For large models, consider using Kaggle (which offers TPU v3) or Google Cloud TPUs (which support TPU v3 and newer generations). Kaggle notebooks can be set up similarly, and Google Cloud TPUs require setting up a TPU VM, which can be connected to Colab via the "Connect to a custom GCE VM" option.
  • Alternative Platforms: If Colab's resources are insufficient, refer to the KerasNLP documentation for distributed training on Google Cloud TPUs or Kaggle. For example, Kaggle provides a tutorial on running Gemma models on TPU v3, which has more memory than Colab's TPU v2.
  • Model Availability: Ensure the gemma3_27b_en preset is available in KerasNLP. If not, check the latest KerasNLP documentation or Kaggle for the correct preset name or model weights.
  • Debugging: If you encounter errors, test your code on a CPU or GPU first to ensure the model and data pipeline work correctly, as TPUs are harder to debug.

Example Notebook

For a complete, working example, you can adapt the official KerasNLP tutorials available on Google AI for Developers or Kaggle. These tutorials often include notebooks specifically designed for running Gemma models on TPUs. Look for titles like "Distributed tuning with Gemma using Keras" or "Get started with Gemma using KerasNLP."

By following these steps, you should be able to run the Gemma 3-27B model on a TPU using JAX in Google Colab, provided you have sufficient resources. If you face persistent issues, consider scaling down to a smaller model (e.g., Gemma 7B) or upgrading your hardware resources.

If you only need inference for the Gemma 3-27B model on a TPU using JAX in Google Colab (and not training or fine-tuning), the setup can be streamlined to focus on loading the model, setting up the TPU, and running inference. Below is a step-by-step guide tailored for inference, while keeping resource efficiency in mind. Note that the 27B model is very large, and Google Colab's free TPU v2 may not have sufficient memory, so you may need premium Colab resources, Kaggle TPUs, or Google Cloud TPUs.

Step-by-Step Guide for Inference

1. Set Up Google Colab with TPU Runtime

  • Open a new notebook in Google Colab: Google Colab.
  • Go to the Runtime menu, select Change runtime type, and choose TPU v2 as the hardware accelerator. Be aware that TPU v2 has limited memory (16GB per core), so you may need to upgrade to a premium plan or use Kaggle/Google Cloud TPUs for the 27B model.

2. Install Required Libraries

  • Install JAX with TPU support, Keras, and KerasNLP. Use the following commands in a Colab code cell:

    # Install JAX with TPU support
    !pip install -q jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
    # Install Keras and KerasNLP for Gemma models
    !pip install -q -U keras-nlp
    !pip install -q -U keras>=3
    
    # Workaround for TensorFlow compatibility on TPU
    !pip install -q -U tensorflow-cpu
    !pip install -q -U tensorflow-hub
  • Explanation:

    • jax[tpu]: Provides JAX with TPU support.
    • keras-nlp and keras>=3: Enable loading and running Gemma models with JAX as the backend.
    • tensorflow-cpu and tensorflow-hub: Prevent TensorFlow from interfering with TPU execution.

3. Set Up Environment Variables for Kaggle Access

  • The Gemma models are hosted on Kaggle, so you need to authenticate your access. Obtain your Kaggle API key (kaggle.json) from your Kaggle account settings.

  • In Colab, either upload your kaggle.json file or set environment variables manually. Use the following code to set up access:

    import os
    from google.colab import userdata
    
    # Replace with your Kaggle username and key (or use userdata if stored in Colab secrets)
    os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')  # Or set manually
    os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')  # Or set manually
  • If you're not using Colab secrets, you can manually set os.environ["KAGGLE_USERNAME"] = "your_username" and os.environ["KAGGLE_KEY"] = "your_key".

4. Configure Keras to Use JAX Backend

  • Set the Keras backend to JAX, as it is required for TPU execution. Add this code to your notebook:

    import os
    os.environ["KERAS_BACKEND"] = "jax"
    # Allocate memory efficiently for TPU
    os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"
  • Explanation:

    • KERAS_BACKEND: Ensures Keras uses JAX, optimized for TPUs.
    • XLA_PYTHON_CLIENT_MEM_FRACTION: Allocates 90% of TPU memory to the process, reducing the likelihood of memory errors.

5. Initialize TPU and Set Up Distribution Strategy

  • Initialize the TPU and set up a distribution strategy to shard the model across TPU cores. This is necessary for large models like Gemma 3-27B to fit in memory. Use the following code:

    import tensorflow as tf
    import keras
    
    # Initialize TPU
    resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
    tf.config.experimental_connect_to_cluster(resolver)
    tf.tpu.experimental.initialize_tpu_system(resolver)
    print("All devices: ", tf.config.list_logical_devices('TPU'))
    
    # Set up TPUStrategy for distributed inference
    strategy = tf.distribute.TPUStrategy(resolver)
  • Explanation:

    • TPUClusterResolver: Connects to the TPU cluster in Colab.
    • initialize_tpu_system: Prepares the TPU for use.
    • TPUStrategy: Enables model parallelism across the 8 TPU cores, essential for handling the 27B model.

6. Load the Gemma 3-27B Model for Inference

  • Use KerasNLP to load the Gemma 3-27B model, sharding it across TPU cores for efficient inference. Since you're only doing inference, you can simplify the setup compared to training. Here's the code:

    import keras_nlp
    import keras
    
    # Create a device mesh for TPU cores (1, 8) shape
    device_mesh = keras.distribution.DeviceMesh(
        (1, 8),  # 1 batch dimension, 8 model dimensions (for 8 TPU cores)
        ["batch", "model"],
        devices=keras.distribution.list_devices()
    )
    
    # Define layout map for sharding weights
    layout_map = keras.distribution.LayoutMap(device_mesh)
    layout_map["token_embedding/embeddings"] = ("model", None)
    layout_map["decoder_block.*attention.*(query|key|value).*kernel"] = ("model", None, None)
    layout_map["decoder_block.*attention_output.*kernel"] = ("model", None, None)
    layout_map["decoder_block.*ffw_gating.*kernel"] = ("model", None)
    layout_map["decoder_block.*ffw_linear.*kernel"] = ("model", None)
    
    # Set the distribution layout
    model_parallel = keras.distribution.ModelParallel(device_mesh, layout_map, batch_dim_name="batch")
    keras.distribution.set_distribution(model_parallel)
    
    # Load the Gemma 3-27B model within the strategy scope
    with strategy.scope():
        gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma3_27b_en")
        gemma_lm.summary()
  • Explanation:

    • DeviceMesh: Represents the TPU cores as a mesh for distributing computation. The (1, 8) shape shards the model across all 8 TPU cores.
    • LayoutMap: Specifies how model weights are sharded across TPU cores, ensuring efficient memory usage.
    • ModelParallel: Enables model parallelism, necessary for large models that don't fit on a single TPU core.
    • from_preset: Loads the pre-trained Gemma 3-27B model. Ensure the preset name gemma3_27b_en is correct; if not, check the KerasNLP documentation or Kaggle for the exact preset name.

7. Run Inference

  • Perform inference by generating text based on a prompt. Use the generate method provided by KerasNLP:

    # Generate text
    prompt = "What is the meaning of life?"
    output = gemma_lm.generate(prompt, max_length=64)
    print(output)
  • Explanation:

    • generate: Generates text based on the prompt. The max_length argument limits the length of the generated sequence.
    • The first run may be slow due to XLA compilation, but subsequent runs will be faster.

8. Optimize for Inference Performance

  • To improve inference performance and reduce memory usage, consider the following:
    • Use Mixed Precision: Enable bfloat16 to reduce memory usage and speed up computation. Add this code before loading the model:

      from keras import mixed_precision
      mixed_precision.set_global_policy('mixed_bfloat16')
    • Reduce Sequence Length: Use a shorter max_length (e.g., 32 instead of 64) to reduce memory usage, especially if you're encountering out-of-memory errors.

    • Batch Size: For inference, the batch size is typically 1 (as in the example above). Avoid increasing the batch size unless you have sufficient memory, as it can lead to crashes on Colab TPUs.

9. Handle Potential Issues

  • Memory Errors: If you encounter out-of-memory errors, it may indicate that the Colab TPU v2 is insufficient for the 27B model. Consider:
    • Reducing max_length further.
    • Switching to a premium Colab plan with access to more powerful TPUs (e.g., TPU v3 or v4).
    • Using Kaggle, which offers TPU v3 for free, or Google Cloud TPUs, which provide more memory and newer TPU versions.
  • Slow Compilation: The first run of generate compiles the model with XLA, which can take several minutes. Subsequent runs will be faster.
  • Authentication Errors: Ensure your Kaggle API key is correctly set up to access the Gemma model.
  • Model Availability: If the gemma3_27b_en preset is not available, check the latest KerasNLP documentation or Kaggle for the correct preset name or model weights.

10. Clean Up Resources

  • After running your experiments, disconnect the TPU runtime to free up resources:
    • Go to Runtime > Disconnect and delete runtime.

Important Notes

  • Colab Limitations: The free Colab TPU v2 may not have enough memory to run the Gemma 3-27B model efficiently, especially for inference with longer sequences. For large models, consider using Kaggle (which offers TPU v3) or Google Cloud TPUs (which support TPU v3 and newer generations). Kaggle notebooks can be set up similarly, and Google Cloud TPUs require setting up a TPU VM, which can be connected to Colab via the "Connect to a custom GCE VM" option.
  • Alternative Platforms: If Colab's resources are insufficient, refer to the KerasNLP documentation for distributed inference on Google Cloud TPUs or Kaggle. For example, Kaggle provides tutorials on running Gemma models on TPU v3, which has more memory than Colab's TPU v2.
  • Debugging: If you encounter errors, test your code on a CPU or GPU first with a smaller model (e.g., Gemma 7B) to ensure the pipeline works correctly, as TPUs are harder to debug.
  • Inference-Only Focus: Since you're only doing inference, you don't need to worry about training-related configurations (e.g., optimizers, loss functions). The setup above is optimized for generating text efficiently.

Example Notebook

For a complete, working example, you can adapt the official KerasNLP tutorials available on Google AI for Developers or Kaggle. Look for titles like "Get started with Gemma using KerasNLP" and modify them to focus on inference only, as shown above.

By following these steps, you should be able to run inference with the Gemma 3-27B model on a TPU using JAX in Google Colab, provided you have sufficient resources. If you face persistent issues, consider scaling down to a smaller model (e.g., Gemma 7B) or upgrading your hardware resources.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment