tensorflow の layers を keras.layers で置き換える書き方

今後 tf.layerstf.keras.layers で置き換わるらしいという話を聞きました。

参考: https://www.tensorflow.org/api_docs/python/tf/layers/Layer

ということは全部 keras 風に書かないといけないの?と思ったのですが、実は layers の部分だけ置き換えるというのもできるようです。

このへんの書き換えで盛大にハマったのでメモを残しておきます。

例としてサイズ2のベクトルを受け取ってスカラーを返す単純なモデルを使います。


従来の tensorflow の書き方。

import numpy as np
import tensorflow as tf


inputs = tf.placeholder(tf.float32, shape=(None, 2))
dense = tf.layers.dense(inputs, 1)

session = tf.Session()
session.run(tf.global_variables_initializer())
result = session.run(
    dense, feed_dict={inputs: np.array([[1.0, -1.0]])}
)
print(result)


tf.keras.layers を使って全部 keras 風に書く書き方。

import numpy as np
import tensorflow as tf


inputs = tf.keras.Input(shape=(2,))
dense = tf.keras.layers.Dense(1)(inputs)

model = tf.keras.Model(inputs, dense)
result = model.predict(np.array([[1.0, -1.0]]))
print(result)


layers の部分だけ tf.keras.layers を使う書き方。

import numpy as np
import tensorflow as tf


inputs = tf.placeholder(tf.float32, shape=(None, 2))
dense = tf.keras.layers.Dense(1)(inputs)

session = tf.Session()
session.run(tf.global_variables_initializer())
result = session.run(
    dense, feed_dict={inputs: np.array([[1.0, -1.0]])}
)
print(result)


tf.keras.Inputtf.placeholder の代わりに使える模様。tf.keras.Input を使う場合 shape にバッチサイズの次元は書かないことに注意。

import numpy as np
import tensorflow as tf


inputs = tf.keras.Input(shape=(2,))
dense = tf.keras.layers.Dense(1)(inputs)

session = tf.Session()
session.run(tf.global_variables_initializer())
result = session.run(
    dense, feed_dict={inputs: np.array([[1.0, -1.0]])}
)
print(result)


最後にだめな事例。tf.keras.Inputtf.placeholder の代わりに使えるけれど、その逆は無理。tf.placeholdertf.keras.Modelinputs として与えようとすると Input tensors to a Model must come from `tf.layers.Input` と怒られてしまいます。

import numpy as np
import tensorflow as tf


inputs = tf.placeholder(tf.float32, shape=(None, 2))
dense = tf.keras.layers.Dense(1)(inputs)

# inputs は placeholder ではダメ!
model = tf.keras.Model(inputs, dense)
result = model.predict(np.array([[1.0, -1.0]]))
print(result)


以上です。

どうか気をつけてほしい。テンサーフローのやみは深い。