Skip to content

Instantly share code, notes, and snippets.

@adler-j
Created February 17, 2017 10:09
Show Gist options
  • Save adler-j/f19d122034a43cc91b7cba2cddfbadfb to your computer and use it in GitHub Desktop.
Save adler-j/f19d122034a43cc91b7cba2cddfbadfb to your computer and use it in GitHub Desktop.
import odl
import numpy as np
class L2SquaredSmart(odl.solvers.Functional):
def __init__(self, op, data):
self.data = data
self.op = op
self.optdata = op.adjoint(data)
# create padded ft
padding = odl.ResizingOperator(op.domain,
ran_shp=(op.domain.shape[0] * 2 + 1,
op.domain.shape[1] * 2 + 1),
pad_mode='constant')
ft = odl.trafos.FourierTransform(padding.range, impl='pyfftw')
self.padded_ft = ft * padding
# create kernel
self.img_kernel = ft.domain.element(
lambda x: 2 * np.pi / (np.sqrt(x[0]**2 + x[1]**2)))
self.kernel = ft(self.img_kernel)
odl.solvers.Functional.__init__(self, op.domain, linear=False)
@property
def gradient(self):
AtA = self.padded_ft.inverse * self.kernel * self.padded_ft
return 2 * (AtA - self.optdata)
reco_space = odl.uniform_discr(
min_pt=[-20, -20], max_pt=[20, 20], shape=[300, 300], dtype='float32')
geometry = odl.tomo.parallel_beam_geometry(reco_space)
ray_trafo = odl.tomo.RayTransform(reco_space, geometry, impl='scikit')
phantom = odl.phantom.shepp_logan(reco_space, True)
data = ray_trafo(phantom)
with odl.util.Timer('Optimized'):
# Optimized functional
func = L2SquaredSmart(ray_trafo, data)
x = reco_space.zero()
odl.solvers.steepest_descent(func, x, line_search=0.005, maxiter=100,
callback=odl.solvers.CallbackShow())
with odl.util.Timer('Classical'):
# Classical functional
func = odl.solvers.L2NormSquared(ray_trafo.range) * (ray_trafo - data)
x = reco_space.zero()
odl.solvers.steepest_descent(func, x, line_search=0.005, maxiter=100,
callback=odl.solvers.CallbackShow())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment