Last active
October 27, 2018 21:28
-
-
Save Kautenja/b5690b8216bb95a5c043a6fc566f19e2 to your computer and use it in GitHub Desktop.
Methods to get the model name of GPUs used by TensorFlow.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Methods to get the model name of GPUs used by TensorFlow.""" | |
from tensorflow.python.client import device_lib | |
def get_device_model(device) -> str: | |
""" | |
Return the model of a TensorFlow device. | |
Args: | |
device: the device to get the model name of | |
device_type: the type to | |
Returns: | |
a string describing the model name of the device | |
""" | |
# get the physical description of the device | |
desc = device.physical_device_desc | |
# extract the name from the description | |
desc = desc.split(', name: ')[-1] | |
# remove the additional text | |
desc = desc.split(',')[0] | |
return desc | |
def get_gpu_models() -> list: | |
"""Return a list of the available GPUs by model name.""" | |
# get the devices on the machine | |
devices = device_lib.list_local_devices() | |
# return a list of the devices by model name | |
return [get_device_model(d) for d in devices if d.device_type == 'GPU'] | |
# explicitly define the outward facing API of this module | |
__all__ = [get_available_gpus.__name__] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment