Skip to content

Instantly share code, notes, and snippets.

@joaofig
Created October 13, 2018 18:08
Show Gist options
  • Select an option

  • Save joaofig/d1dc98bce3b3bcc78ee17e0cbcfe0071 to your computer and use it in GitHub Desktop.

Select an option

Save joaofig/d1dc98bce3b3bcc78ee17e0cbcfe0071 to your computer and use it in GitHub Desktop.
def fit(self, s_k):
if not torch.is_tensor(s_k):
raise ValueError('s_k must be a torch tensor.')
item_count = s_k.size()[0]
k_k = torch.arange(0, item_count, dtype=torch.float32)
r_k = torch.ones_like(s_k)
r_k[1:] = 1.0 / s_k[1:]
d_k = (r_k[2:] - r_k[1:-1]) / torch.log(k_k[2:])
# First estimate for K
k = torch.argmax(d_k) + 2
x = torch.log(k_k[1:])
y = torch.log(s_k[1:])
# Calculate the r2 value for the whole curve
r2 = self.calculate_r2(x, y)
if 2 < k < item_count - 3:
r2_1 = self.calculate_r2(x[:k], y[:k])
r2_2 = self.calculate_r2(x[k+1:], y[k+1:])
if (r2_1 + r2_2) / 2.0 > r2:
self.K = k
else:
self.K = 1
else:
self.K = k
return self
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment