Skip to content

Instantly share code, notes, and snippets.

@josh-gree
Last active March 23, 2017 08:58
Show Gist options
  • Save josh-gree/e5f5421ce34fb8371c0f471ffce930de to your computer and use it in GitHub Desktop.
Save josh-gree/e5f5421ce34fb8371c0f471ffce930de to your computer and use it in GitHub Desktop.
from . import ChunkSource
import numpy as np
import h5py
class MySource(ChunkSource):
def __init__(self, fnames,name, *args, **kwargs):
super().__init__(*args, **kwargs)
self.fnames = fnames
self.name = name
def shape(self):
return (256,256,1)
def __len__(self):
return len(self.fnames)
def shuffle(self, indices):
self.fnames = [self.fnames[i] for i in indices]
def can_shuffle(self):
""" This source can be shuffled.
"""
return True
def __iter__(self):
start = 0
num_entries = len(self)
while start < num_entries:
end = min(num_entries, start + self.chunk_size)
chunk_fnames = self.fnames[start:end]
out = np.zeros((self.chunk_size,) + self.shape())
for ind,name in enumerate(chunk_fnames):
f = h5py.File(name, "r")
out[ind,...] = f[self.name]
f.close()
yield out
start = end
from . import Supplier
from ..sources import MySource
class MySupplier(Supplier):
@classmethod
def get_name(cls):
""" Returns the name of the supplier.
"""
return 'mysupplier'
def __init__(self, names, *args, **kwargs):
super().__init__(*args, **kwargs)
fnames = open(names,'r').readlines()
fnames = list(map(str.strip,fnames))
self.data = {'x':MySource(fnames=fnames,name='x',chunk_size=1),
'y':MySource(fnames=fnames,name='y',chunk_size=1)}
def get_sources(self, sources=None):
""" Returns all sources from this provider.
"""
if sources is None:
sources = list(self.data.keys())
elif not isinstance(sources, (list, tuple)):
sources = [sources]
for source in sources:
if source not in self.data:
raise KeyError(
'Invalid data key: {}. Valid keys are: {}'.format(
source, ', '.join(str(k) for k in self.data.keys())
))
return {k : self.data[k] for k in sources}
test/file_0.h5
test/file_1.h5
test/file_2.h5
test/file_3.h5
test/file_4.h5
test/file_5.h5
test/file_6.h5
test/file_7.h5
test/file_8.h5
test/file_9.h5
settings:
layer_reps: 5
module_reps: 5
cs: 16
model:
- input:
shape: [256,256,1]
name: x
- for:
range: "{{module_reps}}"
with_index: j
iterate:
- for:
range: "{{layer_reps}}"
with_index: i
iterate:
- convolution:
kernels: "{{cs*2**j}}"
size: [3,3]
name: "m{{j}}_c{{i}}"
sink: yes
- activation: relu
- pool:
size: [2,2]
strides: [2,2]
type: max
name: "m{{j}}_p"
sink: yes
- convolutiontranspose:
kernels: 128
size: [3,3]
strides: [2,2]
name: dc1
- merge:
inputs: [dc1,m3_p]
name: m1
- convolutiontranspose:
kernels: 64
size: [3,3]
strides: [2,2]
name: dc2
- merge:
inputs: [dc2,m2_p]
name: m2
- convolutiontranspose:
kernels: 32
size: [3,3]
strides: [2,2]
name: dc3
- merge:
inputs: [dc3,m1_p]
name: m3
- convolutiontranspose:
kernels: 16
size: [3,3]
strides: [2,2]
name: dc4
- merge:
inputs: [dc4,m0_p]
name: m4
- convolutiontranspose:
kernels: 8
size: [3,3]
strides: [2,2]
name: lc
- convolution:
kernels: 1
size: [1,1]
name: y
evaluate:
data:
- mysupplier:
names: names.txt
provider:
batch_size: 1
num_batches: 1
hooks:
- myhook
from kur.supplier import MySupplier
from kur.providers import BatchProvider
# What I imigine is basicly going on in the evaluate section
# 1) supplier created
# 2) supplier.data passed to batchprovider
# 3) iterator created and then next is called
# this works when provider batch_size and source chunk_size are equal!
m = MySupplier('names.txt')
b = BatchProvider(sources=m.data,batch_size=1)
b_ = iter(b)
print(next(b_))
# this does not...get same error as when calling kur evaluate test.yml
#
# ValueError: The truth value of an array with more than one element
# is ambiguous. Use a.any() or a.all()
m = MySupplier('names.txt')
b = BatchProvider(sources=m.data,batch_size=2)
b_ = iter(b)
print(next(b_))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment