Skip to content

Instantly share code, notes, and snippets.

@jhelgert
Last active December 20, 2022 10:09
Show Gist options
  • Save jhelgert/a78ce3494cde96b6c1c0088eefa1eed9 to your computer and use it in GitHub Desktop.
Save jhelgert/a78ce3494cde96b6c1c0088eefa1eed9 to your computer and use it in GitHub Desktop.

Reducing integer multiplication into bitwise operations is a well known approach in order to accelerate things like matrix multiplication, see xy for instance.

One of the examples in xy is the following: Given a matrix $X \in \{ 0,1,2 \}^{m \times p}$, we want to calculate $X X^\top$ as fast as possible. Since matrix multiplication is basically just a couple of scalar products, we only consider the latter for the ease of clarity. On the other hand, the scalar product only consists of two operations: integer multiplication and integer addition. The former is computationally expensive, so we'd like to implement the integer multiplication $a \cdot b$ with $a, b \in \{0,1,2\}$ through bitwise operations.

One reasonable choice is using the two bitwise operations (which are cheap to compute) and two simple lookup tables:

constexpr static std::array<std::uint8_t, 4> lookupOR{0, 0, 0, 2};
constexpr static std::array<std::uint8_t, 3> lookupAND{0, 1, 4};

std::uint8_t multiply(std::uint8_t a, std::uint8_t b){
    return lookupAND[a & b] + lookupOR[a | b];

How can we find such lookup tables?

Instead of doing it by hand, we could formulate this task as an integer optimization problem.

Let's assume a set of numbers $I$ is given (think of $I = \{ 0, 1, 2\}$) with $N := \max I$. First, we notice that

  1. $a\, \text{AND} \,b \leq \max\{a,b\} =: N_1$ for all $a,b \in I$
  2. $a\, \text{OR} \,b \leq \max\{a,b\} =: N_2$ for all $a,b \in I$
  3. $a\, \text{XOR} \,b \leq \max\{a,b\} =: N_3$ for all $a,b \in I$.

Now we introduce three integer variables $l_1 \in [-N, N^2]^{N_1}$, $l_2 \in [-N, N^2]^{N_2}$, $l_3 \in [-N, N^2]^{N_3}$ representing our desired lookup tables. In detail, the integer variable $l_{i,k}$ represents the value of the $i$-th lookup table at index $k$. In order to choice between the lookup tables, we further introduce three integer variables $s_1,s_2,s_3 \in \{-1, 0, 1\}$.

Next we define the three sets $B_1 = \{ a \, \text{AND}\, b \mid a, b \in I \}$, $B_2 = \{ a \, \text{XOR}\, b \mid a, b \in I \}$ and $B_3 = \{ a \, \text{OR}\, b \mid a, b \in I \}$.

We want to use as few lookup tables as possible and the as much as possible values inside these tables should be zero. Hence, we want to solve the following optimization problem:

$$ \min \sum s_j^2 + \sum l_{j,k}^2 $$

$$ a \cdot b = \sum_{k \in B_1} s_1 \cdot l_{i,k} + \sum_{k \in B_2} s_2 \cdot l_{2,k} + \sum_{k \in B_3} s_3 \cdot l_{3,k} \quad \forall a, b \in I $$

from gurobipy import Model, GRB, quicksum as qsum
import numpy as np
I = [3, 1, 2, 4]
N = max(I)
max_and = max(a & b for a in I for b in I)
max_xor = max(a ^ b for a in I for b in I)
max_or = max(a | b for a in I for b in I)
mdl = Model()
# Variables
s = mdl.addVars(3, vtype='I', lb=-1, ub=1)
lookup_and = mdl.addVars(max_and+1, vtype='I', lb=-N, ub=N*N*2)
lookup_xor = mdl.addVars(max_xor+1, vtype='I', lb=-N, ub=N*N*2)
lookup_or = mdl.addVars(max_or+1, vtype='I', lb=-N, ub=N*N*2)
# Objective
term1 = qsum(s[i]**2 for i in range(3))
term2 = qsum(lookup_and[i]**2 for i in range(max_and+1))
term3 = qsum(lookup_xor[i]**2 for i in range(max_xor+1))
term4 = qsum(lookup_or[i]**2 for i in range(max_or+1))
mdl.setObjective(100*term1+term2+term3+term4)
# Constraints
for a in I:
for b in I:
rhs = s[0]*lookup_and[a&b] + s[1]*lookup_xor[a^b] + s[2]*lookup_or[a|b]
mdl.addConstr(a*b == rhs, name=f"{a:02d}_times_{b:02d}")
# Optimize
mdl.Params.NonConvex = 2
mdl.Params.PoolSolutions = 20
mdl.Params.PoolSearchMode = 2
mdl.optimize()
if mdl.Status == GRB.OPTIMAL:
for k in range(mdl.SolCount):
mdl.Params.SolutionNumber = k
if s[0].Xn != 0:
AND = np.array([lookup_and[i].Xn for i in range(max_and+1)])
print(f"s[0] = {s[0].Xn:+.02f}, {AND = }")
if s[1].Xn != 0:
XOR = np.array([lookup_xor[i].Xn for i in range(max_xor+1)])
print(f"s[1] = {s[1].Xn:+.02f}, {XOR = }")
if s[2].Xn != 0:
OR = np.array([lookup_or[i].Xn for i in range(max_or+1)])
print(f"s[2] = {s[2].Xn:+.02f}, {OR = }")
print(""*80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment