Skip to content

Instantly share code, notes, and snippets.

@yongjun823
Created November 20, 2018 00:09
Show Gist options
  • Save yongjun823/993b20626adc82883e8529df2d18f16c to your computer and use it in GitHub Desktop.
Save yongjun823/993b20626adc82883e8529df2d18f16c to your computer and use it in GitHub Desktop.
tensorflow estimator
import tensorlayer as tl
import tensorflow as tf
from config import config, log_config
from utils import *
def read_tf_img(path, name):
temp_string = tf.read_file(path + name)
temp_decoded = tf.image.decode_image(temp_string, channels=4)
temp_image = tf.image.per_image_standardization(temp_decoded)
return temp_image
def _parse_function(hr_name, lr_name):
hr_img = read_tf_img(config.TRAIN.hr_img_path, hr_name)
lr_img = read_tf_img(config.TRAIN.lr_img_path, lr_name)
return hr_img, lr_img
def train_input_fn():
train_hr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.hr_img_path, regx='.*.png', printable=False))
train_lr_img_list = sorted(tl.files.load_file_list(path=config.TRAIN.lr_img_path, regx='.*.png', printable=False))
return tf.data.Dataset \
.from_tensor_slices((train_hr_img_list, train_lr_img_list)) \
.map(_parse_function) \
.repeat() \
.shuffle(buffer_size=100) \
.batch(config.TRAIN.batch_size) \
.make_one_shot_iterator() \
.get_next()
ni = int(np.sqrt(batch_size))
with tf.Session() as sess:
for i in range(100):
sample_hr, sample_lr = sess.run(train_input)
print(sample_hr.shape)
print(sample_lr.shape)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment