Thanks for tuning in to Google I/O. View all sessions on demandWatch on demand


Sparse MoE layer with per-token routing.

In this TF implementation, all experts need to fit onto a single device allowing for batch parallelism only.

Uses Keras add_loss() and add_metric() APIs.

experts Instance of FeedForwardExperts. Needs to have the same num_experts as the router.
router Instance of MaskedRouter to route the tokens to the different experts.
train_capacity_factor Scaling factor to increase the expert token capacity during training. This factor plays an analogous, but slightly different, role depending on the routing assignment algorithm:

  • For "tokens choose" routing, the capacity factor only affects the maximum number of tokens that an expert will process. It does not affect how many experts a given token is routed to; see the num_selected_experts attributes of "tokens choose" routers.
  • For "experts choose" routing, because experts always fill their buffer, increasing the capacity factor will increase the number of tokens that an expert will process AND will indirectly increase the number of experts that a given token is routed to.
eval_capacity_factor As above, but used during evaluation.
examples_per_group Number of examples to form a group. Router then performs top_k token selection for each expert on a per group basis. E.g. when examples_per_group=4.0, tokens are assigned to experts in groups formed from 4 examples. When examples_per_group=0.5, each example is split into 2 groups. examples_per_group must divide the local batch size. A larger group size will result in slower but more accurate top-k and sorting computations, whereas a smaller group size will result in faster but more approximate (and potentially less stable) routing choices. In practice, we find that imperfect routing choices are tolerable and recommend choosing a group size on the order of 4096 tokens, although this number will vary based on model configuration and size.
name Layer name.
**kwargs Forwarded to super.

num_experts Number of experts (i.e. number of independent feed-forward blocks).



View source

Applies MoeLayer.

inputs Batch of input embeddings of shape [batch_size, seq_length, hidden_dim].
training Only apply dropout and jitter noise during training. If not provided taken from tf.keras.backend.

Transformed inputs with same shape as inputs: [batch_size, seq_length, hidden_dim].

ValueError if we cannot find a group_size satisfying given requirements.