Initializing GRU (or RNN in general) inside nested models in Keras

The name of the pictureThe name of the pictureThe name of the pictureClash Royale CLAN TAG#URR8PPP



Initializing GRU (or RNN in general) inside nested models in Keras



I need to build two Keras models:


gru_model


inital_state


wrapper_model


gru_model


gru_model



According to the output gru_model is built and called successfully, however wrapper_model fails during the call in the provided setting. Interestingly, if the initial_state param is not passed to GRU layer in gru_model, wrapper_model works as well.


gru_model


wrapper_model


initial_state


gru_model


wrapper_model



Please advise me on this issue. Thanks!


import numpy as np
from keras import Input, Model
from keras.layers import GRU

BATCH_SIZE = 2
SEQ_LEN = 3
EMB_SIZE = 4
HIDDEN_DIM = 5


def build_gru_model():
gru_input = Input(shape=(SEQ_LEN, EMB_SIZE), dtype='float32', name='gru_input')
gru_hs = Input(shape=(HIDDEN_DIM, ), dtype='float32', name='gru_hs')

output, updated_hs = GRU(
units=HIDDEN_DIM,
name='gru',
return_state=True
)(gru_input, initial_state=gru_hs) # wrapper_model fails to predict if inital_state is defined here

return Model(
inputs=[gru_input, gru_hs,],
outputs=[output, updated_hs],
name='gru_model')


def build_wrapper_model():
gru_input = Input(shape=(SEQ_LEN, EMB_SIZE), dtype='float32', name='caller_gru_input')
gru_hs = Input(shape=(HIDDEN_DIM, ), dtype='float32', name='caller_gru_hs')

output, updated_hs = gru_model(inputs=[gru_input, gru_hs])

return Model(
inputs=[gru_input, gru_hs],
outputs=[output, updated_hs],
name='caller_model')


if __name__ == '__main__':
ids_input = np.ones((BATCH_SIZE, SEQ_LEN, EMB_SIZE), dtype=np.float32)
hs_input = np.zeros((BATCH_SIZE, HIDDEN_DIM), dtype=np.float32)

gru_model = build_gru_model()
res_output, res_updated_hs = gru_model.predict([ids_input, hs_input]) # Works
print('ngru_model outputs:n'.format(res_output))
print('ngru_model hs:n'.format(res_updated_hs))

caller_model = build_wrapper_model()
res_output, res_updated_hs = caller_model.predict([ids_input, hs_input]) # Fails if inital_state is passed to GRU
print('nwrapper_model outputs:n'.format(res_output))
print('nwrapper_model hs:n'.format(res_updated_hs))


/home/nicolas/Code/test_gru_init.py
Using TensorFlow backend.
2018-08-08 16:44:55.013947: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA

gru_model outputs:
[[-0.6254787 0.5550256 -0.14170723 -0.51687145 0.09609278]
[-0.6254787 0.5550256 -0.14170723 -0.51687145 0.09609278]]

gru_model hs:
[[-0.6254787 0.5550256 -0.14170723 -0.51687145 0.09609278]
[-0.6254787 0.5550256 -0.14170723 -0.51687145 0.09609278]]

Traceback (most recent call last):
File "/home/nicolas/Code/Replika/cakechat-internal/tools/test_gru_init.py", line 49, in <module>
res_output, res_updated_hs = caller_model.predict([ids_input, hs_input]) # Fails if inital_state is passed to GRU
File "/home/nicolas/Code/keras/engine/training.py", line 1172, in predict
steps=steps)
File "/home/nicolas/Code/keras/engine/training_arrays.py", line 297, in predict_loop
batch_outs = f(ins_batch)
File "/home/nicolas/Code/keras/backend/tensorflow_backend.py", line 2661, in __call__
return self._call(inputs)
File "/home/nicolas/Code/keras/backend/tensorflow_backend.py", line 2631, in _call
fetched = self._callable_fn(*array_vals)
File "/home/nicolas/Code/tensorflow/python/client/session.py", line 1454, in __call__
self._session._session, self._handle, args, status, None)
File "/home/nicolas/Code/tensorflow/python/framework/errors_impl.py", line 519, in __exit__
c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'gru_hs' with dtype float and shape [?,5]
[[Node: gru_hs = Placeholder[dtype=DT_FLOAT, shape=[?,5], _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]

Process finished with exit code 1



--



Apparently there is a bug in Keras that causes this problem. For the details see https://github.com/keras-team/keras/issues/9385









By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.

Popular posts from this blog

make 2 or more post in bootsrap

Store custom data using WC_Cart add_to_cart() method in Woocommerce 3

Firebase Auth - with Email and Password - Check user already registered