tensorflow で embedding_lookup をすると UserWarning が出て困ったので、対処法をメモに残しておきます。
以下のような感じのコードを用意します。
import numpy as np import tensorflow as tf class Embedding(tf.keras.layers.Layer): def build(self, input_shape): self.lookup_table = self.add_variable( 'embedding', (4, 16), dtype=tf.float32, ) def call(self, inputs): return tf.nn.embedding_lookup(self.lookup_table, inputs) inputs = tf.keras.Input(shape=(1,), dtype=tf.int32) embedding = Embedding()(inputs) flatten = tf.keras.layers.Flatten()(embedding) dense = tf.keras.layers.Dense(256, activation='relu')(flatten) outputs = tf.keras.layers.Dense(2, activation='softmax')(dense) x = np.array([0, 1, 2, 3]) y = np.array([0, 0, 1, 1]) model = tf.keras.Model(inputs, outputs) model.compile( 'Adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'], ) model.fit(x, y, epochs=5) print(model.evaluate(x, y)) print(np.argmax(model.predict(x), axis=1))
今回 Embedding
というクラスを作っていて、これは0,1,2,3の4種類のスカラー値を16次元のベクトル空間に embedding するやつです。
ともかくこれを実行すると以下のような warning が出ます。
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.
sparse IndexedSlices というのを dense Tensor に変換するときにたくさんメモリを使うかもね、と書いてあります。
ということでこのメッセージでぐぐってやると以下の記事を見つけることができました。
この記事によると以下のように書いてあります。
To fix this problem, you should try to ensure that the params input to tf.gather() (or the params inputs to tf.nn.embedding_lookup()) is a tf.Variable. Variables can receive the sparse updates directly, so no conversion is needed.
tf.Variable
を使うとさっきのメモリをたくさんつかうかもな変換は起きないらしいです。
単純に考えると tf.get_variable()
で tf.Variable
を作ってやればよさそうです。
class Embedding(tf.keras.layers.Layer): def build(self, input_shape): self.lookup_table = tf.get_variable( name='embedding', shape=(4, 16), dtype=tf.float32, ) def call(self, inputs): return tf.nn.embedding_lookup(self.lookup_table, inputs)
ただ tf.get_variable()
を使うと自分で tf.global_variables_initializer()
を呼んでやらないといけません。
なので
with tf.Session() as session: inputs = tf.keras.Input(shape=(1,), dtype=tf.int32) embedding = Embedding()(inputs) flatten = tf.keras.layers.Flatten()(embedding) dense = tf.keras.layers.Dense(256, activation='relu')(flatten) outputs = tf.keras.layers.Dense(2, activation='softmax')(dense) session.run(tf.global_variables_initializer() model = tf.keras.Model(inputs, outputs)
という感じでせっかく keras 風に書いて session のことを忘れていたのに、一気に厳しい感じになってきます。
で、他の方法なのですが Embedding
のなかで使っている add_variable
には変数を取ってくるためのメソッドを外部から指定できるようです。これを使って「変数を取ってくる部分は get_variable
を使ってね」という感じにしてやればよさそうです。
class Embedding(tf.keras.layers.Layer): def build(self, input_shape): self.lookup_table = self.add_variable( 'embedding', (4, 16), dtype=tf.float32, getter=tf.get_variable, ) def call(self, inputs): return tf.nn.embedding_lookup(self.lookup_table, inputs)
こうすると無事 UserWarning が消えました。めでたし。
以上です。
どうか気をつけてほしい。テンサーフローのやみは深い。