This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
class RationalActivation(nn.Module): | |
""" | |
A rational activation function with trainable parameters. | |
Inspired by https://arxiv.org/abs/2205.01549. | |
.. seealso:: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Several similarity matrix normalization methods.""" | |
import torch | |
def csls( | |
sim: torch.FloatTensor, | |
k: Optional[int] = 1, | |
) -> torch.FloatTensor: | |
""" | |
Apply CSLS normalization to a similarity matrix. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Determine optimal threshold for Macro F1 score.""" | |
from typing import Tuple | |
import numpy | |
from sklearn.metrics._ranking import _binary_clf_curve | |
def f1_scores( | |
precision: numpy.ndarray, | |
recall: numpy.ndarray, | |
) -> numpy.ndarray: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
import logging | |
logger = logging.getLogger(__name__) | |
# pylint: disable=abstract-method | |
class ExtendedModule(nn.Module): | |
"""Extends nn.Module by a few utility methods.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def latex_bold(text: str) -> str: | |
"""Format text in bold font using Latex.""" | |
return rf"\textbf{{{text}}}" | |
def highlight_max( | |
data: pandas.Series, | |
float_formatter: Callable[[float], str] = "{:2.2f}".format, | |
highlighter: Callable[[str], str] = latex_bold, | |
) -> pandas.Series: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import string | |
from typing import Sequence | |
def random_sentence_list( | |
num_sentences: int = 1, | |
word_sep: str = ' ', | |
min_num_words: int = 1, | |
max_num_words: int = 1, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Find maximal parameter value for a given CUDA device by successive halvening.""" | |
from typing import Callable, Tuple, TypeVar | |
import torch | |
R = TypeVar('R') | |
def maximize_memory_utilization( | |
func: Callable[..., R], | |
parameter_name: str, |