Created
March 16, 2020 14:27
-
-
Save tals/941ea27021490eb60f4fa164dce4b413 to your computer and use it in GitHub Desktop.
A StyleGAN PNG loader. Saves the need for intermediate TFRecords if you have the CPU/GPU for it
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A StyleGAN PNG data loader. | |
Saves the need for intermediate TFRecords if you have the CPU/GPU for it. | |
Apply the following patch: | |
diff --git a/run_training.py b/run_training.py | |
index bc4c0a2..61e5a33 100755 | |
--- a/run_training.py | |
+++ b/run_training.py | |
@@ -58,7 +58,11 @@ def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, m | |
desc = 'stylegan2' | |
desc += '-' + dataset | |
- dataset_args = EasyDict(tfrecord_dir=dataset) | |
+ dataset_args = EasyDict( | |
+ path=os.path.join(data_dir, dataset), | |
+ class_name='artbreeder_extensions.dataset.PNGDataset', | |
+ num_workers=16 | |
+ ) | |
assert num_gpus in [1, 2, 4, 8] | |
sc.num_gpus = num_gpus | |
""" | |
from pathlib import Path | |
import numpy as np | |
import tensorflow as tf | |
from dnnlib import tflib | |
class PNGDataset: | |
def __init__( | |
self, | |
path, | |
num_workers=16, | |
repeat=True, | |
shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling. | |
prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching. | |
): | |
self.shape = [3, 1024, 1024] | |
self.resolution = 1024 | |
self.label_size = 0 | |
self.dtype = np.uint8 | |
self.label_dtype = np.float32 | |
self.dynamic_range = [0, 255] | |
self.resolution_log2 = int(np.log2(self.resolution)) | |
path = Path(path) | |
files = [str(x) for x in path.glob("*.png") if x.stat().st_size > 0] | |
print(f"found {len(files)} files") | |
with tf.name_scope("Dataset"), tf.device("/cpu:0"): | |
self._tf_minibatch_in = tf.placeholder( | |
tf.int64, name="minibatch_in", shape=[] | |
) | |
dset = tf.data.Dataset.from_generator( | |
lambda: files, | |
output_types=(tf.string), | |
) | |
def map_func(x): | |
x = tf.io.read_file(x) | |
x = tf.io.decode_png(x) | |
x = tf.transpose(x, (2,0,1)) # HWC to CHW | |
lbl = tf.zeros((0,), dtype=tf.float32) | |
return (x, lbl) | |
dset = dset.map( | |
map_func, | |
num_parallel_calls=num_workers | |
) | |
bytes_per_item = np.prod(self.shape) * np.dtype(self.dtype).itemsize | |
if shuffle_mb > 0: | |
dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1) | |
if repeat: | |
dset = dset.repeat() | |
if prefetch_mb > 0: | |
dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1) | |
dset = dset.batch(self._tf_minibatch_in) | |
self._cur_minibatch = 0 | |
self._tf_minibatch_np = None | |
self._tf_iterator = tf.data.Iterator.from_structure( | |
output_types=dset.output_types, | |
output_shapes=dset.output_shapes, | |
) | |
self._tf_init_op = self._tf_iterator.make_initializer(dset) | |
def close(self): | |
pass | |
def configure(self, minibatch_size, lod=0): | |
assert minibatch_size >= 1 | |
if self._cur_minibatch != minibatch_size: | |
self._tf_init_op.run({self._tf_minibatch_in: minibatch_size}) | |
self._cur_minibatch = minibatch_size | |
def get_minibatch_tf(self): # => images, labels | |
return self._tf_iterator.get_next() | |
# Get next minibatch as NumPy arrays. | |
def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels | |
self.configure(minibatch_size) | |
with tf.name_scope("Dataset"): | |
if self._tf_minibatch_np is None: | |
self._tf_minibatch_np = self.get_minibatch_tf() | |
return tflib.run(self._tf_minibatch_np) | |
def get_random_labels_tf(self, minibatch_size): # => labels | |
with tf.name_scope('Dataset'): | |
return tf.zeros([minibatch_size, 0], self.label_dtype) | |
# Get random labels as NumPy array. | |
def get_random_labels_np(self, minibatch_size): # => labels | |
return np.zeros([minibatch_size, 0], self.label_dtype) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment