Last active
March 23, 2016 00:57
-
-
Save koher/579b4149bd6bf9dda327 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding:utf-8 | |
import tensorflow as tf | |
from random import choice | |
zun = -1.0 | |
doko = 1.0 | |
zun_length = 4 | |
zundoko_length = zun_length + 1 | |
def make_zundoko_list(length): | |
return [choice([zun, doko]) for _ in range(length)] | |
def solve_zundoko_list(zundoko_list, index=0, zun_count=0): | |
if len(zundoko_list) == 0: | |
return index - zun_length | |
elif zun_count >= zun_length and zundoko_list[0] == doko: | |
return index - zun_length | |
else: | |
return solve_zundoko_list(zundoko_list[1:], index + 1, zun_count + 1 if zundoko_list[0] == zun else 0) | |
def make_solved_zundoko_list(length, answer): | |
zundoko_list = make_zundoko_list(length) | |
while True: | |
solved_answer = solve_zundoko_list(zundoko_list) | |
if solved_answer >= answer: | |
break | |
zundoko_list[solved_answer] = doko | |
if answer + zundoko_length <= length: | |
zundoko_list[answer:answer + zundoko_length] = [zun for _ in range(zun_length)] + [doko] | |
return zundoko_list | |
def dense_to_one_hot(index, num_classes): | |
return [(1.0 if i == index else 0.0) for i in range(num_classes)] | |
list_length = 100 | |
num_classes = list_length - zundoko_length + 1 + 1 | |
zundoko_lists = [make_solved_zundoko_list(list_length, i % num_classes) for i in range(100000)] | |
zundoko_answers = [dense_to_one_hot(solve_zundoko_list(z), num_classes) for z in zundoko_lists] | |
test_zundoko_lists = [make_zundoko_list(list_length) for _ in range(1000)] | |
test_zundoko_answers = [dense_to_one_hot(solve_zundoko_list(z), num_classes) for z in test_zundoko_lists] | |
x = tf.placeholder(tf.float32, [None, list_length]) | |
x_reshaped = tf.reshape(x, [-1, list_length, 1, 1]) | |
W1 = tf.Variable(tf.truncated_normal([5, 1, 1, 1], stddev=0.1)) | |
b1 = tf.Variable(tf.truncated_normal([1], stddev=0.1)) | |
h1_reshaped = tf.nn.relu(tf.nn.conv2d(x_reshaped, W1, strides=[1, 1, 1, 1], padding='SAME') + b1) | |
h1 = tf.reshape(h1_reshaped, [-1, list_length]) | |
W2 = tf.Variable(tf.truncated_normal([list_length, num_classes], stddev=0.1)) | |
b2 = tf.Variable(tf.truncated_normal([num_classes], stddev=0.1)) | |
y = tf.nn.softmax(tf.matmul(h1, W2) + b2) | |
y_ = tf.placeholder(tf.float32, [None, num_classes]) | |
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0))) | |
train_step = tf.train.GradientDescentOptimizer(1e-5).minimize(cross_entropy) | |
init = tf.initialize_all_variables() | |
sess = tf.Session() | |
sess.run(init) | |
for i in range(1000): | |
sess.run(train_step, feed_dict={x: zundoko_lists, y_: zundoko_answers}) | |
zundoko_list = make_zundoko_list(list_length) | |
zundoko_answer = sess.run(answer, feed_dict={x: [zundoko_list]})[0] | |
zundoko_string_list = ['ズン' if zundoko == zun else 'ドコ' for zundoko in zundoko_list] | |
zundoko_string_list = zundoko_string_list[:min(zundoko_answer + zundoko_length, len(zundoko_string_list))] | |
for zundoko_string in zundoko_string_list: | |
print(zundoko_string) | |
if zundoko_answer + zundoko_length == len(zundoko_string_list): | |
print('キ・ヨ・シ!') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment