Last active
September 22, 2023 20:23
-
-
Save innat/1ce396dd46496ce9aac2fce384226211 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Ref: https://gist.github.com/Rocketknight1/efc47242914788def0144b341b1ad638 | |
import math | |
import tensorflow as tf | |
from tensorflow.keras import layers | |
class TFAdaptiveAveragePooling(layers.Layer): | |
def __init__(self, output_size, **kwargs): | |
super().__init__(**kwargs) | |
if not isinstance(output_size, (list, tuple)): | |
output_size = (output_size,) | |
self.output_size = output_size | |
def _get_pooling_params(self, input_dim, output_dim): | |
small_window = math.ceil(input_dim / output_dim) | |
big_window = small_window + 1 | |
return small_window, big_window | |
def _compute_window(self, pools, windows, input_size, target_size, axis=-1): | |
small_pool, big_pool = pools | |
small_window, big_window = windows | |
both_pool = tf.concat([small_pool, big_pool], axis=axis) | |
window_starts = tf.math.floor( | |
(tf.range(target_size, dtype=tf.float32) * input_size) / target_size | |
) | |
window_starts = tf.cast(window_starts, tf.int64) | |
window_ends = tf.math.ceil( | |
(tf.range(1, target_size + 1, dtype=tf.float32) * input_size) | |
/ target_size | |
) | |
window_ends = tf.cast(window_ends, tf.int64) | |
pool_selector = tf.cast( | |
window_ends - window_starts - small_window, tf.bool | |
) | |
small_indices = window_starts | |
big_indices = window_starts + small_pool.shape[axis] | |
gather_indices = tf.where(pool_selector, big_indices, small_indices) | |
return tf.gather(both_pool, gather_indices, axis=axis) | |
def call(self, inputs): | |
raise NotImplementedError( | |
"This method should be implemented by subclasses." | |
) | |
class TFAdaptiveAveragePooling1D(TFAdaptiveAveragePooling): | |
def call(self, inputs): | |
_, input_dim, _ = tf.unstack(tf.shape(inputs)) | |
input_dim = tf.cast(input_dim, tf.float32) | |
output_size = self.output_size[0] | |
small_window, big_window = self._get_pooling_params(input_dim, output_size) | |
small_pool = tf.nn.avg_pool1d( | |
inputs, | |
ksize=small_window, | |
strides=1, | |
padding="VALID", | |
data_format='NWC', | |
) | |
big_pool = tf.nn.avg_pool1d( | |
inputs, | |
ksize=big_window, | |
strides=1, | |
padding="VALID", | |
data_format='NWC', | |
) | |
return self._compute_window( | |
[small_pool, big_pool], | |
[small_window, big_window], | |
input_size=input_dim, | |
target_size=output_size, | |
axis=1 | |
) | |
class TFAdaptiveAveragePooling2D(TFAdaptiveAveragePooling): | |
def _pseudo_pool(self, inputs, output_size, axis=-1): | |
input_dim = inputs.shape[axis] | |
small_window, big_window = self._get_pooling_params(input_dim, output_size) | |
if axis == 1: | |
small_window_shape = (small_window, 1) | |
big_window_shape = (big_window, 1) | |
elif axis == 2: | |
small_window_shape = (1, small_window) | |
big_window_shape = (1, big_window) | |
small_pool = tf.nn.avg_pool2d( | |
inputs, | |
ksize=small_window_shape, | |
strides=1, | |
padding="VALID", | |
data_format="NHWC", | |
) | |
big_pool = tf.nn.avg_pool2d( | |
inputs, | |
ksize=big_window_shape, | |
strides=1, | |
padding="VALID", | |
data_format="NHWC", | |
) | |
return self._compute_window( | |
[small_pool, big_pool], | |
[small_window, big_window], | |
input_size=input_dim, | |
target_size=output_size, | |
axis=axis | |
) | |
def call(self, inputs): | |
x = self._pseudo_pool(inputs, output_size=self.output_size[0], axis=1) | |
x = self._pseudo_pool(x, output_size=self.output_size[1], axis=2) | |
return x | |
class TFAdaptiveAveragePooling3D(TFAdaptiveAveragePooling): | |
def _pseudo_pool(self, inputs, output_size, axis=-1): | |
input_dim = inputs.shape[axis] | |
small_window, big_window = self._get_pooling_params(input_dim, output_size) | |
if axis == 1: | |
small_window_shape = (small_window, 1, 1) | |
big_window_shape = (big_window, 1, 1) | |
elif axis == 2: | |
small_window_shape = (1, small_window, 1) | |
big_window_shape = (1, big_window, 1) | |
elif axis == 3: | |
small_window_shape = (1, 1, small_window) | |
big_window_shape = (1, 1, big_window) | |
small_pool = tf.nn.avg_pool3d( | |
inputs, | |
ksize=small_window_shape, | |
strides=1, | |
padding="VALID", | |
data_format="NDHWC", | |
) | |
big_pool = tf.nn.avg_pool3d( | |
inputs, | |
ksize=big_window_shape, | |
strides=1, | |
padding="VALID", | |
data_format="NDHWC", | |
) | |
return self._compute_window( | |
[small_pool, big_pool], | |
[small_window, big_window], | |
input_size=input_dim, | |
target_size=output_size, | |
axis=axis | |
) | |
def call(self, inputs): | |
x = self._pseudo_pool(inputs, output_size=self.output_size[0], axis=1) | |
x = self._pseudo_pool(x, output_size=self.output_size[1], axis=2) | |
x = self._pseudo_pool(x, output_size=self.output_size[2], axis=3) | |
return x |
Author
innat
commented
Sep 16, 2023
•
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment