Skip to content

Instantly share code, notes, and snippets.

@Tillsten
Last active August 29, 2015 14:07
Show Gist options
  • Select an option

  • Save Tillsten/cee3373e3967c84be496 to your computer and use it in GitHub Desktop.

Select an option

Save Tillsten/cee3373e3967c84be496 to your computer and use it in GitHub Desktop.
most basic proof of concept for using cassowary with mpl.
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 30 14:20:54 2014
@author: tillsten
"""
import cassowary # can be installed via pip, pure python.
import matplotlib.pyplot as plt
import matplotlib.textpath
from cassowary import Variable, WEAK, STRONG, MEDIUM, REQUIRED
figure_padding = 0.05
class Rect(object):
"""Basic rectangle representation using varibles"""
def __init__(self, name, lower_left=(0, 0), upper_right=(0, 0)):
self.name = ''
self.top = Variable(self.name + 'top', upper_right[1])
self.bottom = Variable(self.name + 'bottom', lower_left[1])
self.left = Variable(self.name + 'left', lower_left[0])
self.right = Variable(self.name + 'right', upper_right[1])
self.width = Variable(self.name + 'width')
self.height = Variable(self.name + 'height')
self.h_center = Variable(self.name + 'h_center')
self.v_center = Variable(self.name + 'v_center')
self.min_width = Variable(self.name + 'min_width', 0)
self.min_height = Variable(self.name + 'min_height', 0)
def add_constraints(self, sol, strength=MEDIUM):
for i in [self.top, self.bottom,
self.left, self.right]:
sol.add_stay(i, strength)
for i in [self.min_width, self.min_height]:
sol.add_stay(i, STRONG)
sol.add_constraint(self.width == self.right - self.left)
sol.add_constraint(self.height == self.top - self.bottom)
sol.add_constraint(self.h_center == (self.left + self.right) / 2.)
sol.add_constraint(self.v_center == (self.top + self.bottom) / 2.)
sol.add_constraint(self.width >= self.min_width)
sol.add_constraint(self.height >= self.min_height)
sol.add_constraint(self.bottom >= 0 + figure_padding)
sol.add_constraint(self.left >= 0 + figure_padding)
sol.add_constraint(self.top <= 1 - figure_padding)
sol.add_constraint(self.right <= 1 - figure_padding)
def get_mpl_rect(self):
return (self.left.value, self.bottom.value,
self.width.value, self.height.value)
def __repr__(self):
args = self.name, self.left.value, self.bottom.value, self.top.value, self.right.value
return 'Rect: %s, ll: (%f, %f), ur: (%f, %f)'%args
def get_text_extend(text, fp={}):
t = matplotlib.textpath.TextPath((0,0), text, **fp)
ex = t.get_extents()
fig_x, fig_y= plt.gcf().transFigure.inverted().transform([ex.width, ex.height])
return fig_x, fig_y
print get_text_extend('dsjklösdjgklsdjkgöjsdflögjkg')
def align(items, attr):
"""
Helper function to generate alignment constraints
items: a list of rects to align.
attr: which attribute to align.
"""
cons = []
for i in items:
cons.append(getattr(i, attr) == getattr(items[0], attr))
return cons
class Label(Rect):
def __init__(self, name, text=''):
super(Label, self).__init__(name)
self.text = text
#self.rotation = 'horizontal'
class Axes(object):
def __init__(self, sol):
self.patch = Rect('p_', (0, 0), (1, 1))
self.label_left = Label('ll_')
self.label_right = Label('lr_')
self.label_top = Label('lt_')
self.label_bottom = Label('lb_')
self.title = Label('t_')
self.labels = [self.label_top, self.label_bottom,
self.label_left, self.label_right]
self.solver = sol
self.basic_constraints()
def basic_constraints(self):
ll, lr, lt, lb = (self.label_left, self.label_right,
self.label_top, self.label_bottom)
p = self.patch
cons = [self.label_left.right == self.patch.left,
self.label_top.bottom == self.patch.top,
self.label_right.left == self.patch.right,
self.label_bottom.top == self.patch.bottom]
cons += align([lt, p, lb], 'h_center')
cons += align([ll, p, lr], 'v_center')
for i in cons:
self.solver.add_constraint(i)
for i in ll, lr, lt, lb, self.title:
i.add_constraints(self.solver)
p.add_constraints(self.solver, WEAK)
self.solver.add_constraint(p.width == 1, STRONG)
self.solver.add_constraint(p.height == 1, STRONG)
#self.solver.add_stay(p.width, MEDIUM)
self.solver.solve()
def set_title(self, text, fp=None):
self.solver.add_constraint(self.title.bottom == self.label_top.top)
self.solver.add_constraint(self.title.height == 0.2)
self.solver.solve()
def add_label(self, text, where='bottom'):
rect = getattr(self, 'label_' + where)
rect.text = text
rect.fontprops = {'fontsize': plt.rcParams['axes.labelsize']}
if where in ['left', 'right']:
rect.fontprops['rotation'] = 'vertical'
text_ex = get_text_extend(text, rect.fontprops)
self.solver.add_constraint(rect.width == text_ex[0])
elif where in ['top', 'bottom']:
rect.fontprops['rotation'] = 'horizontal'
text_ex = get_text_extend(text, rect.fontprops)
print text_ex
self.solver.add_constraint(rect.height == text_ex[1])
def draw(self, fig):
self.solver.solve()
rect = a.patch.get_mpl_rect()
ax = plt.gcf().add_axes(rect, frameon=1, axisbg='w')
for i in self.labels:
rect = i.get_mpl_rect()
r = plt.Rectangle(rect[:2], *rect[2:], transform=fig.transFigure,
alpha=0.3)
r.set_clip_on(False)
ax.add_artist(r)
if i.text != '':
#print i.h_center, i.v_center, i.height
ax.text(i.h_center.value, i.v_center.value,
i.text,transform=fig.transFigure,
va='center', ha='center',
**i.fontprops)
sol = cassowary.SimplexSolver()
##r = Rect((0, 0), (0, 0))
##r.add_constraints(sol)
##sol.add_constraint(r.height == 2)
##sol.solve()
a = Axes(sol)
sol.solve()
#a.set_title('234')
plt.rcParams['axes.labelsize'] = 50
a.add_label('label')
a.add_label('label', 'left')
a.add_label('label', 'top')
#a.add_label('label right', 'right')
sol.solve()
#t = a.title
#print a.title.bottom, a.title.top, a.title.height, a.label_top.top, a.label_top.bottom
#print a.label_top.min_width, a.label_top.width
#print a.label_bottom.min_width, a.label_bottom.width
#print a.patch, a.label_bottom
##plt.ion()
#fig = plt.Figure()
plt.clf()
a.draw(plt.gcf())
import cassowary # can be installed via pip, pure python.
import matplotlib.pyplot as plt
import matplotlib.textpath
from cassowary import Variable, WEAK, STRONG, MEDIUM
class Rect(object):
"""Basic rectangle representation using varibles"""
def __init__(self, lower_left, upper_right):
self.top = Variable('top', upper_right[1])
self.bottom = Variable('bottom', lower_left[1])
self.left = Variable('left', lower_left[0])
self.right = Variable('right', upper_right[1])
self.width = Variable('right')
self.height = Variable('height')
self.center_x = Variable('center_x')
self.center_y = Variable('center_y')
def add_constraints(self, sol, strength=WEAK):
for i in [self.top, self.bottom, self.left, self.right]:
sol.add_stay(i, strength)
sol.add_constraint(self.width == self.right - self.left)
sol.add_constraint(self.height == self.top - self.bottom)
sol.add_constraint(self.center_x == (self.left + self.right) / 2.)
sol.add_constraint(self.center_y == (self.top + self.bottom) / 2.)
def get_mpl_rect(self):
return (self.left.value, self.bottom.value,
self.width.value, self.height.value)
def get_text_extend(text):
t = matplotlib.textpath.TextPath((0,0), 'hello')
bb = t.get_extents()
inv = fig.transFigure.inverted()
tmp = inv.transform(bb)
return tmp[1, :] - tmp[0, :]
fig = plt.figure()
class AxesLayout(object):
def __init__(self):
self.solver = cassowary.SimplexSolver()
self.patch = Rect((0.5, 0.1), (0.9, 0.9))
self.patch.add_constraints(self.solver, STRONG)
self.y_label = None
def add_ylabel(self, text, pad=0.04):
self.y_label_text = text
width, height = get_text_extend(text)
self.y_label = Rect((0, 0), (width, height))
self.y_label.add_constraints(self.solver)
self.solver.add_constraint(self.patch.center_y == self.y_label.center_y)
self.solver.add_constraint(self.y_label.width == width)
self.solver.add_constraint(self.y_label.height == height)
self.solver.add_constraint(self.y_label.right + pad == self.patch.left)
return self.y_label
def draw(self):
rect = self.patch.get_mpl_rect()
ax = fig.add_axes(rect, axisbg='w')
if self.y_label:
ax.text(self.y_label.left.value,
self.patch.center_y.value,
self.y_label_text, transform=fig.transFigure,
rotation='vertical', fontsize=plt.rcParams['axes.labelsize'])
ax.plot([1,2,3])
return ax
def solve(self):
self.solver.solve()
#
ax = AxesLayout()
b = ax.add_ylabel('bla')
ax.solve()
sax = ax.draw()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment