From 772daa72c591b7d661ca896c2a9824cd20566af0 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 16 Aug 2024 09:34:25 -0700 Subject: [PATCH] Keep column name, and add an axis. --- site/en/tutorials/keras/regression.ipynb | 79 ++++++++++++++---------- 1 file changed, 48 insertions(+), 31 deletions(-) diff --git a/site/en/tutorials/keras/regression.ipynb b/site/en/tutorials/keras/regression.ipynb index d5da28a782..5064d7ce17 100644 --- a/site/en/tutorials/keras/regression.ipynb +++ b/site/en/tutorials/keras/regression.ipynb @@ -368,10 +368,7 @@ "test_features = test_dataset.copy()\n", "\n", "train_labels = train_features.pop('MPG')\n", - "test_labels = test_features.pop('MPG')", - "\n", - "train_features = tf.convert_to_tensor(train_features, dtype=tf.float32)\n", - "test_features = tf.convert_to_tensor(test_features, dtype=tf.float32)" + "test_labels = test_features.pop('MPG')" ] }, { @@ -492,7 +489,7 @@ }, "outputs": [], "source": [ - "first = np.array(train_features[:1])\n", + "first = np.array(train_features[:1], dtype=float)\n", "\n", "with np.printoptions(precision=2, suppress=True):\n", " print('First example:', first)\n", @@ -548,9 +545,11 @@ }, "outputs": [], "source": [ - "horsepower = np.array(train_features[:, 2])\n", + "horsepower = np.array(train_features['Horsepower'], dtype=float)\n", + "horsepower = horsepower[:, None]\n", + "print(horsepower.shape)\n", "\n", - "horsepower_normalizer = layers.Normalization(input_shape=[1,], axis=None)\n", + "horsepower_normalizer = layers.Normalization(axis=1, input_shape=[1])\n", "horsepower_normalizer.adapt(horsepower)" ] }, @@ -574,9 +573,7 @@ "horsepower_model = tf.keras.Sequential([\n", " horsepower_normalizer,\n", " layers.Dense(units=1)\n", - "])\n", - "\n", - "horsepower_model.summary()" + "])" ] }, { @@ -601,6 +598,17 @@ "horsepower_model.predict(horsepower[:10])" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "eNDk5U4S-WVR" + }, + "outputs": [], + "source": [ + "horsepower_model.summary()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -642,7 +650,7 @@ "source": [ "%%time\n", "history = horsepower_model.fit(\n", - " train_features[:, 2],\n", + " train_features['Horsepower'],\n", " train_labels,\n", " epochs=100,\n", " # Suppress logging.\n", @@ -722,7 +730,7 @@ "test_results = {}\n", "\n", "test_results['horsepower_model'] = horsepower_model.evaluate(\n", - " test_features[:, 2],\n", + " test_features['Horsepower'],\n", " test_labels, verbose=0)" ] }, @@ -756,7 +764,7 @@ "outputs": [], "source": [ "def plot_horsepower(x, y):\n", - " plt.scatter(train_features[:, 2], train_labels, label='Data')\n", + " plt.scatter(train_features['Horsepower'], train_labels, label='Data')\n", " plt.plot(x, y, color='k', label='Predictions')\n", " plt.xlabel('Horsepower')\n", " plt.ylabel('MPG')\n", @@ -1026,17 +1034,6 @@ "This model has quite a few more trainable parameters than the linear models:" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ReAD0n6MsFK-" - }, - "outputs": [], - "source": [ - "dnn_horsepower_model.summary()" - ] - }, { "cell_type": "markdown", "metadata": { @@ -1056,12 +1053,23 @@ "source": [ "%%time\n", "history = dnn_horsepower_model.fit(\n", - " train_features[:, 2],\n", + " train_features['Horsepower'],\n", " train_labels,\n", " validation_split=0.2,\n", " verbose=0, epochs=100)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ReAD0n6MsFK-" + }, + "outputs": [], + "source": [ + "dnn_horsepower_model.summary()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -1132,7 +1140,7 @@ "outputs": [], "source": [ "test_results['dnn_horsepower_model'] = dnn_horsepower_model.evaluate(\n", - " test_features[:, 2], test_labels,\n", + " test_features['Horsepower'], test_labels,\n", " verbose=0)" ] }, @@ -1162,8 +1170,7 @@ }, "outputs": [], "source": [ - "dnn_model = build_and_compile_model(normalizer)\n", - "dnn_model.summary()" + "dnn_model = build_and_compile_model(normalizer)" ] }, { @@ -1182,6 +1189,17 @@ " verbose=0, epochs=100)" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5R9yfjpC_qcT" + }, + "outputs": [], + "source": [ + "dnn_model.summary()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -1324,7 +1342,7 @@ }, "outputs": [], "source": [ - "dnn_model.save('dnn_model.tf', save_format='tf')" + "dnn_model.save('dnn_model.keras')" ] }, { @@ -1344,7 +1362,7 @@ }, "outputs": [], "source": [ - "reloaded = tf.keras.models.load_model('dnn_model.tf')\n", + "reloaded = tf.keras.models.load_model('dnn_model.keras')\n", "\n", "test_results['reloaded'] = reloaded.evaluate(\n", " test_features, test_labels, verbose=0)" @@ -1380,7 +1398,6 @@ ], "metadata": { "colab": { - "collapsed_sections": [], "name": "regression.ipynb", "toc_visible": true },