tensorflow の基本的な書き方

tensorflow の記憶を失ったときのためのメモ(毎日のように忘れる)(テンサーフ労はわるい)。

XOR を計算するサンプルです。

from random import randrange
import numpy as np
import tensorflow as tf

class Xor:
    __slots__ = [
        'sess', 'x_ph', 'y_ph', 'global_step', 'output', 'loss', 'optimizer', 'summary', 'writer'

    def __init__(self):
        self.sess = tf.Session()
        self.x_ph = tf.placeholder(tf.float32, shape=(None, 2), name='x')
        self.y_ph = tf.placeholder(tf.float32, shape=(None,), name='y')
        self.global_step = tf.Variable(0, trainable=False, name='global_step')
        self.output, self.loss, self.optimizer = self._make_model(self.x_ph, self.y_ph)

        summary_dir = 'tensor_board'
        if tf.gfile.Exists(summary_dir):
        self.writer = tf.summary.FileWriter(summary_dir, self.sess.graph)
        tf.summary.scalar('loss_summary', self.loss)
        self.summary = tf.summary.merge_all()


    def train(self, x, y):
        _, summary, global_step = self.sess.run(
            [self.optimizer, self.summary, self.global_step],
            feed_dict={self.x_ph: x, self.y_ph: y},
        self.writer.add_summary(summary, global_step)

    def predict(self, x):
        return self.sess.run(self.output, feed_dict={self.x_ph: x})

    def _make_model(self, x_ph, y_ph):
        dense1 = tf.layers.dense(x_ph, 256, activation=tf.nn.relu, name='dense1')
        dense2 = tf.layers.dense(dense1, 256, activation=tf.nn.relu, name='dense2')
        output = tf.layers.dense(dense2, 1, name='output')
        with tf.name_scope('loss'):
            loss = tf.reduce_mean(tf.square(output - y_ph))
        optimizer = tf.train.AdamOptimizer().minimize(loss, global_step=self.global_step)
        return output, loss, optimizer

with tf.Graph().as_default():
    xor = Xor()

    for i in range(1000):
        x = [[randrange(2), randrange(2)]]
        y = [(x[0][0] + x[0][1]) % 2]
        xor.train(x, y)

    for i in range(10):
        x = [[randrange(2), randrange(2)]]
        y = abs(np.round(xor.predict(x)))
        print('y: {}, x: {}'.format(x, y))