Bidirectional wrapper for RNNs.
Inherits From: Wrapper
View aliases
Compat aliases for migration
See Migration guide for more details.
tf.keras.layers.Bidirectional(
layer, merge_mode='concat', weights=None, backward_layer=None, **kwargs
)
Arguments | |
---|---|
layer
|
keras.layers.RNN instance, such as keras.layers.LSTM or
keras.layers.GRU . It could also be a keras.layers.Layer instance
that meets the following criteria:
|
merge_mode
|
Mode by which outputs of the forward and backward RNNs will be combined. One of {'sum', 'mul', 'concat', 'ave', None}. If None, the outputs will not be combined, they will be returned as a list. Default value is 'concat'. |
backward_layer
|
Optional keras.layers.RNN , or keras.layers.Layerinstance
to be used to handle backwards input processing. If backward_layeris
not provided, the layer instance passed as the layerargument will be
used to generate the backward layer automatically.
Note that the provided backward_layerlayer should have properties
matching those of the layerargument, in particular it should have the
same values for stateful, return_states, return_sequence, etc.
In addition, backward_layerand layershould have different go_backwardsargument values.
A ValueError` will be raised if these requirements are not met.
|
Call arguments:
The call arguments for this layer are the same as those of the wrapped RNN layer.
Raises | |
---|---|
ValueError
|
|
Examples:
model = Sequential()
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5, 10)))
model.add(Bidirectional(LSTM(10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
# With custom backward layer
model = Sequential()
forward_layer = LSTM(10, return_sequences=True)
backward_layer = LSTM(10, activation='relu', return_sequences=True,
go_backwards=True)
model.add(Bidirectional(forward_layer, backward_layer=backward_layer,
input_shape=(5, 10)))
model.add(Dense(5))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
Attributes | |
---|---|
constraints
|
Methods
reset_states
reset_states()