Skip to content

Instantly share code, notes, and snippets.

@botcs
Created October 22, 2019 09:56
Show Gist options
  • Select an option

  • Save botcs/1355f8fbfc6a8251195e7d82c5cf733e to your computer and use it in GitHub Desktop.

Select an option

Save botcs/1355f8fbfc6a8251195e7d82c5cf733e to your computer and use it in GitHub Desktop.
class SetOpsArithmeticModule(nn.Module):
def __init__(self, *args, **kwargs):
super(SetOpsArithmeticModule, self).__init__()
self.subtract_op = lambda x, y: x - y
self.intersect_op = lambda x, y: x * y
self.union_op = lambda x, y: x + y
def forward(self, a: torch.Tensor, b: torch.Tensor):
a = a.view(a.size(0), -1)
b = b.view(b.size(0), -1)
a_S_b = self.subtract_op(a, b)
b_S_a = self.subtract_op(b, a)
a_S_b_b = self.subtract_op(a_S_b, b)
b_S_a_a = self.subtract_op(b_S_a, a)
a_I_b = self.intersect_op(a, b)
b_I_a = self.intersect_op(b, a)
a_S_b_I_a = self.subtract_op(a, b_I_a)
b_S_a_I_b = self.subtract_op(b, a_I_b)
a_S_a_I_b = self.subtract_op(a, a_I_b)
b_S_b_I_a = self.subtract_op(b, b_I_a)
a_I_b_b = self.intersect_op(a_I_b, b)
b_I_a_a = self.intersect_op(b_I_a, a)
a_U_b = self.union_op(a, b)
b_U_a = self.union_op(b, a)
a_U_b_b = self.union_op(a_U_b, b)
b_U_a_a = self.union_op(b_U_a, a)
out_a = self.union_op(a_S_b_I_a, a_I_b)
out_b = self.union_op(b_S_a_I_b, b_I_a)
return out_a, out_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, \
a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, \
a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment