LSTMBlockCell

public final class LSTMBlockCell

Computes the LSTM cell forward propagation for 1 time step.

This implementation uses 1 weight matrix and 1 bias vector, and there's an optional peephole connection.

This kernel op implements the following mathematical equations:

xh = [x, h_prev]
 [i, f, ci, o] = xh * w + b
 f = f + forget_bias
 
 if not use_peephole:
   wci = wcf = wco = 0
 
 i = sigmoid(cs_prev * wci + i)
 f = sigmoid(cs_prev * wcf + f)
 ci = tanh(ci)
 
 cs = ci .* i + cs_prev .* f
 cs = clip(cs, cell_clip)
 
 o = sigmoid(cs * wco + o)
 co = tanh(cs)
 h = co .* o
 

Nested Classes

class LSTMBlockCell.Options Optional attributes for LSTMBlockCell  

Public Methods

static LSTMBlockCell.Options
cellClip(Float cellClip)
Output<T>
ci()
The cell input.
Output<T>
co()
The cell after the tanh.
static <T extends Number> LSTMBlockCell<T>
create(Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, Options... options)
Factory method to create a class wrapping a new LSTMBlockCell operation.
Output<T>
cs()
The cell state before the tanh.
Output<T>
f()
The forget gate.
static LSTMBlockCell.Options
forgetBias(Float forgetBias)
Output<T>
h()
The output h vector.
Output<T>
i()
The input gate.
Output<T>
o()
The output gate.
static LSTMBlockCell.Options
usePeephole(Boolean usePeephole)

Inherited Methods

org.tensorflow.op.PrimitiveOp
final boolean
equals(Object obj)
final int
Operation
op()
Returns the underlying Operation
final String
boolean
equals(Object arg0)
final Class<?>
getClass()
int
hashCode()
final void
notify()
final void
notifyAll()
String
toString()
final void
wait(long arg0, int arg1)
final void
wait(long arg0)
final void
wait()

Public Methods

public static LSTMBlockCell.Options cellClip (Float cellClip)

Parameters
cellClip Value to clip the 'cs' value to.

public Output<T> ci ()

The cell input.

public Output<T> co ()

The cell after the tanh.

public static LSTMBlockCell<T> create (Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, Options... options)

Factory method to create a class wrapping a new LSTMBlockCell operation.

Parameters
scope current scope
x The input to the LSTM cell, shape (batch_size, num_inputs).
csPrev Value of the cell state at previous time step.
hPrev Output of the previous cell at previous time step.
w The weight matrix.
wci The weight matrix for input gate peephole connection.
wcf The weight matrix for forget gate peephole connection.
wco The weight matrix for output gate peephole connection.
b The bias vector.
options carries optional attributes values
Returns
  • a new instance of LSTMBlockCell

public Output<T> cs ()

The cell state before the tanh.

public Output<T> f ()

The forget gate.

public static LSTMBlockCell.Options forgetBias (Float forgetBias)

Parameters
forgetBias The forget gate bias.

public Output<T> h ()

The output h vector.

public Output<T> i ()

The input gate.

public Output<T> o ()

The output gate.

public static LSTMBlockCell.Options usePeephole (Boolean usePeephole)

Parameters
usePeephole Whether to use peephole weights.