Created
December 28, 2011 01:44
-
-
Save josharian/1525765 to your computer and use it in GitHub Desktop.
A failed attempt to adapt http://scikit-learn.org/stable/auto_examples/decomposition/plot_image_denoising.html#example-decomposition-plot-image-denoising-py to use pipelines
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from time import time | |
import pylab as pl | |
import scipy as sp | |
import numpy as np | |
from sklearn.decomposition import MiniBatchDictionaryLearning | |
from sklearn.feature_extraction.image import PatchExtractor | |
from sklearn.feature_extraction.image import reconstruct_from_patches_2d | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import Scaler | |
############################################################################### | |
# Load Lena image and extract patches | |
lena = sp.lena() / 256.0 | |
# downsample for higher speed | |
lena = lena[::2, ::2] + lena[1::2, ::2] + lena[::2, 1::2] + lena[1::2, 1::2] | |
lena /= 4.0 | |
height, width = lena.shape | |
# Distort the right half of the image | |
print 'Distorting image...' | |
distorted = lena.copy() | |
distorted[:, height / 2:] += 0.075 * np.random.randn(width, height / 2) | |
# Extract all clean patches from the left half of the image | |
print 'Extracting clean patches...' | |
t0 = time() | |
patch_size = (7, 7) | |
pipeline = Pipeline([("extract", PatchExtractor(patch_size)), ("scale", Scaler()), ("sparse", MiniBatchDictionaryLearning(n_atoms=100, alpha=1e-2, n_iter=500))]) | |
data = distorted[:, :height / 2] | |
print "Data shape", data.shape | |
RESHAPE = False # Fails for either value of RESHAPE | |
if RESHAPE: | |
data.shape = (1, data.shape[0], data.shape[1]) | |
print "Data reshaped", data.shape | |
############################################################################### | |
# Learn the dictionary from clean patches | |
print 'Learning the dictionary... ' | |
t0 = time() | |
V = pipeline.fit(data).components_ | |
print "V shape", V.shape | |
dt = time() - t0 | |
print 'done in %.2fs.' % dt |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment