Skip to content

Instantly share code, notes, and snippets.

@Kautenja
Last active October 27, 2018 21:28
Show Gist options
  • Save Kautenja/b5690b8216bb95a5c043a6fc566f19e2 to your computer and use it in GitHub Desktop.
Save Kautenja/b5690b8216bb95a5c043a6fc566f19e2 to your computer and use it in GitHub Desktop.
Methods to get the model name of GPUs used by TensorFlow.
"""Methods to get the model name of GPUs used by TensorFlow."""
from tensorflow.python.client import device_lib
def get_device_model(device) -> str:
"""
Return the model of a TensorFlow device.
Args:
device: the device to get the model name of
device_type: the type to
Returns:
a string describing the model name of the device
"""
# get the physical description of the device
desc = device.physical_device_desc
# extract the name from the description
desc = desc.split(', name: ')[-1]
# remove the additional text
desc = desc.split(',')[0]
return desc
def get_gpu_models() -> list:
"""Return a list of the available GPUs by model name."""
# get the devices on the machine
devices = device_lib.list_local_devices()
# return a list of the devices by model name
return [get_device_model(d) for d in devices if d.device_type == 'GPU']
# explicitly define the outward facing API of this module
__all__ = [get_available_gpus.__name__]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment