Skip to content

Instantly share code, notes, and snippets.

@jph00
Last active September 16, 2024 04:30
Show Gist options
  • Save jph00/aac855c4dbfb9d53a265c26cd9ee76cf to your computer and use it in GitHub Desktop.
Save jph00/aac855c4dbfb9d53a265c26cd9ee76cf to your computer and use it in GitHub Desktop.
Make get_peft_model() fast
from bitsandbytes.nn.modules import Linear8bitLt, Linear4bit
from contextlib import contextmanager
def noop (x=None, *args, **kwargs):
"Do nothing"
return x
@contextmanager
def no_kaiming():
old_iku = init.kaiming_uniform_
init.kaiming_uniform_ = noop
try: yield
finally: init.kaiming_uniform_ = old_iku
_old_8init = Linear8bitLt.__init__
_old_4init = Linear4bit.__init__
def _new_4init(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
with no_kaiming():
return _old_4init(self, input_features, output_features, bias=bias, has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward, threshold=threshold, index=index, device=device)
def _new_8init(self, input_features, output_features, bias=True, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None, device=None):
with no_kaiming():
return _old_8init(self, input_features, output_features, bias=bias, has_fp16_weights=has_fp16_weights,
memory_efficient_backward=memory_efficient_backward, threshold=threshold, index=index, device=device)
Linear8bitLt.__init__ = _new_8init
Linear4bit.__init__ = _new_4init
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment