{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "headers" }, "source": [ "Project: /overview/_project.yaml\n", "Book: /overview/_book.yaml\n", "\n", "\n", "\n", "\n", "\n", "\n", "{% comment %}\n", "The source of truth file can be found [here]: http://google3/zz\n", "{% endcomment %}" ] }, { "cell_type": "markdown", "metadata": { "id": "metadata" }, "source": [ "
在 TensorFlow.org 上查看 | \n", "在 Google Colab 中运行 | \n", "在 GitHub 上查看源代码 | \n", "下载笔记本 | \n", "
tf.keras.Sequential
API 更加灵活的模型创建方式。函数式 API 可以处理具有非线性拓扑的模型、具有共享层的模型,以及具有多个输入或输出的模型。\n",
"\n",
"深度学习模型通常是层的有向无环图 (DAG)。因此,函数式 API 是构建*层计算图*的一种方式。\n",
"\n",
"请考虑以下模型:\n",
"\n",
"```\n",
"(input: 784-dimensional vectors) ↧ [Dense (64 units, relu activation)] ↧ [Dense (64 units, relu activation)] ↧ [Dense (10 units, softmax activation)] ↧ (output: logits of a probability distribution over 10 classes)\n",
"```\n",
"\n",
"这是一个具有三层的基本计算图。要使用函数式 API 构建此模型,请先创建一个输入节点:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:27.798180Z",
"iopub.status.busy": "2022-12-14T21:22:27.797454Z",
"iopub.status.idle": "2022-12-14T21:22:27.805574Z",
"shell.execute_reply": "2022-12-14T21:22:27.805022Z"
},
"id": "8d477c91955a"
},
"outputs": [],
"source": [
"inputs = keras.Input(shape=(784,))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "13c14d993620"
},
"source": [
"数据的形状设置为 784 维向量。由于仅指定了每个样本的形状,因此始终忽略批次大小。\n",
"\n",
"例如,如果您有一个形状为 `(32, 32, 3)` 的图像输入,则可以使用:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:27.809003Z",
"iopub.status.busy": "2022-12-14T21:22:27.808768Z",
"iopub.status.idle": "2022-12-14T21:22:27.813011Z",
"shell.execute_reply": "2022-12-14T21:22:27.812452Z"
},
"id": "e4732e8e279b"
},
"outputs": [],
"source": [
"# Just for demonstration purposes.\n",
"img_inputs = keras.Input(shape=(32, 32, 3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "971bf8b5588f"
},
"source": [
"返回的 `inputs` 包含馈送给模型的输入数据的形状和 `dtype`。形状如下:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:27.816291Z",
"iopub.status.busy": "2022-12-14T21:22:27.815926Z",
"iopub.status.idle": "2022-12-14T21:22:27.821946Z",
"shell.execute_reply": "2022-12-14T21:22:27.821402Z"
},
"id": "ee96c179846a"
},
"outputs": [
{
"data": {
"text/plain": [
"TensorShape([None, 784])"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.shape"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "866eee86d63e"
},
"source": [
"dtype 如下:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:27.825709Z",
"iopub.status.busy": "2022-12-14T21:22:27.825229Z",
"iopub.status.idle": "2022-12-14T21:22:27.829172Z",
"shell.execute_reply": "2022-12-14T21:22:27.828663Z"
},
"id": "480be92067f3"
},
"outputs": [
{
"data": {
"text/plain": [
"tf.float32"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.dtype"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6c93172cdfba"
},
"source": [
"可以通过在此 `inputs` 对象上调用层,在层计算图中创建新的节点:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:27.832400Z",
"iopub.status.busy": "2022-12-14T21:22:27.832052Z",
"iopub.status.idle": "2022-12-14T21:22:31.311325Z",
"shell.execute_reply": "2022-12-14T21:22:31.310539Z"
},
"id": "b50da8b1c28d"
},
"outputs": [],
"source": [
"dense = layers.Dense(64, activation=\"relu\")\n",
"x = dense(inputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0f36afe42ff3"
},
"source": [
"“层调用”操作就像从“输入”向您创建的该层绘制一个箭头。您将输入“传递”到 `dense` 层,然后得到 `x`。\n",
"\n",
"让我们为层计算图多添加几个层:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:31.315381Z",
"iopub.status.busy": "2022-12-14T21:22:31.315126Z",
"iopub.status.idle": "2022-12-14T21:22:31.338333Z",
"shell.execute_reply": "2022-12-14T21:22:31.337760Z"
},
"id": "463d5cd0c484"
},
"outputs": [],
"source": [
"x = layers.Dense(64, activation=\"relu\")(x)\n",
"outputs = layers.Dense(10)(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e379f089b044"
},
"source": [
"model = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:31.341925Z",
"iopub.status.busy": "2022-12-14T21:22:31.341339Z",
"iopub.status.idle": "2022-12-14T21:22:31.349152Z",
"shell.execute_reply": "2022-12-14T21:22:31.348616Z"
},
"id": "7820cc2209a6"
},
"outputs": [],
"source": [
"model = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c9aa111852d3"
},
"source": [
"让我们看看模型摘要是什么样子:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:31.352537Z",
"iopub.status.busy": "2022-12-14T21:22:31.352026Z",
"iopub.status.idle": "2022-12-14T21:22:31.363236Z",
"shell.execute_reply": "2022-12-14T21:22:31.362703Z"
},
"id": "4949ab8242e8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"mnist_model\"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" Layer (type) Output Shape Param # \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=================================================================\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" input_1 (InputLayer) [(None, 784)] 0 \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" dense (Dense) (None, 64) 50240 \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" dense_1 (Dense) (None, 64) 4160 \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" dense_2 (Dense) (None, 10) 650 \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"=================================================================\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total params: 55,050\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Trainable params: 55,050\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non-trainable params: 0\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"_________________________________________________________________\n"
]
}
],
"source": [
"model.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "99ab8535d6c3"
},
"source": [
"您还可以将模型绘制为计算图:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:22:31.370085Z",
"iopub.status.busy": "2022-12-14T21:22:31.369553Z",
"iopub.status.idle": "2022-12-14T21:22:31.502884Z",
"shell.execute_reply": "2022-12-14T21:22:31.502115Z"
},
"id": "6872f1b1b8b8"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOcAAAFgCAYAAABANm70AAAABmJLR0QA/wD/AP+gvaeTAAAgAElEQVR4nO3de0xUZ/4G8GeAYbjPgIto8YKXWtsupZG6VStFpaKNl1EKIt5Yra7RtnY1VuvaGuKa3Zq2tt1U19a2a5u4ETSR1eq60uiaKGNqEXC1xSrGrIpQxMIycnFgvr8/GubX0wEVHGbeGZ5PchJ5zzvnfOc953HmnJk5RyciAiJSzR4/T1dARO1jOIkUxXASKYrhJFJUwC8bLBYLtmzZ4olaiHqsPXv2OLU5vXJevXoVe/fudUtB9P/27t2La9eueboMcrNr1651mDenV8427SWZuo9Op8PKlSsxa9YsT5dCbpSXl4fMzMx25/GYk0hRDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFMVwPoBDhw5h2LBhCAjo8Mc93SYsLAw6nU4zvfPOO26vwxV86bm4kkvCabVa8fDDD2Pq1KmuWJzyysvLMX36dKxbtw5VVVUeqcFqtaK4uBgAYDabISJYvXq1R2p5UL70XFzJJeEUEdjtdtjtdlcsrluFhYVh7NixD7SMN998E2PGjEFRURHCw8NdVJlvc8W49zQueT8WHh6O8vJyVyzKK3z66acIDg72dBnk43jM2QUMJrnDA4czPz9fcyDf1NTUbvuVK1eQmZkJk8mEXr16YerUqZpX23feecfRt1+/fjh9+jRSUlIQHh6OkJAQjB8/HidPnnT037Rpk6P/z98uHT582NH+q1/9ymn5t2/fxsmTJx19PHEypzv1hHFvaWlBbm4uJk6ciD59+iA4OBjx8fH44IMPHIdWtbW1TieZNm3a5Hj8z9vT09Mdy66ursaKFSsQFxeHwMBAREdHIy0tDSUlJR2O8YULFzBr1iz06tXL0Xbz5s0uPz8H+YXc3Fxpp/mezGazAJDGxsZ2281msxQWForVapWCggIJDg6WkSNHOi0nISFBQkNDZfTo0Y7+p0+flieeeEICAwPl3//+t6Z/aGioPPPMM07LSUxMlF69ejm1d9S/q2JjY8Xf3/+BlwNAcnNzO/WY4uJix9j+kreN+92eyy8dOHBAAMif/vQnuXXrllRXV8tf/vIX8fPzk9WrV2v6Tpo0Sfz8/OTSpUtOyxk9erTs2rXL8XdFRYUMHDhQYmJi5ODBg1JfXy/nzp2T5ORkCQoKksLCQs3j28Y4OTlZjh07Jrdv35ZTp06Jv7+/VFdX3/N5iNw1b3luC+eBAwc07enp6QLA6UkkJCQIACkuLta0nz17VgBIQkKCpp3hvHs4vWXcOxvOcePGObXPmzdP9Hq91NXVOdr+9a9/CQBZvny5pu+JEyckNjZW7ty542jLzs4WAJrAiojcuHFDDAaDJCYmatrbxvjQoUP3rLkjdwun2445R44cqfm7f//+AICKigqnvqGhoXjyySc1bfHx8XjooYdQWlqKGzdudF+hPsYXx33q1Kk4duyYU3tCQgJsNhvOnz/vaEtNTUV8fDx27tyJmpoaR/vbb7+NV155BXq93tGWn58PPz8/p48E+/Tpg8cffxxFRUXtXr70N7/5jSuelhO3hdNoNGr+DgwMBIB2P34xmUztLqN3794AgB9++MHF1fkuXxz3uro6bNiwAfHx8YiMjHQc57322msAgIaGBk3/3//+92hoaMC2bdsAAN9//z2OHj2K3/3ud44+zc3NqKurg91uh9FodDpePXPmDADg4sWLTvWEhoZ2y/NU8mxtTU0NpJ07E7btHG07CwD4+fnhzp07Tn1ra2vbXbZOp3NRlb7HW8Z92rRp+OMf/4glS5bg+++/h91uh4jgvffeAwCn5zB37lzExMTgww8/RHNzM959911kZ2cjMjLS0cdgMMBkMiEgIAA2mw0i0u40fvx4lz2Pe1EynE1NTTh9+rSm7T//+Q8qKiqQkJCAvn37Otr79u2L69eva/pWVlbiv//9b7vLDgkJ0exUjzzyCD7++GMXVu+9VB/3gIAAnD9/HidPnkSfPn2wYsUKREdHO4Lf2NjY7uMMBgOWL1+OH374Ae+++y527dqFV1991alfWloaWlpaNGen22zevBkDBgxAS0tLp2p+EEqG02g04g9/+AMsFgtu376Nb775BvPmzUNgYCA++OADTd/U1FRUVFTgww8/hNVqRXl5OV599VXN//I/N2LECHz//fe4evUqLBYLLl++jKSkJHc8LeV5w7j7+/tj3LhxqKysxNtvv42bN2+isbERx44dw/bt2zt83PLlyxEcHIw33ngDzz33HIYOHerU589//jOGDBmCRYsW4Z///Cfq6upw69YtfPTRR9i4cSPeeecd93701omzR+3at2+fANBMc+fOFYvF4tS+fv16kZ/ec2imKVOmOJaXkJAgsbGx8u2338qkSZMkPDxcgoODJTk5WU6cOOG0/traWlm8eLH07dtXgoODZezYsXL69GlJTEx0LH/t2rWO/mVlZZKUlCShoaHSv39/2bp1630/1zZtp/Lbm3bs2NHp5bWNSWfO1oaGhjqt++233/bKcW/vuXQ0fffdd1JdXS1Lly6V/v37i16vl5iYGPntb38rr7/+uqPfL8+siogsWbJEAMjx48c7HNeamhpZtWqVDB48WPR6vURHR0tqaqoUFBQ4+rQ3xp3JzM+55aMUV2nbSXqazobT1XrCuH/22WfthtaTlPgohcjTtm/fjlWrVnm6jPvGcJLP+uSTTzBz5kxYrVZs374dP/74o1fdxU2ZcLZ9B7O0tBTXr1+HTqfDG2+84bb1//JzrfamnJwct9XjLp4e9+6Wn5+PyMhI/PWvf8Xu3bu96rvUOhHth0Jt9wuUdj7vou6j0+mQm5vrVf+z04O7S972KPPKSURaDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFNXh72cyMjLcWQcBeO+997Bnzx5Pl0Fu1N51cNs4/WTMYrFgy5Yt3V4UuVZ1dTW+++47PPvss54uhbqgnf+U9ziFk7wTf4frc/h7TiJVMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFBXi6AOq8a9euITs7G62trY62mzdvIiAgAOPGjdP0feSRR/DRRx+5uUJyBYbTC/Xr1w9XrlzB5cuXneYdP35c83dSUpK7yiIX49taL7VgwQLo9fp79ps9e7YbqqHuwHB6qblz58Jms921z2OPPYbHH3/cTRWRqzGcXmro0KF44oknoNPp2p2v1+uRnZ3t5qrIlRhOL7ZgwQL4+/u3O6+lpQWzZs1yc0XkSgynF8vKyoLdbndq1+l0ePrppxEXF+f+oshlGE4v9tBDD2HMmDHw89NuRn9/fyxYsMBDVZGrMJxebv78+U5tIoIXXnjBA9WQKzGcXi4jI0Pzyunv74/nnnsOvXv39mBV5AoMp5eLjIxEamqq48SQiGDevHkeropcgeH0AfPmzXOcGAoICMD06dM9XBG5AsPpA6ZPnw6DweD4d0REhIcrIldQ5ru1165dQ2FhoafL8FojRoxAYWEhBg0ahLy8PE+X47VU+mxYJyLi6SIAIC8vD5mZmZ4ug3o4ReIAAHuUeeVso9DgeI2MjAzY7XYMHToUmzdv9nQ5XknFFwcec/oIPz8/5OTkeLoMciGG04cEBwd7ugRyIYaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRflcOHfv3g2dTgedToegoCBPl6OssLAwxzi1TX5+foiMjERCQgKWL1+OoqIiT5fZo/lcOGfPng0RQUpKiqdLUZrVakVxcTEAwGw2Q0Rgs9lQVlaGjRs3oqysDE899RQWLlyIhoYGD1fbM/lcOKnr/P39ERMTA7PZjKNHj2LNmjXYuXMnsrKy+DtbD2A4qUNvvfUWnn76aezfvx+7d+/2dDk9DsNJHdLpdHj55ZcBANu2bfNwNT2P14ezrKwMM2bMgNFoRGhoKJKSknDixIkO+1dXV2PFihWIi4tDYGAgoqOjkZaWhpKSEkef/Px8zYmSK1euIDMzEyaTCb169cLUqVNRXl6uWW5zczM2bNiA4cOHIyQkBFFRUZg2bRr279+vuQP1/dagirFjxwIATp06pbnlIMfRDUQRubm50tlyLl68KCaTSWJjY+XIkSNSX18vZ8+eldTUVImLixODwaDpX1FRIQMHDpSYmBg5ePCg1NfXy7lz5yQ5OVmCgoKksLBQ099sNgsAMZvNUlhYKFarVQoKCiQ4OFhGjhyp6bt48WIxGo1y5MgRaWhokMrKSlm9erUAkGPHjnW5hvuRnp4u6enpnX5ccXGx4/l1pLGxUQAIAKmoqOjSc/CGcezK/tfN8pSppiuDk5GRIQBk7969mvbr16+LwWBwCmd2drYAkF27dmnab9y4IQaDQRITEzXtbTvVgQMHNO3p6ekCQKqrqx1tgwYNkjFjxjjVOGzYMM1O1dka7kd3hrOhocEpnL44jgznXXRlcMLDwwWA1NfXO82Lj493CqfRaBQ/Pz+pq6tz6j9ixAgBIFevXnW0te1UlZWVmr4rV64UAFJaWupoW7ZsmQCQJUuWiMVikZaWlnZr7mwN96M7w1leXi4ARK/Xy507d0TEN8dRxXB67TFnc3Mz6uvrERQUhLCwMKf5v7yRT3NzM+rq6mC322E0Gp0+gD9z5gwA4OLFi07LMhqNmr8DAwMBQHNvzK1bt+KLL77A5cuXkZKSgoiICEyePBn79u1zSQ2e0nb8Pnr0aOj1eo6jG3ltOA0GA8LDw9HU1ASr1eo0/9atW079TSYTAgICYLPZICLtTuPHj+9SPTqdDvPnz8dXX32F2tpa5OfnQ0SQlpaGLVu2uKUGV7Pb7di6dSsA4KWXXgLAcXQnrw0nADz//PMAgMOHD2vab968iQsXLjj1T0tLQ0tLC06ePOk0b/PmzRgwYABaWlq6VIvJZEJZWRkAQK/XY+LEiY6zlQcPHnRLDa62bt06fP3115g5cyYyMjIc7RxHN3HXG+h76cp7/kuXLklUVJTmbO358+dl0qRJ0rt3b6djzqqqKhkyZIgMHjxYDh06JLW1tVJTUyPbt2+XkJAQyc3N1fRvO1ZqbGzUtK9du1YASHFxsaPNaDRKcnKylJaWSlNTk1RVVUlOTo4AkE2bNnW5hvvhqmPO1tZWqaqqkvz8fJkwYYIAkEWLFklDQ4Pmcb44jioecypTTVcH58KFCzJjxgyJiIhwnJr/8ssvJSUlxXGW8cUXX3T0r6mpkVWrVsngwYNFr9dLdHS0pKamSkFBgaOPxWJxPLZtWr9+vYiIU/uUKVNERKSkpESWLl0qjz76qISEhEhUVJSMGjVKduzYIXa7XVPz/dTQGV0JZ2hoqNNz0el0YjQaJT4+XpYtWyZFRUUdPt7XxlHFcCp3IyNFyvEqbW859+zZ4+FKvJeC+98erz7mJPJlDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRDCeRohhOIkUxnESKYjiJFMVwEimK4SRSFMNJpCiGk0hRAZ4u4Jfy8vI8XYLXuXbtGgCO3YOwWCyeLsGJcuHMzMz0dAlei2PnW5S5hhA9GAWvgUMPhtcQIlIVw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFKUcredp3urrq7Gvn37NG3ffPMNAODjjz/WtIeFhWHOnDluq41ch7ed90LNzc2Ijo7G7du34e/vDwAQEYgI/Pz+/82QzWbDggUL8Pnnn3uqVOo63nbeGxkMBmRkZCAgIAA2mw02mw0tLS1obW11/G2z2QCAr5pejOH0UnPmzMGdO3fu2sdkMiElJcVNFZGrMZxeavz48YiOju5wvl6vx7x58xAQwNMK3orh9FJ+fn6YM2cOAgMD251vs9mQlZXl5qrIlRhOL5aVldXhW9u+ffti9OjRbq6IXInh9GJPP/00Bg4c6NSu1+uRnZ0NnU7ngarIVRhOLzd//nzo9XpNG9/S+gaG08vNnTvX8bFJm6FDh+KJJ57wUEXkKgynlxs+fDgee+wxx1tYvV6PhQsXergqcgWG0wcsWLDA8U0hm82GWbNmebgicgWG0wfMnj0bra2tAIDExEQMHTrUwxWRKzCcPmDgwIEYOXIkgJ9eRck3dPsX3/Py8pCZmdmdqyByOzf8XmSP277blZub665V9Uj/+9//sG3bNrz++usd9nnvvfcAACtXrnRXWT7HYrHg/fffd8u63BZOnqTofsnJyXj44Yc7nL9nzx4A3BYPyl3h5DGnD7lbMMn7MJxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUpTXhHP37t3Q6XTQ6XQICgrydDludejQIQwbNswjV28PCwtzjHvb5Ofnh8jISCQkJGD58uUoKipye109gdeEc/bs2RCRHnXvj/LyckyfPh3r1q1DVVWVR2qwWq0oLi4GAJjNZogIbDYbysrKsHHjRpSVleGpp57CwoUL0dDQ4JEafZXXhLMnevPNNzFmzBgUFRUhPDzc0+U4+Pv7IyYmBmazGUePHsWaNWuwc+dOZGVlueMKAT0G73KjsE8//RTBwcGeLuOe3nrrLRw/fhz79+/H7t27eUFrF+Erp8K8IZgAoNPp8PLLLwMAtm3b5uFqfIey4SwrK8OMGTNgNBoRGhqKpKQknDhxosP+1dXVWLFiBeLi4hAYGIjo6GikpaWhpKTE0Sc/P19zYuPKlSvIzMyEyWRCr169MHXqVJSXl2uW29zcjA0bNmD48OEICQlBVFQUpk2bhv379zsuR9mZGnzV2LFjAQCnTp3SXIGe2+UBSDfLzc2Vzq7m4sWLYjKZJDY2Vo4cOSL19fVy9uxZSU1Nlbi4ODEYDJr+FRUVMnDgQImJiZGDBw9KfX29nDt3TpKTkyUoKEgKCws1/c1mswAQs9kshYWFYrVapaCgQIKDg2XkyJGavosXLxaj0ShHjhyRhoYGqayslNWrVwsAOXbsWJdr6KzY2Fjx9/d/oGWkp6dLenp6px9XXFzsGK+ONDY2CgABIBUVFSLim9ulK/tzF+UpGc6MjAwBIHv37tW0X79+XQwGg1M4s7OzBYDs2rVL037jxg0xGAySmJioaW/bCQ4cOKBpT09PFwBSXV3taBs0aJCMGTPGqcZhw4ZpdoLO1tBZqoezoaHBKZy+uF16fDjDw8MFgNTX1zvNi4+Pdwqn0WgUPz8/qaurc+o/YsQIASBXr151tLXtBJWVlZq+K1euFABSWlrqaFu2bJkAkCVLlojFYpGWlpZ2a+5sDZ2lejjLy8sFgOj1erlz546I+OZ2cWc4lTvmbG5uRn19PYKCghAWFuY0v3fv3k796+rqYLfbYTQanT4wP3PmDADg4sWLTssyGo2av9vuEm232x1tW7duxRdffIHLly8jJSUFERERmDx5Mvbt2+eSGnxF2/mA0aNHQ6/Xc7u4gHLhNBgMCA8PR1NTE6xWq9P8W7duOfU3mUwICAiAzWaDiLQ7jR8/vkv16HQ6zJ8/H1999RVqa2uRn58PEUFaWhq2bNnilhpUZ7fbsXXrVgDASy+9BIDbxRWUCycAPP/88wCAw4cPa9pv3ryJCxcuOPVPS0tDS0sLTp486TRv8+bNGDBgAFpaWrpUi8lkQllZGYCfbq83ceJEx9nFgwcPuqUG1a1btw5ff/01Zs6ciYyMDEc7t8sD6u43zl15j37p0iWJiorSnK09f/68TJo0SXr37u10zFlVVSVDhgyRwYMHy6FDh6S2tlZqampk+/btEhISIrm5uZr+bcc2jY2Nmva1a9cKACkuLna0GY1GSU5OltLSUmlqapKqqirJyckRALJp06Yu19BZKh1ztra2SlVVleTn58uECRMEgCxatEgaGho0j/PF7dLjTwiJiFy4cEFmzJghERERjlPpX375paSkpDjOCr744ouO/jU1NbJq1SoZPHiw6PV6iY6OltTUVCkoKHD0sVgsjse2TevXrxcRcWqfMmWKiIiUlJTI0qVL5dFHH5WQkBCJioqSUaNGyY4dO8Rut2tqvp8aOuPAgQNOdbVNO3bs6PTyuhLO0NBQp3XrdDoxGo0SHx8vy5Ytk6Kiog4f72vbxZ3hdNtdxrp5NXQf2t5ytt0zhTrPjfvzHiWPOYlI0RNCRMRwut0vP2trb8rJyfF0maQA/mTMzXjsTfeLr5xEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQot/0qRafTuWtVdA/cFt6h28M5ZswY5ObmdvdqejyLxYL333+fY+1Duv0aQuQevFaTz+E1hIhUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIpy223nyXVsNhusVqum7fbt2wCAH3/8UdOu0+lgMpncVhu5DsPphWpqatCvXz+0trY6zYuKitL8PW7cOBw7dsxdpZEL8W2tF+rTpw+effZZ+PndffPpdDpkZWW5qSpyNYbTS82fPx86ne6uffz8/PDCCy+4qSJyNYbTS73wwgvw9/fvcL6/vz8mT56MXr16ubEqciWG00tFRERg8uTJCAho/7SBiGDevHluropcieH0YvPmzWv3pBAABAYGYurUqW6uiFyJ4fRi06ZNQ0hIiFN7QEAAZs6cibCwMA9URa7CcHqxoKAgpKWlQa/Xa9pbWlowd+5cD1VFrsJwerk5c+bAZrNp2iIiIjBx4kQPVUSuwnB6ueeee07zxQO9Xo/Zs2cjMDDQg1WRKzCcXi4gIACzZ892vLW12WyYM2eOh6siV2A4fUBWVpbjrW1MTAySkpI8XBG5AsPpA5555hk89NBDAH765tC9vtZH3qHbv/husViwZcuW7l5NjxceHg4AKC4uRkZGhoer8X179uzp9nV0+3+xV69exd69e7t7NT3egAEDEB4ejsjIyA77nDp1CqdOnXJjVb7n2rVrbtuf3faTMXf8T9PT5eXlYdasWR3Ob3tF5bboury8PGRmZrplXTw48SF3CyZ5H4aTSFEMJ5GiGE4iRTGcRIpiOIkUxXASKYrhJFIUw0mkKIaTSFEMJ5GiGE4iRTGcRIrymnDu3r0bOp0OOp0OQUFBni6n2/3444/Yvn07JkyYgKioKAQHB+Phhx/G3LlzUVpa6rY6wsLCHOPeNvn5+SEyMhIJCQlYvnw5ioqK3FZPT+I14Zw9ezZEBCkpKZ4uxS1ee+01vPLKKzCbzfj2229RU1ODzz77DCUlJUhMTER+fr5b6rBarSguLgYAmM1miAhsNhvKysqwceNGlJWV4amnnsLChQvR0NDglpp6Cq8JZ0+0aNEivPrqq+jTpw9CQkKQlJSEv//972htbcWaNWs8Vpe/vz9iYmJgNptx9OhRrFmzBjt37kRWVhZExGN1+Rren1NRn3zySbvtCQkJCA4ORnl5OUTknncac4e33noLx48fx/79+7F7927edtBF+MrpZW7fvo3Gxkb8+te/ViKYwE/3AX355ZcBANu2bfNwNb5D2XCWlZVhxowZMBqNCA0NRVJSEk6cONFh/+rqaqxYsQJxcXEIDAxEdHQ00tLSUFJS4uiTn5+vObFx5coVZGZmwmQyoVevXpg6dSrKy8s1y21ubsaGDRswfPhwhISEICoqCtOmTcP+/fudbiJ0PzU8qLZLjKxfv95ly3SFsWPHAvjpOkU/vwJ9T9ku3UK6WW5urnR2NRcvXhSTySSxsbFy5MgRqa+vl7Nnz0pqaqrExcWJwWDQ9K+oqJCBAwdKTEyMHDx4UOrr6+XcuXOSnJwsQUFBUlhYqOlvNpsFgJjNZiksLBSr1SoFBQUSHBwsI0eO1PRdvHixGI1GOXLkiDQ0NEhlZaWsXr1aAMixY8e6XENXVFZWSkxMjCxevLhLj09PT5f09PROP664uNgxXh1pbGwUAAJAKioqRMQ3t0tX9ucuylMynBkZGQJA9u7dq2m/fv26GAwGp3BmZ2cLANm1a5em/caNG2IwGCQxMVHT3rYTHDhwQNOenp4uAKS6utrRNmjQIBkzZoxTjcOGDdPsBJ2tobNu3rwpTz75pGRmZkpLS0uXltGd4WxoaHAKpy9ulx4fzvDwcAEg9fX1TvPi4+Odwmk0GsXPz0/q6uqc+o8YMUIAyNWrVx1tbTtBZWWlpu/KlSsFgJSWljrali1bJgBkyZIlYrFYOgxGZ2voDKvVKomJiTJnzpwuB1Oke8NZXl4uAESv18udO3dExDe3izvDqdwxZ3NzM+rr6xEUFNTu/SV79+7t1L+urg52ux1Go9HpA/MzZ84AAC5evOi0LKPRqPm77eY/drvd0bZ161Z88cUXuHz5MlJSUhx3lN63b59LariXlpYWZGRkIDY2Fp9//vldbzXvSW3nA0aPHg29Xu/z28UdlAunwWBAeHg4mpqaYLVanebfunXLqb/JZEJAQABsNhtEpN1p/PjxXapHp9Nh/vz5+Oqrr1BbW4v8/HyICNLS0hxXsu/OGpYuXYrm5mbk5eVpbjE/dOcBt98AAAJ/SURBVOhQZS4QbbfbsXXrVgDASy+9BMD3t4s7KBdOAHj++ecBAIcPH9a037x5ExcuXHDqn5aWhpaWFpw8edJp3ubNmzFgwAC0tLR0qRaTyYSysjIAP91eb+LEiY6ziwcPHuzWGnJycnD+/Hn84x//gMFg6FL97rBu3Tp8/fXXmDlzpuZWEL66Xdymu984d+U9+qVLlyQqKkpztvb8+fMyadIk6d27t9MxZ1VVlQwZMkQGDx4shw4dktraWqmpqZHt27dLSEiI5Obmavq3Hds0NjZq2teuXSsApLi42NFmNBolOTlZSktLpampSaqqqiQnJ0cAyKZNm7pcw7387W9/c5xg6WiyWCydWqarjjlbW1ulqqpK8vPzZcKECQJAFi1aJA0NDZrH+eJ26fEnhERELly4IDNmzJCIiAjHqfQvv/xSUlJSHDvniy++6OhfU1Mjq1atksGDB4ter5fo6GhJTU2VgoICRx+LxeK0g69fv15ExKl9ypQpIiJSUlIiS5culUcffVRCQkIkKipKRo0aJTt27BC73a6p+X5quF9TpkxRIpyhoaFO69XpdGI0GiU+Pl6WLVsmRUVFHT7e17aLO8OpE+neL0O23Vuim1dD94H3Snlwbtyf9yh5zElEip4QIiKG0+1++Vlbe1NOTo6nyyQF8CdjbsZjb7pffOUkUhTDSaQohpNIUQwnkaIYTiJFMZxEimI4iRTFcBIpiuEkUhTDSaQohpNIUQwnkaIYTiJFue1XKT+/8BN5RtvV+rgtuu7atWtuW1e3h7N///5IT0/v7tXQfRg1apSnS/B6/fr1c9v+3O3XECKiLuE1hIhUxXASKYrhJFIUw0mkqP8DKFnjfb+C/sgAAAAASUVORK5CYII=\n",
"text/plain": [
"tf.keras
包含了各种内置层,例如:\n",
"\n",
"- 卷积层:`Conv1D`、`Conv2D`、`Conv3D`、`Conv2DTranspose`\n",
"- 池化层:`MaxPooling1D`、`MaxPooling2D`、`MaxPooling3D`、`AveragePooling1D`\n",
"- RNN 层:`GRU`、`LSTM`、`ConvLSTM2D`\n",
"- `BatchNormalization`、`Dropout`、`Embedding` 等\n",
"\n",
"但是,如果找不到所需内容,可以通过创建您自己的层来方便地扩展 API。所有层都会子类化 `Layer` 类并实现下列方法:\n",
"\n",
"- `call` 方法,用于指定由层完成的计算。\n",
"- `build` 方法,用于创建层的权重(这只是一种样式约定,因为您也可以在 `__init__` 中创建权重)。\n",
"\n",
"要详细了解从头开始创建层的详细信息,请阅读[自定义层和模型](https://tensorflow.google.cn/guide/keras/custom_layers_and_models)指南。\n",
"\n",
"以下是 tf.keras.layers.Dense
的基本实现:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:23:05.840498Z",
"iopub.status.busy": "2022-12-14T21:23:05.839958Z",
"iopub.status.idle": "2022-12-14T21:23:05.885319Z",
"shell.execute_reply": "2022-12-14T21:23:05.884710Z"
},
"id": "1d9faf1f622a"
},
"outputs": [],
"source": [
"class CustomDense(layers.Layer):\n",
" def __init__(self, units=32):\n",
" super(CustomDense, self).__init__()\n",
" self.units = units\n",
"\n",
" def build(self, input_shape):\n",
" self.w = self.add_weight(\n",
" shape=(input_shape[-1], self.units),\n",
" initializer=\"random_normal\",\n",
" trainable=True,\n",
" )\n",
" self.b = self.add_weight(\n",
" shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
" )\n",
"\n",
" def call(self, inputs):\n",
" return tf.matmul(inputs, self.w) + self.b\n",
"\n",
"\n",
"inputs = keras.Input((4,))\n",
"outputs = CustomDense(10)(inputs)\n",
"\n",
"model = keras.Model(inputs, outputs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b8933568358c"
},
"source": [
"为了在您的自定义层中支持序列化,请定义一个`get_config`方法,该方法返回该层实例的构造函数参数:"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:23:05.888754Z",
"iopub.status.busy": "2022-12-14T21:23:05.888186Z",
"iopub.status.idle": "2022-12-14T21:23:05.915702Z",
"shell.execute_reply": "2022-12-14T21:23:05.915099Z"
},
"id": "b22a134918a2"
},
"outputs": [],
"source": [
"class CustomDense(layers.Layer):\n",
" def __init__(self, units=32):\n",
" super(CustomDense, self).__init__()\n",
" self.units = units\n",
"\n",
" def build(self, input_shape):\n",
" self.w = self.add_weight(\n",
" shape=(input_shape[-1], self.units),\n",
" initializer=\"random_normal\",\n",
" trainable=True,\n",
" )\n",
" self.b = self.add_weight(\n",
" shape=(self.units,), initializer=\"random_normal\", trainable=True\n",
" )\n",
"\n",
" def call(self, inputs):\n",
" return tf.matmul(inputs, self.w) + self.b\n",
"\n",
" def get_config(self):\n",
" return {\"units\": self.units}\n",
"\n",
"\n",
"inputs = keras.Input((4,))\n",
"outputs = CustomDense(10)(inputs)\n",
"\n",
"model = keras.Model(inputs, outputs)\n",
"config = model.get_config()\n",
"\n",
"new_model = keras.Model.from_config(config, custom_objects={\"CustomDense\": CustomDense})"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "015abf7d0508"
},
"source": [
"您也可以选择实现 `from_config(cls, config)` 类方法,该方法用于在给定其配置字典的情况下重新创建层实例。`from_config` 的默认实现如下:\n",
"\n",
"```python\n",
"def from_config(cls, config): return cls(**config)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "b4ead34e01dd"
},
"source": [
"## 何时使用函数式 API\n",
"\n",
"什么时候应该使用 Keras 函数式 API 来创建新的模型,或者什么时候应该直接对 `Model` 类进行子类化呢?通常来说,函数式 API 更高级、更易用且更安全,并且具有许多子类化模型所不支持的功能。\n",
"\n",
"但是,当构建不容易表示为有向无环的层计算图的模型时,模型子类化会提供更大的灵活性。例如,您无法使用函数式 API 来实现 Tree-RNN,而必须直接子类化 `Model` 类。\n",
"\n",
"要深入了解函数式 API 和模型子类化之间的区别,请阅读 [TensorFlow 2.0 符号式 API 和命令式 API 介绍](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html)。\n",
"\n",
"### 函数式 API 的优势:\n",
"\n",
"下列属性对于序贯模型(也是数据结构)同样适用,但对于子类化模型(是 Python 字节码而非数据结构)则不适用。\n",
"\n",
"#### 更加简洁\n",
"\n",
"没有 `super(MyClass, self).__init__(...)`,没有 `def call(self, ...):` 等内容。\n",
"\n",
"对比:\n",
"\n",
"```python\n",
"inputs = keras.Input(shape=(32,)) x = layers.Dense(64, activation='relu')(inputs) outputs = layers.Dense(10)(x) mlp = keras.Model(inputs, outputs)\n",
"```\n",
"\n",
"下面是子类化版本:\n",
"\n",
"```python\n",
"class MLP(keras.Model): def __init__(self, **kwargs): super(MLP, self).__init__(**kwargs) self.dense_1 = layers.Dense(64, activation='relu') self.dense_2 = layers.Dense(10) def call(self, inputs): x = self.dense_1(inputs) return self.dense_2(x) # Instantiate the model. mlp = MLP() # Necessary to create the model's state. # The model doesn't have a state until it's called at least once. _ = mlp(tf.zeros((1, 32)))\n",
"```\n",
"\n",
"#### 定义连接计算图时进行模型验证\n",
"\n",
"在函数式 API 中,输入规范(形状和 dtype)是预先创建的(使用 `Input`)。每次调用层时,该层都会检查传递给它的规范是否符合其假设,如不符合,它将引发有用的错误消息。\n",
"\n",
"这样可以保证能够使用函数式 API 构建的任何模型都可以运行。所有调试(除与收敛有关的调试外)均在模型构造的过程中静态发生,而不是在执行时发生。这类似于编译器中的类型检查。\n",
"\n",
"#### 函数式模型可绘制且可检查\n",
"\n",
"您可以将模型绘制为计算图,并且可以轻松访问该计算图中的中间节点。例如,要提取和重用中间层的激活(如前面的示例所示),请运行以下代码:\n",
"\n",
"```python\n",
"features_list = [layer.output for layer in vgg19.layers] feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n",
"```\n",
"\n",
"#### 函数式模型可以序列化或克隆\n",
"\n",
"因为函数式模型是数据结构而非一段代码,所以它可以安全地序列化,并且可以保存为单个文件,从而使您可以重新创建完全相同的模型,而无需访问任何原始代码。请参阅[序列化和保存](https://tensorflow.google.cn/guide/keras/save_and_serialize/)指南。\n",
"\n",
"要序列化子类化模型,实现器必须在模型级别指定 `get_config()` 和 `from_config()` 方法。\n",
"\n",
"### 函数式 API 的劣势:\n",
"\n",
"#### 不支持动态架构\n",
"\n",
"函数式 API 将模型视为层的 DAG。对于大多数深度学习架构来说确实如此,但并非所有(例如,递归网络或 Tree RNN 就不遵循此假设,无法在函数式 API 中实现)。"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "72992d4ed462"
},
"source": [
"## 混搭 API 样式\n",
"\n",
"在函数式 API 或模型子类化之间进行选择并非是让您作出二选一的决定而将您限制在某一类模型中。tf.keras
API 中的所有模型都可以彼此交互,无论它们是 `Sequential` 模型、函数式模型,还是从头开始编写的子类化模型。\n",
"\n",
"您始终可以将函数式模型或 `Sequential` 模型用作子类化模型或层的一部分:"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:23:05.919621Z",
"iopub.status.busy": "2022-12-14T21:23:05.918997Z",
"iopub.status.idle": "2022-12-14T21:23:05.976125Z",
"shell.execute_reply": "2022-12-14T21:23:05.975569Z"
},
"id": "3c6221508766"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(1, 10, 32)\n"
]
}
],
"source": [
"units = 32\n",
"timesteps = 10\n",
"input_dim = 5\n",
"\n",
"# Define a Functional model\n",
"inputs = keras.Input((None, units))\n",
"x = layers.GlobalAveragePooling1D()(inputs)\n",
"outputs = layers.Dense(1)(x)\n",
"model = keras.Model(inputs, outputs)\n",
"\n",
"\n",
"class CustomRNN(layers.Layer):\n",
" def __init__(self):\n",
" super(CustomRNN, self).__init__()\n",
" self.units = units\n",
" self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n",
" self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n",
" # Our previously-defined Functional model\n",
" self.classifier = model\n",
"\n",
" def call(self, inputs):\n",
" outputs = []\n",
" state = tf.zeros(shape=(inputs.shape[0], self.units))\n",
" for t in range(inputs.shape[1]):\n",
" x = inputs[:, t, :]\n",
" h = self.projection_1(x)\n",
" y = h + self.projection_2(state)\n",
" state = y\n",
" outputs.append(y)\n",
" features = tf.stack(outputs, axis=1)\n",
" print(features.shape)\n",
" return self.classifier(features)\n",
"\n",
"\n",
"rnn_model = CustomRNN()\n",
"_ = rnn_model(tf.zeros((1, timesteps, input_dim)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "41f42eb2a9c0"
},
"source": [
"您可以在函数式 API 中使用任何子类化层或模型,前提是它实现了遵循以下模式之一的 `call` 方法:\n",
"\n",
"- `call(self, inputs, **kwargs)` - 其中 `inputs` 是张量或张量的嵌套结构(例如张量列表),`**kwargs` 是非张量参数(非输入)。\n",
"- `call(self, inputs, training=None, **kwargs)` - 其中 `training` 是指示该层是否应在训练模式和推断模式下运行的布尔值。\n",
"- `call(self, inputs, mask=None, **kwargs)` - 其中 `mask` 是一个布尔掩码张量(对 RNN 等十分有用)。\n",
"- `call(self, inputs, training=None, mask=None, **kwargs)` - 当然,您可以同时具有掩码和训练特有的行为。\n",
"\n",
"此外,如果您在自定义层或模型上实现了 `get_config` 方法,则您创建的函数式模型将仍可序列化和克隆。\n",
"\n",
"下面是一个从头开始编写、用于函数式模型的自定义 RNN 的简单示例:"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"execution": {
"iopub.execute_input": "2022-12-14T21:23:05.979347Z",
"iopub.status.busy": "2022-12-14T21:23:05.978887Z",
"iopub.status.idle": "2022-12-14T21:23:06.161023Z",
"shell.execute_reply": "2022-12-14T21:23:06.160378Z"
},
"id": "3deb90222d05"
},
"outputs": [],
"source": [
"units = 32\n",
"timesteps = 10\n",
"input_dim = 5\n",
"batch_size = 16\n",
"\n",
"\n",
"class CustomRNN(layers.Layer):\n",
" def __init__(self):\n",
" super(CustomRNN, self).__init__()\n",
" self.units = units\n",
" self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n",
" self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n",
" self.classifier = layers.Dense(1)\n",
"\n",
" def call(self, inputs):\n",
" outputs = []\n",
" state = tf.zeros(shape=(inputs.shape[0], self.units))\n",
" for t in range(inputs.shape[1]):\n",
" x = inputs[:, t, :]\n",
" h = self.projection_1(x)\n",
" y = h + self.projection_2(state)\n",
" state = y\n",
" outputs.append(y)\n",
" features = tf.stack(outputs, axis=1)\n",
" return self.classifier(features)\n",
"\n",
"\n",
"# Note that you specify a static batch size for the inputs with the `batch_shape`\n",
"# arg, because the inner computation of `CustomRNN` requires a static batch size\n",
"# (when you create the `state` zeros tensor).\n",
"inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))\n",
"x = layers.Conv1D(32, 3)(inputs)\n",
"outputs = CustomRNN()(x)\n",
"\n",
"model = keras.Model(inputs, outputs)\n",
"\n",
"rnn_model = CustomRNN()\n",
"_ = rnn_model(tf.zeros((1, 10, 5)))"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "functional.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
}
},
"nbformat": 4,
"nbformat_minor": 0
}