Flatten in Keras video answering machine example

Clash Royale CLAN TAG#URR8PPP
Flatten in Keras video answering machine example
In Keras' video question answering example (https://keras.io/getting-started/functional-api-guide/), what does the vision_model.add(Flatten()) at the end of the convolutional neural net do and why is it needed?
vision_model.add(Flatten())
Full source:
from keras.layers import Conv2D, MaxPooling2D, Flatten
from keras.layers import Input, LSTM, Embedding, Dense
from keras.models import Model, Sequential
# First, let's define a vision model using a Sequential model.
# This model will encode an image into a vector.
vision_model = Sequential()
vision_model.add(Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(224, 224, 3)))
vision_model.add(Conv2D(64, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(128, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(128, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Conv2D(256, (3, 3), activation='relu', padding='same'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(Conv2D(256, (3, 3), activation='relu'))
vision_model.add(MaxPooling2D((2, 2)))
vision_model.add(Flatten())
Then later
from keras.layers import TimeDistributed
video_input = Input(shape=(100, 224, 224, 3))
# This is our video encoded via the previously trained vision_model (weights are reused)
encoded_frame_sequence = TimeDistributed(vision_model)(video_input) # the output will be a sequence of vectors
encoded_video = LSTM(256)(encoded_frame_sequence) # the output will be a vector
1 Answer
1
running:
vision_model.summary()
we get:
Layer (type) Output Shape Param #
=================================================================
conv2d_1 (Conv2D) (None, 224, 224, 64) 1792
_________________________________________________________________
conv2d_2 (Conv2D) (None, 222, 222, 64) 36928
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 111, 111, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 111, 111, 128) 73856
_________________________________________________________________
conv2d_4 (Conv2D) (None, 109, 109, 128) 147584
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 54, 54, 128) 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 54, 54, 256) 295168
_________________________________________________________________
conv2d_6 (Conv2D) (None, 52, 52, 256) 590080
_________________________________________________________________
conv2d_7 (Conv2D) (None, 50, 50, 256) 590080
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 25, 25, 256) 0
_________________________________________________________________
flatten_1 (Flatten) (None, 160000) 0
=================================================================
Total params: 1,735,488
Trainable params: 1,735,488
Non-trainable params: 0
vision_model.add(Flatten()) flattens vision_model.add(MaxPooling2D((2, 2))) from (None, 25, 25, 256) to (None, 160000)
vision_model.add(Flatten())
vision_model.add(MaxPooling2D((2, 2)))
By clicking "Post Your Answer", you acknowledge that you have read our updated terms of service, privacy policy and cookie policy, and that your continued use of the website is subject to these policies.