Created
June 17, 2020 15:04
-
-
Save mberr/c37a8068b38cabc98228db2cbe358043 to your computer and use it in GitHub Desktop.
Find maximal parameter value for a given CUDA device by successive halvening.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Find maximal parameter value for a given CUDA device by successive halvening.""" | |
from typing import Callable, Tuple, TypeVar | |
import torch | |
R = TypeVar('R') | |
def maximize_memory_utilization( | |
func: Callable[..., R], | |
parameter_name: str, | |
parameter_max_value: int, | |
*args, | |
**kwargs | |
) -> Tuple[R, int]: | |
""" | |
Iteratively reduce parameter value until no RuntimeError is generated by CUDA. | |
:param func: | |
The callable. | |
:param parameter_name: | |
The name of the parameter to maximise. | |
:param parameter_max_value: | |
The maximum value to start with. | |
:param args: | |
Additional positional arguments for func. Does _not_ include parameter_name! | |
:param kwargs: | |
Additional keyword-based arguments for func. Does _not_ include parameter_name! | |
:return: | |
The result, as well as the maximum value which led to successful execution. | |
""" | |
result = None | |
direct_success = True | |
while parameter_max_value > 0: | |
p_kwargs = {parameter_name: parameter_max_value} | |
try: | |
result = func(*args, **p_kwargs, **kwargs) | |
if not direct_success: | |
logger.info('Execution succeeded with %s=%d', parameter_name, parameter_max_value) | |
break | |
except RuntimeError as runtime_error: | |
# Failed at least once | |
direct_success = False | |
# clear cache | |
torch.cuda.empty_cache() | |
# Check whether the error actually came from CUDA OOM. | |
if 'CUDA out of memory.' not in runtime_error.args[0]: | |
raise runtime_error | |
logger.info('Execution failed with %s=%d', parameter_name, parameter_max_value) | |
parameter_max_value //= 2 | |
if parameter_max_value == 0: | |
raise MemoryError(f'Execution did not even succeed with {parameter_name}=1.') | |
return result, parameter_max_value |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment