Last active
November 21, 2017 13:16
-
-
Save adler-j/382d6aadce46b4b1b6c18ee54516db74 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Noise-free reconstruction. | |
Adapted from | |
https://github.com/odlgroup/odl/blob/master/examples/solvers/douglas_rachford_pd_tomography_tv.py | |
""" | |
import odl | |
from odl.contrib.fom import psnr | |
# Create ODL data structures | |
space = odl.uniform_discr([-64, -64], [64, 64], [128, 128]) | |
geometry = odl.tomo.parallel_beam_geometry(space, num_angles=30) | |
ray_trafo = odl.tomo.RayTransform(space, geometry) | |
phantom = odl.phantom.shepp_logan(space, modified=True) | |
data = ray_trafo(phantom) | |
# --- Create functionals for solving the optimization problem --- | |
# Gradient for TV regularization | |
gradient = odl.Gradient(space) | |
# Functional to enforce 0 <= x <= 1 | |
f = odl.solvers.IndicatorBox(space, 0, 1) | |
# Functional to enforce Ax = g | |
# Due to the splitting used in the douglas_rachford_pd solver, we only | |
# create the functional for the indicator function on g here, the forward | |
# model is handled separately. | |
indicator_zero = odl.solvers.IndicatorZero(ray_trafo.range) | |
indicator_data = indicator_zero.translated(data) | |
# Functional for TV minimization | |
cross_norm = 0.1 * odl.solvers.GroupL1Norm(gradient.range) | |
# Assemble operators and functionals for the solver | |
lin_ops = [ray_trafo, gradient] | |
g = [indicator_data, cross_norm] | |
# Solve with initial guess x = 0. | |
# Step size parameters are selected to ensure convergence. | |
# See douglas_rachford_pd doc for more information. | |
x = ray_trafo.domain.zero() | |
odl.solvers.douglas_rachford_pd(x, f, g, lin_ops, | |
tau=0.1, sigma=[0.01, 1.0], lam=1.5, niter=1000, | |
callback=odl.solvers.CallbackShow(step=10)) | |
# Show the result | |
x.show('reconstruction') | |
print('psnr: {}'.format(psnr(x, phantom))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
psnr: 132.67350370339886