tf.contrib.recurrent.Recurrent
Stay organized with collections
Save and categorize content based on your preferences.
Compute a recurrent neural net.
tf.contrib.recurrent.Recurrent(
theta, state0, inputs, cell_fn, cell_grad=None, extras=None,
max_input_length=None, use_tpu=False, aligned_end=False
)
Roughly, Recurrent() computes the following:
state = state0
for t in inputs' sequence length:
state = cell_fn(theta, state, inputs[t, :])
accumulate_state[t, :] = state
return accumulate_state, state
theta, state, inputs are all structures of tensors.
inputs[t, :] means taking a slice out from every tensor in the inputs.
accumulate_state[t, :] = state means that we stash every tensor in
'state' into a slice of the corresponding tensor in
accumulate_state.
cell_fn is a python callable computing (building up a TensorFlow
graph) the recurrent neural network's one forward step. Two calls of
cell_fn must describe two identical computations.
By construction, Recurrent()'s backward computation does not access
any intermediate values computed by cell_fn during forward
computation. We may extend Recurrent() to support that by taking a
customized backward function of cell_fn.
Args |
theta
|
weights. A structure of tensors.
|
state0
|
initial state. A structure of tensors.
|
inputs
|
inputs. A structure of tensors.
|
cell_fn
|
A python function, which computes:
state1, extras = cell_fn(theta, state0, inputs[t, :])
|
cell_grad
|
A python function which computes:
dtheta, dstate0, dinputs[t, :] = cell_grad(
theta, state0, inputs[t, :], extras, dstate1)
|
extras
|
A structure of tensors. The 2nd return value of every
invocation of cell_fn is a structure of tensors with matching keys
and shapes of this extras .
|
max_input_length
|
maximum length of effective input. This is used to
truncate the computation if the inputs have been allocated to a
larger size. A scalar tensor.
|
use_tpu
|
whether or not we are on TPU.
|
aligned_end
|
A boolean indicating whether the sequence is aligned at
the end.
|
Returns |
accumulate_state and the final state.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[null,null,["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.recurrent.Recurrent\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v1.15.0/tensorflow/contrib/recurrent/python/ops/recurrent.py#L655-L742) |\n\nCompute a recurrent neural net. \n\n tf.contrib.recurrent.Recurrent(\n theta, state0, inputs, cell_fn, cell_grad=None, extras=None,\n max_input_length=None, use_tpu=False, aligned_end=False\n )\n\nRoughly, Recurrent() computes the following:\nstate = state0\nfor t in inputs' sequence length:\nstate = cell_fn(theta, state, inputs\\[t, :\\])\naccumulate_state\\[t, :\\] = state\nreturn accumulate_state, state\n\ntheta, state, inputs are all structures of tensors.\n\ninputs\\[t, :\\] means taking a slice out from every tensor in the inputs.\n\naccumulate_state\\[t, :\\] = state means that we stash every tensor in\n'state' into a slice of the corresponding tensor in\naccumulate_state.\n\ncell_fn is a python callable computing (building up a TensorFlow\ngraph) the recurrent neural network's one forward step. Two calls of\ncell_fn must describe two identical computations.\n\nBy construction, Recurrent()'s backward computation does not access\nany intermediate values computed by cell_fn during forward\ncomputation. We may extend Recurrent() to support that by taking a\ncustomized backward function of cell_fn.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `theta` | weights. A structure of tensors. |\n| `state0` | initial state. A structure of tensors. |\n| `inputs` | inputs. A structure of tensors. |\n| `cell_fn` | A python function, which computes: state1, extras = cell_fn(theta, state0, inputs\\[t, :\\]) |\n| `cell_grad` | A python function which computes: dtheta, dstate0, dinputs\\[t, :\\] = cell_grad( theta, state0, inputs\\[t, :\\], extras, dstate1) |\n| `extras` | A structure of tensors. The 2nd return value of every invocation of cell_fn is a structure of tensors with matching keys and shapes of this `extras`. |\n| `max_input_length` | maximum length of effective input. This is used to truncate the computation if the inputs have been allocated to a larger size. A scalar tensor. |\n| `use_tpu` | whether or not we are on TPU. |\n| `aligned_end` | A boolean indicating whether the sequence is aligned at the end. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| accumulate_state and the final state. ||\n\n\u003cbr /\u003e"]]