Last active
March 23, 2017 08:58
-
-
Save josh-gree/e5f5421ce34fb8371c0f471ffce930de to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import ChunkSource | |
import numpy as np | |
import h5py | |
class MySource(ChunkSource): | |
def __init__(self, fnames,name, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.fnames = fnames | |
self.name = name | |
def shape(self): | |
return (256,256,1) | |
def __len__(self): | |
return len(self.fnames) | |
def shuffle(self, indices): | |
self.fnames = [self.fnames[i] for i in indices] | |
def can_shuffle(self): | |
""" This source can be shuffled. | |
""" | |
return True | |
def __iter__(self): | |
start = 0 | |
num_entries = len(self) | |
while start < num_entries: | |
end = min(num_entries, start + self.chunk_size) | |
chunk_fnames = self.fnames[start:end] | |
out = np.zeros((self.chunk_size,) + self.shape()) | |
for ind,name in enumerate(chunk_fnames): | |
f = h5py.File(name, "r") | |
out[ind,...] = f[self.name] | |
f.close() | |
yield out | |
start = end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from . import Supplier | |
from ..sources import MySource | |
class MySupplier(Supplier): | |
@classmethod | |
def get_name(cls): | |
""" Returns the name of the supplier. | |
""" | |
return 'mysupplier' | |
def __init__(self, names, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
fnames = open(names,'r').readlines() | |
fnames = list(map(str.strip,fnames)) | |
self.data = {'x':MySource(fnames=fnames,name='x',chunk_size=1), | |
'y':MySource(fnames=fnames,name='y',chunk_size=1)} | |
def get_sources(self, sources=None): | |
""" Returns all sources from this provider. | |
""" | |
if sources is None: | |
sources = list(self.data.keys()) | |
elif not isinstance(sources, (list, tuple)): | |
sources = [sources] | |
for source in sources: | |
if source not in self.data: | |
raise KeyError( | |
'Invalid data key: {}. Valid keys are: {}'.format( | |
source, ', '.join(str(k) for k in self.data.keys()) | |
)) | |
return {k : self.data[k] for k in sources} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
test/file_0.h5 | |
test/file_1.h5 | |
test/file_2.h5 | |
test/file_3.h5 | |
test/file_4.h5 | |
test/file_5.h5 | |
test/file_6.h5 | |
test/file_7.h5 | |
test/file_8.h5 | |
test/file_9.h5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
settings: | |
layer_reps: 5 | |
module_reps: 5 | |
cs: 16 | |
model: | |
- input: | |
shape: [256,256,1] | |
name: x | |
- for: | |
range: "{{module_reps}}" | |
with_index: j | |
iterate: | |
- for: | |
range: "{{layer_reps}}" | |
with_index: i | |
iterate: | |
- convolution: | |
kernels: "{{cs*2**j}}" | |
size: [3,3] | |
name: "m{{j}}_c{{i}}" | |
sink: yes | |
- activation: relu | |
- pool: | |
size: [2,2] | |
strides: [2,2] | |
type: max | |
name: "m{{j}}_p" | |
sink: yes | |
- convolutiontranspose: | |
kernels: 128 | |
size: [3,3] | |
strides: [2,2] | |
name: dc1 | |
- merge: | |
inputs: [dc1,m3_p] | |
name: m1 | |
- convolutiontranspose: | |
kernels: 64 | |
size: [3,3] | |
strides: [2,2] | |
name: dc2 | |
- merge: | |
inputs: [dc2,m2_p] | |
name: m2 | |
- convolutiontranspose: | |
kernels: 32 | |
size: [3,3] | |
strides: [2,2] | |
name: dc3 | |
- merge: | |
inputs: [dc3,m1_p] | |
name: m3 | |
- convolutiontranspose: | |
kernels: 16 | |
size: [3,3] | |
strides: [2,2] | |
name: dc4 | |
- merge: | |
inputs: [dc4,m0_p] | |
name: m4 | |
- convolutiontranspose: | |
kernels: 8 | |
size: [3,3] | |
strides: [2,2] | |
name: lc | |
- convolution: | |
kernels: 1 | |
size: [1,1] | |
name: y | |
evaluate: | |
data: | |
- mysupplier: | |
names: names.txt | |
provider: | |
batch_size: 1 | |
num_batches: 1 | |
hooks: | |
- myhook |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from kur.supplier import MySupplier | |
from kur.providers import BatchProvider | |
# What I imigine is basicly going on in the evaluate section | |
# 1) supplier created | |
# 2) supplier.data passed to batchprovider | |
# 3) iterator created and then next is called | |
# this works when provider batch_size and source chunk_size are equal! | |
m = MySupplier('names.txt') | |
b = BatchProvider(sources=m.data,batch_size=1) | |
b_ = iter(b) | |
print(next(b_)) | |
# this does not...get same error as when calling kur evaluate test.yml | |
# | |
# ValueError: The truth value of an array with more than one element | |
# is ambiguous. Use a.any() or a.all() | |
m = MySupplier('names.txt') | |
b = BatchProvider(sources=m.data,batch_size=2) | |
b_ = iter(b) | |
print(next(b_)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment