tensorflow で embedding_lookup をすると UserWarning が出るやつの対処

tensorflow で embedding_lookup をすると UserWarning が出て困ったので、対処法をメモに残しておきます。


以下のような感じのコードを用意します。

import numpy as np
import tensorflow as tf


class Embedding(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.lookup_table = self.add_variable(
            'embedding',
            (4, 16),
            dtype=tf.float32,
        )

    def call(self, inputs):
        return tf.nn.embedding_lookup(self.lookup_table, inputs)


inputs = tf.keras.Input(shape=(1,), dtype=tf.int32)
embedding = Embedding()(inputs)
flatten = tf.keras.layers.Flatten()(embedding)
dense = tf.keras.layers.Dense(256, activation='relu')(flatten)
outputs = tf.keras.layers.Dense(2, activation='softmax')(dense)

x = np.array([0, 1, 2, 3])
y = np.array([0, 0, 1, 1])

model = tf.keras.Model(inputs, outputs)
model.compile(
    'Adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)
model.fit(x, y, epochs=5)
print(model.evaluate(x, y))
print(np.argmax(model.predict(x), axis=1))

今回 Embedding というクラスを作っていて、これは0,1,2,3の4種類のスカラー値を16次元のベクトル空間に embedding するやつです。

ともかくこれを実行すると以下のような warning が出ます。

UserWarning: Converting sparse IndexedSlices
to a dense Tensor of unknown shape.
This may consume a large amount of memory.

sparse IndexedSlices というのを dense Tensor に変換するときにたくさんメモリを使うかもね、と書いてあります。

ということでこのメッセージでぐぐってやると以下の記事を見つけることができました。

stackoverflow.com

この記事によると以下のように書いてあります。

To fix this problem,
you should try to ensure that the params input to tf.gather()
(or the params inputs to tf.nn.embedding_lookup())
is a tf.Variable.
Variables can receive the sparse updates directly,
so no conversion is needed.

tf.Variable を使うとさっきのメモリをたくさんつかうかもな変換は起きないらしいです。

単純に考えると tf.get_variable()tf.Variable を作ってやればよさそうです。

class Embedding(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.lookup_table = tf.get_variable(
            name='embedding',
            shape=(4, 16),
            dtype=tf.float32,
        )

    def call(self, inputs):
        return tf.nn.embedding_lookup(self.lookup_table, inputs)

ただ tf.get_variable() を使うと自分で tf.global_variables_initializer() を呼んでやらないといけません。

なので

with tf.Session() as session:
    inputs = tf.keras.Input(shape=(1,), dtype=tf.int32)
    embedding = Embedding()(inputs)
    flatten = tf.keras.layers.Flatten()(embedding)
    dense = tf.keras.layers.Dense(256, activation='relu')(flatten)
    outputs = tf.keras.layers.Dense(2, activation='softmax')(dense)

    session.run(tf.global_variables_initializer()
    model = tf.keras.Model(inputs, outputs)

という感じでせっかく keras 風に書いて session のことを忘れていたのに、一気に厳しい感じになってきます。

で、他の方法なのですが Embedding のなかで使っている add_variable には変数を取ってくるためのメソッドを外部から指定できるようです。これを使って「変数を取ってくる部分は get_variable を使ってね」という感じにしてやればよさそうです。

class Embedding(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.lookup_table = self.add_variable(
            'embedding',
            (4, 16),
            dtype=tf.float32,
            getter=tf.get_variable,
        )

    def call(self, inputs):
        return tf.nn.embedding_lookup(self.lookup_table, inputs)

こうすると無事 UserWarning が消えました。めでたし。


以上です。

どうか気をつけてほしい。テンサーフローのやみは深い。