Skip to content

Instantly share code, notes, and snippets.

@GongXinyuu
Last active October 14, 2023 09:24
Show Gist options
  • Select an option

  • Save GongXinyuu/b1b24bc93256f4ab0413f888f44c7e37 to your computer and use it in GitHub Desktop.

Select an option

Save GongXinyuu/b1b24bc93256f4ab0413f888f44c7e37 to your computer and use it in GitHub Desktop.
def get_fair_ch_idx(channels_list, inp_choice_idx) -> Tuple:
"""
This function will return the actual channel index based on fair sampling.
:param channels_list:
:param inp_choice_idx:
:return:
"""
min_channel = min(channels_list)
base_ch = min_channel
max_channel = max(channels_list)
num_ch_choice = len(channels_list)
if num_ch_choice > 1:
ch_step = (max_channel - min_channel) // (num_ch_choice - 1)
else:
ch_step = 0
mid_choice_idx_point = (len(channels_list) - 1) / 2
if inp_choice_idx <= mid_choice_idx_point:
st_ch_idx = base_ch
ed_ch_idx = base_ch + inp_choice_idx * ch_step
else:
st_ch_idx = max_channel - (channels_list[inp_choice_idx] - base_ch)
ed_ch_idx = max_channel
return st_ch_idx, ed_ch_idx
class SlimmableConv2d(nn.Conv2d):
def __init__(
self,
in_channels_list,
out_channels_list,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups_list=[1],
bias=True,
fair=False,
):
super(SlimmableConv2d, self).__init__(
max(in_channels_list),
max(out_channels_list),
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=max(groups_list),
bias=bias,
)
self.fair = fair
self.in_channels_list = in_channels_list
self.out_channels_list = out_channels_list
self.groups_list = groups_list
if self.groups_list == [1]:
self.groups_list = [1 for _ in range(len(in_channels_list))]
self.base_in_ch = min(in_channels_list)
self.base_out_ch = min(out_channels_list)
def forward(self, inputs: List[Union[torch.Tensor, int]]):
x, out_ch_idx = inputs
in_ch_idx = list(self.in_channels_list).index(x.shape[1])
groups = self.groups_list[in_ch_idx]
weight, bias = self.get_kernel(in_ch_idx, out_ch_idx)
y = F.conv2d(x, weight, bias, self.stride, self.padding, self.dilation, groups)
return y
def get_kernel(self, inp_ch_idx, out_ch_idx) -> Tuple:
if self.fair:
inp_st_ch_idx, inp_ed_ch_idx = get_fair_ch_idx(
self.in_channels_list, inp_ch_idx
)
out_st_ch_idx, out_ed_ch_idx = get_fair_ch_idx(
self.out_channels_list, out_ch_idx
)
out_sliced_weight = torch.cat(
[
self.weight[: self.base_out_ch],
self.weight[out_st_ch_idx:out_ed_ch_idx],
],
dim=0,
)
weight = torch.cat(
[
out_sliced_weight[:, : self.base_in_ch],
out_sliced_weight[:, inp_st_ch_idx:inp_ed_ch_idx],
],
dim=1,
)
if self.bias is not None:
bias = torch.cat(
[
self.bias[: self.base_out_ch],
self.bias[out_st_ch_idx:out_ed_ch_idx],
],
dim=0,
)
else:
bias = self.bias
else:
inp_ch, out_ch = (
self.in_channels_list[inp_ch_idx],
self.out_channels_list[out_ch_idx],
)
weight = self.weight[:out_ch, :inp_ch, :, :]
if self.bias is not None:
bias = self.bias[:out_ch]
else:
bias = self.bias
return weight, bias
if inp_choice_idx <= mid_choice_idx_point:
st_ch_idx = base_ch
ed_ch_idx = base_ch + inp_choice_idx * ch_step
else:
st_ch_idx = max_channel - (channels_list[inp_choice_idx] - base_ch)
ed_ch_idx = max_channel
return st_ch_idx, ed_ch_idx
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment