Last active
September 16, 2024 04:30
-
-
Save jph00/aac855c4dbfb9d53a265c26cd9ee76cf to your computer and use it in GitHub Desktop.
Make get_peft_model() fast
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from bitsandbytes.nn.modules import Linear8bitLt, Linear4bit | |
from contextlib import contextmanager | |
def noop (x=None, *args, **kwargs): | |
"Do nothing" | |
return x | |
@contextmanager | |
def no_kaiming(): | |
old_iku = init.kaiming_uniform_ | |
init.kaiming_uniform_ = noop | |
try: yield | |
finally: init.kaiming_uniform_ = old_iku | |
_old_8init = Linear8bitLt.__init__ | |
_old_4init = Linear4bit.__init__ | |
def _new_4init(self, input_features, output_features, bias=True, has_fp16_weights=True, | |
memory_efficient_backward=False, threshold=0.0, index=None, device=None): | |
with no_kaiming(): | |
return _old_4init(self, input_features, output_features, bias=bias, has_fp16_weights=has_fp16_weights, | |
memory_efficient_backward=memory_efficient_backward, threshold=threshold, index=index, device=device) | |
def _new_8init(self, input_features, output_features, bias=True, has_fp16_weights=True, | |
memory_efficient_backward=False, threshold=0.0, index=None, device=None): | |
with no_kaiming(): | |
return _old_8init(self, input_features, output_features, bias=bias, has_fp16_weights=has_fp16_weights, | |
memory_efficient_backward=memory_efficient_backward, threshold=threshold, index=index, device=device) | |
Linear8bitLt.__init__ = _new_8init | |
Linear4bit.__init__ = _new_4init |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment