Created
January 5, 2024 15:25
-
-
Save JohnAtl/2fd83eeee94af7053b6524064577b90e to your computer and use it in GitHub Desktop.
Distrobox with Tensorflow and Nvidia support
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
# Adapted from https://stackoverflow.com/a/47436840 | |
function lib_installed() { /sbin/ldconfig -N -v $(sed 's/:/ /' <<< $LD_LIBRARY_PATH) 2>/dev/null | grep $1; } | |
function check() { lib_installed $1 && echo "$1 is installed" || echo -e "\nERROR: $1 is NOT installed\n"; } | |
check libcuda.so | |
check libcudart | |
check libcudnn | |
if ! command -v nvcc &>/dev/null | |
then | |
echo -e "\nnvcc is not installed\n" | |
else | |
nvcc --version | |
fi | |
if ! command -v nvidia-smi &>/dev/null | |
then | |
echo -e "\nnvidia-smi is not installed\n" | |
else | |
nvidia-smi | |
fi | |
echo "Run tensorflow_mnist_test to test GPU funtion for training/testing on the MNIST dataset." |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[nvbox] | |
image=nvcr.io/nvidia/tensorflow:23.12-tf2-py3 | |
init=true | |
nvidia=true | |
pull=true | |
home="~/.local/share/distrobox/nvbox" # Sets an alternate home to not pollute your home as much | |
# volume="/mnt/nvme:/mnt/nvme:rw /mnt/btrfs:/mnt/btrfs:rw" # Replace with your shared folders | |
start_now=true | |
additional_packages="systemd openssh-server" | |
additional_packages="build-essential clang clang-tools python-is-python3" # random C packages | |
pre_init_hooks="echo 'Port 2222' | tee -a /etc/ssh/sshd_config" # Set ssh port to 2222 | |
pre_init_hooks="echo 'ListenAddress 127.0.0.1' | tee -a /etc/ssh/sshd_config" # Set listen address to only be localhost. | |
init_hooks=sudo -u "${USER}" "/usr/bin/cp -r ${HOME}/.ssh ~" # copy over ssh keys |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
# https://www.tensorflow.org/datasets/keras_example | |
try: | |
import tensorflow as tf | |
except: | |
print("Please install tensorflow:") | |
print("pip install tensorflow") | |
exit(-1) | |
try: | |
import tensorflow_datasets as tfds | |
except: | |
print("\nPlease install the tensorflow datasets:") | |
print(" pip install tensorflow_datasets") | |
print("You can safely ignore pip dependency errors") | |
exit(-2) | |
if len(tf.config.list_physical_devices('GPU')) < 1: | |
print("No GPU(s) found") | |
exit(-3) | |
(ds_train, ds_test), ds_info = tfds.load( | |
'mnist', | |
split=['train', 'test'], | |
shuffle_files=True, | |
as_supervised=True, | |
with_info=True, | |
) | |
def normalize_img(image, label): | |
"""Normalizes images: `uint8` -> `float32`.""" | |
return tf.cast(image, tf.float32) / 255., label | |
ds_train = ds_train.map( | |
normalize_img, num_parallel_calls=tf.data.AUTOTUNE) | |
ds_train = ds_train.cache() | |
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples) | |
ds_train = ds_train.batch(128) | |
ds_train = ds_train.prefetch(tf.data.AUTOTUNE) | |
ds_test = ds_test.map( | |
normalize_img, num_parallel_calls=tf.data.AUTOTUNE) | |
ds_test = ds_test.batch(128) | |
ds_test = ds_test.cache() | |
ds_test = ds_test.prefetch(tf.data.AUTOTUNE) | |
model = tf.keras.models.Sequential([ | |
tf.keras.layers.Flatten(input_shape=(28, 28)), | |
tf.keras.layers.Dense(128, activation='relu'), | |
tf.keras.layers.Dense(10) | |
]) | |
model.compile( | |
optimizer=tf.keras.optimizers.Adam(0.001), | |
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | |
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()], | |
) | |
model.fit( | |
ds_train, | |
epochs=6, | |
validation_data=ds_test, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment