Skip to content

Instantly share code, notes, and snippets.

@dboyliao
Last active March 22, 2020 14:55
Show Gist options
  • Save dboyliao/bbe6be13cb31a81e3c6a6104d1829184 to your computer and use it in GitHub Desktop.
Save dboyliao/bbe6be13cb31a81e3c6a6104d1829184 to your computer and use it in GitHub Desktop.
from ortools.sat.python import cp_model
model = cp_model.CpModel()
nonoverlap_map = {
'w:0': [],
'x:0': ['w:0'],
'y:0': ['x:0', 'w:0'],
'z:0': ['x:0', 'w:0', 'y:0'],
'u:0': ['w:0', 'z:0'],
}
max_pool_size = 1024 # 1KB
var_tuples = {}
vars_map = {}
for tensor_name in nonoverlap_map:
start = model.NewIntVar(0, max_pool_size, f'{tensor_name}_start')
end = model.NewIntVar(0, max_pool_size, f'{tensor_name}_end')
vars_map[tensor_name] = model.NewIntervalVar(start, 4, end, f'{tensor_name}_interval')
var_tuples[tensor_name] = (start, end)
for tensor_name, var_int in vars_map.items():
other_ints = [
vars_map[name] for name in nonoverlap_map[tensor_name]
]
model.AddNoOverlap([var_int] + other_ints)
obj = model.NewIntVar(0, max_pool_size, 'obj')
# obj == max(end_vars)
model.AddMaxEquality(obj, [v[1] for v in var_tuples.values()])
model.Minimize(obj)
solver = cp_model.CpSolver()
status = solver.Solve(model)
if solver.StatusName(status) == 'OPTIMAL':
for tensor_name, (start, _) in var_tuples.items():
print(tensor_name, f'memory offset: {solver.Value(start)}')
print(f'total memory usage: {solver.Value(obj)} bytes')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment