LSTMBlockCell

public final class LSTMBlockCell

Computes the LSTM cell forward propagation for 1 time step.

This implementation uses 1 weight matrix and 1 bias vector, and there's an optional peephole connection.

This kernel op implements the following mathematical equations:

xh = [x, h_prev]
 [i, f, ci, o] = xh * w + b
 f = f + forget_bias
 
 if not use_peephole:
   wci = wcf = wco = 0
 
 i = sigmoid(cs_prev * wci + i)
 f = sigmoid(cs_prev * wcf + f)
 ci = tanh(ci)
 
 cs = ci .* i + cs_prev .* f
 cs = clip(cs, cell_clip)
 
 o = sigmoid(cs * wco + o)
 co = tanh(cs)
 h = co .* o
 

Nested Classes

class LSTMBlockCell.Options Optional attributes for LSTMBlockCell  

Public Methods

static LSTMBlockCell.Options
cellClip(Float cellClip)
Output<T>
ci()
The cell input.
Output<T>
co()
The cell after the tanh.
static <T extends Number> LSTMBlockCell<T>
create(Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, Options... options)
Factory method to create a class wrapping a new LSTMBlockCell operation.
Output<T>
cs()
The cell state before the tanh.
Output<T>
f()
The forget gate.
static LSTMBlockCell.Options
forgetBias(Float forgetBias)
Output<T>
h()
The output h vector.
Output<T>
i()
The input gate.
Output<T>
o()
The output gate.
static LSTMBlockCell.Options
usePeephole(Boolean usePeephole)

Inherited Methods

Public Methods

public static LSTMBlockCell.Options cellClip (Float cellClip)

Parameters
cellClip Value to clip the 'cs' value to.

public Output<T> ci ()

The cell input.

public Output<T> co ()

The cell after the tanh.

public static LSTMBlockCell<T> create (Scope scope, Operand<T> x, Operand<T> csPrev, Operand<T> hPrev, Operand<T> w, Operand<T> wci, Operand<T> wcf, Operand<T> wco, Operand<T> b, Options... options)

Factory method to create a class wrapping a new LSTMBlockCell operation.

Parameters
scope current scope
x The input to the LSTM cell, shape (batch_size, num_inputs).
csPrev Value of the cell state at previous time step.
hPrev Output of the previous cell at previous time step.
w The weight matrix.
wci The weight matrix for input gate peephole connection.
wcf The weight matrix for forget gate peephole connection.
wco The weight matrix for output gate peephole connection.
b The bias vector.
options carries optional attributes values
Returns
  • a new instance of LSTMBlockCell

public Output<T> cs ()

The cell state before the tanh.

public Output<T> f ()

The forget gate.

public static LSTMBlockCell.Options forgetBias (Float forgetBias)

Parameters
forgetBias The forget gate bias.

public Output<T> h ()

The output h vector.

public Output<T> i ()

The input gate.

public Output<T> o ()

The output gate.

public static LSTMBlockCell.Options usePeephole (Boolean usePeephole)

Parameters
usePeephole Whether to use peephole weights.