{ "cells": [ { "cell_type": "markdown", "source": [ "# Neural networks and gradient calculations with Pytorch" ], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Introduction\n", "\n", "The latest Reinforcement Learning assignment involves some basic neural networks and gradient computations. Pytorch, an open-source machine learning framework, offers a bunch of features to construct, train and deploy neural networks as well as calculate derivatives, etc. This tutorial aims to offer you some basic background of neural networks and gradient calculations to help to start your assignment. This tutorial includes two parts. In the first part, we cover the basic linear regression model and neural network and you can skip this part if you have prior knowledge. In the second part, we talk about how to stop gradient. This operator is used when implementing the DQN algorithm. So please make sure you understand this operator before Exercise 4.\n", "\n", "If you are interested in diving into the Pytorch and deep learning, please check these excellent materials. \n", "\n", "1. Pytorch Tutorial: https://pytorch.org/tutorials/\n", "\n", "2. Dive into Deep Learing: https://d2l.ai/\n", "\n", "3. Deep Learning course: CS-E4890, Aalto; CS231N, Stanford http://cs231n.stanford.edu/" ], "metadata": {} }, { "cell_type": "code", "execution_count": 1, "source": [ "import random\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ], "outputs": [], "metadata": {} }, { "cell_type": "markdown", "source": [ "# Part I\n", "## 1. Let's start with linear regression\n", "Suppose we have a bunch of data points generated from a linear model $y_i = wx_i + b$ with additive noise. Our task is to decide the linear model's weight $w$ and bias $b$ using these data on hand." ], "metadata": {} }, { "cell_type": "code", "execution_count": 2, "source": [ "# generate synthetic data\n", "def synthetic_data(w, b, num_examples): #@save\n", " \"\"\"Generate y = Xw + b + noise.\"\"\"\n", " X = torch.normal(0, 1, (num_examples, len(w)))\n", " y = torch.matmul(X, w) + b\n", " y += torch.normal(0, 0.5, y.shape) # additive noise\n", " return X, y.reshape((-1, 1))\n", "\n", "true_w = torch.tensor([-3.4])\n", "true_b = torch.tensor([4.2])\n", "\n", "# generate data\n", "features, labels = synthetic_data(true_w, true_b, 1000) # generate 1000 data points" ], "outputs": [], "metadata": {} }, { "cell_type": "code", "execution_count": 3, "source": [ "# plot data\n", "plt.scatter(features, labels)" ], "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 3 }, { "output_type": "display_data", "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaiElEQVR4nO3df4zc9X3n8dd710MZk7ZrhFvhCcYWipyLz2VXXRFXPp2Cm2BSCkyoqIvINdLdhfaP9GqOW3VJfLFJXLF3vhyWTlVVqqLmFB81CWQPzmlNWnPK1TrTrON1jAu+kgQM6yi4Zw892KGMd9/3x853PTs7353f8/0xz4e0wjP79cxnZPPej9/f9+f9NncXACD5BqJeAACgMwjoAJASBHQASAkCOgCkBAEdAFJiVRRvet111/mGDRuieGsASKwTJ078vbuvDft+JAF9w4YNmpqaiuKtASCxzOz1lb5PygUAUoKADgApQUAHgJQgoANAShDQASAlIqlyacXkyRntP3JW5wtFrRvKamzHJuVHclEvCwBiIxEBffLkjB5+5rSKpTlJ0kyhqIefOS1JBHUAKEtEymX/kbOLwTxQLM1p/5GzEa0IAOInEQH9fKHY1PMA0I8SEdDXDWWbeh4A+lEiAvrYjk3KZgaXPJfNDGpsx6aIVgQA8ZOIm6LBjU+qXAAgXCICurQQ1AngABAuESkXAEB9BHQASAkCOgCkBAEdAFKCgA4AKUFAB4CUIKADQEo0HNDN7Akze8vMXqp4bq+ZzZjZdPnrV7qzTABAPc3s0P9U0u01nn/M3YfLX9/qzLIAAM1qOKC7+3ckXeziWgAAbejE0f/PmdlvSpqS9JC7X6p1kZk9IOkBSVq/fn0H3rZ7mI4EIInavSn6h5JukjQs6ceSvhJ2obs/7u6j7j66du3aNt+2e4LpSDOFolxXpiNNnpyJemkAsKK2Arq7/8Td59x9XtIfS7qlM8uKDtORACRVWwHdzK6vePgpSS+FXZsUTEcCkFQN59DN7ElJH5N0nZm9KWmPpI+Z2bAkl/SapN/q/BJ7a91QVjM1gncwHYn8OoC4Mnfv+ZuOjo761NRUz9+3EUEOvTLtks0M6tF7tkhS6PcI6gC6zcxOuPto2Pc5KVolP5LTo/dsUW4oK5OUG8ouBmzy6wDiLDETi3opbDoS+XUAccYOvQlBHr3R5wGglwjoTRjbsUnZzOCS57KZQY3t2BTRigDgClIuTQjSMFS5AIgjAnqTwvLrABA1Ui4AkBIEdABICVIuEeC0KYBuIKD3yO7J03ryxTc0V3UyN+jmKImgDqAtpFx6YPfkaX3t+LllwTzAaVMAnUBA74EnX3yj7jUzhaK2TRyl7zqAlhHQeyBsZ16NYRoA2kFA74FBs4avJf0CoFUE9B6476M3NHU9zb4AtIIqlzY0Wn64L7/QS/1rx8819Lo0+wLQCnboLWp2mPS+/BatWZ2p+7o0+wLQKnboLVpp2EWtXfrkyRm9897lmq9lWpjhlxvK6tYPr9X+I2f14KFp/Ww2IzOpMFviABKAugjoLWp22MX+I2dVmq9d7eJa2Jnf+uG1evrEzOIPikKxtHjNTKGosa+f0iPPnSHAA6iJlEuLmh12Ue9GZ7E0pydffGPZrr9Sad51abbUUIoHQP8hoLeo2WEXjdzobLRePUCJI4BKBPQWrTRMupZaPwA6gRJHAAFy6G1oZthF5bSjmUJx8UZou4YaqJwB0B/YofdQfiSnsR2blBvKytXcCdIwTWZpAKQYO/QeCmrXgxufc+7KZgZ1dWZAl2ZLy67PlfPuMyukVd4uLv99APoTO/QeCqtdd1foDdZ6uXdOlQIIENB7KOwG5tvF0rIbrL/2i7nFA0ZXZwaUzSz/ozLRdhfAFaRcemjdULZm+mTdUHbJDdbq1Myl2ZKymUF9eut6vfDKhcXXCNLnM4Wixr5xStLKU48YfQekGzv0Hmq0dj0sNfPCKxd0bHx7zZ4wpTnXrkPTGvnS8zV367snT+vBQ9MN954BkDwE9B5qtHa9XluBWjdQA5dmSxr7xqklgXry5IwOHj+3rEyyWJrT3mfPtPRZAMQPKZcea6R2faXUTCNKc66Hnjq12ODrH94rhda8F4olTZ6cIfUCpAA79Biql5oZytY/TDTnLtdCwA7pCbaI9gFAOrBDj6HKU6W1bmD+6s3XNzwsoxFBKoebpkCymUdw1HB0dNSnpqZ6/r5psW3i6IqHjZqVKwfvysoaaWmfdoI7ED0zO+Huo6HfJ6Anz8bxwx3pAyNJA7YQtOv9NcgMmD5w9Sp6sQMRqhfQyaEnULunQ4eyGZmk1ZkBzXtj/WDoxQ7EX8MB3cyeMLO3zOyliueuNbNvm9nflf+7pjvLRKV2WvEOZTOa3nObfjRxh/7xcuv7/GJpTg89dUobxw9zUhWIiWZ26H8q6faq58Yl/ZW7f0jSX5Ufo8sq69lbEQThZgdqVAsqadixA/HQcEB39+9Iulj19N2Svlr+9Vcl5TuzLNSTH8np2Ph2Hdg5rGaa8BaKV9ImncT0JCB67ebQf97dfyxJ5f/+XPtLQjPyI7mO3SBtF9OTgGj17KaomT1gZlNmNnXhwoVevW1faDX10mlMTwKi1W5A/4mZXS9J5f++FXahuz/u7qPuPrp27do23xaVujWvtFlvz5bIowMRajegPyvpM+Vff0bSf2/z9dCCdm+Sdsq8pIef+X6kawD6WTNli09K+t+SNpnZm2b2ryRNSPqEmf2dpE+UHyMCwU3S9qeUtqdYmmeXDkSk4V4u7n5fyLd+uUNrQQeEdWrspV2HprX/yFlOkwI9xknRlIlLPr1WbfrkyRltmzjKYSSgSwjoKROXfLq0tDY9GKvHxCSgewjoKVR56Kh6t26Stt10bc928UH6J2ysHoeRgM4hoKdYrZF3j+0c1sHP/tKS563Ld1J3T54OPXQ0UyiySwc6hPa50OTJGY19/ZRK9UYbtWF1ZkCzpfma3zNJ929dr335LV17fyANaJ+LuvIjOe2/9+au5t3Dgrm00I/94PFz7NSBNhHQIelK3v21iTt0YOewVmeW/9XoZt7dxWxToF2kXBCq1ozRL3zztN59f67+b26RaaGW/tYPr9ULr1xgvilQoV7KhSHRCJUfyS0JorsnuxvMJS2WNFYOwQ5KHIM1AaiNgI6GPfniG5G9d7E0p73Pnln2LwYCPHAFAR0Na3fCUbsKxZIKxZIkdu1ALdwURcMGu12w3iQOJgFLEdDRsPs+ekPUS1iGKUnAFQR0NGxffos+vXV9rHbqLtHoCyijbBEt2zZxNPJWvYHMoOmaq1bp7WKJG6ZILU6Komvi0qpXkkpzrkKxRCdH9DUCOlpW3ao3TqkYbpiiH1G2iLZUHz7aOH5Y0RY3XhGXdBDQKwR0dFQcRuAFTFpMu3AgCf2AgI6OGtuxSQ8/c3rJMAuTItm1u6TPP/N9FUvzi+/PgSSkGTl0dFStoRr3b12vzEA0+fXZimAeIL+OtGKHjo6rzqtL0uiN1+rBQ9Oxya9zIAlpREBHTwQBvjodE5Wh1RlJCzn2R547o0uzCz1ihrIZ7b1rM+kYJBIBHT0TBMngBuXQ6sxiIO21d967rN2Tp3Xou2+oNHfl3w2FYkljXz+1ZL1AUnBSFJGK8rTpoFloB8ncUFbHxrf3eEXAyjgpilgb27EpshumK7UDJseOJCLlgkgFaY3PP/P9FQdJ95pLuunhb2nOXTlq15EQ7NARufxITn/75U/q01vXR72UJYIdPL1hkBQEdMTGvvwWHdg5vKSG/cDOYW276dqol6ZiaU4PPXWKoI5YI+WCWKlVwy5Jx35wMYLVLDXnzilTxBo7dMRenE51csoUcUZAR+zFreJkplDUyJeeJ/2C2CHlgtiLUwfHwKXZknYdmtauQ9OSpAGT5l1UxCBS7NARe3GajBRmvlzSPlMo6sFD09o9eTraBaEvEdARe9WTkeLOJR08fo6UDHqOgI5EyI/kdGx8u+Iz5G5lrnjdzEV/IKAjUdaF7NLjNM80ELebuUi/jgR0M3vNzE6b2bSZ0XULXVMrn57NDK7YlyUyJtIu6KlO7tBvdffhlTqBAe2qNREprvl1d2nXoWlt/uJfENjRE5QtInHCTpPGZXhGtXffn9OuQ9Oaev2i9uW3SFrYuTO4Gp3WqYDukp43M5f0R+7+ePUFZvaApAckaf36eDVhQvJVDs+YKRQX68Lj5ODxcxq9caEvTeUPHwZXo1M6MuDCzNa5+3kz+zlJ35b0O+7+nbDrGXCBXqjcBQ+sMMyil4LUUK2DUgzVQD31Blx0ZIfu7ufL/33LzL4p6RZJoQEd6IXK1MzkyZlYpGRWOvFKVQza1fZNUTO7xsx+Ovi1pNskvdTu6wKdlITDSWElmUCjOlHl8vOS/trMTkn6G0mH3f0vOvC6QEcFh5MO7ByOXSuBzIBpbMemqJeBhGs75eLuP5R0cwfWAvREkIZ55LkzujRbing1C+IzfA9JxklR9KX8SE577tysiOZTLzM373rkuTNRLwMJRx06+tb+I2djVdp4abbEYGq0hR06+tZKVSWZiLbulYOpdx2aZpAGmkJAR98KqyrJDWW1/96bNZTN9HhFy12aLenhZ04T1NGQjhwsahYHixAHtWrTs5lBPXrPlpqpjt2Tp/W14+d6ucRFAyb9zNUZvV0s0Sqgj/XkYBGQRJXtAur1VJk8OaOnT0S3S553qVBcqMihVQDCsEMHGrBt4mjs5ppKzDDtN/V26OTQgQbE9Vh+cPN0w/hhbZs4Sq69zxHQgQYk4Vh+kIohqPcvAjrQgFqTkjKDMTmVVKFYmtNDT50iqPcpcuhAg6qHUrz7j5cXb1TGkWlhUAF59vSgygXokOpJSRvHD0e4mvqCrRpVMf2DlAvQoiTk1QPF0pz2PkuvmLQjoAMtqpVXj7NCsURuPeUI6ECLkjA0oxodHdONHDrQhuoxd3HqsV7LpdmFXTq59HQioAMdEscZprXsP3JW+ZHcsqodKmGSj4AOdEEQGB966tRiS9y4OF8oLvuBQyVMOpBDB7okP5LTfMyCuSTJpH/71PSyfz0US3Paf+RsRItCJxDQgS4KK2208teg9f60qbtCJzXFtWcNGkNAB7qoVmljNjOox3YO60cTd+grvx6v+epJqq3HcgR0oIsqSxtNC8fwKwdo5EdyWrM6+slI0sIPmrEdm6JeBtrATVGgy6pbBlTbc+fmWFTEVObQuTGaTDTnAmIgKCGMyxCNNaszuuMXrtcLr1ygrDFG6jXnIqADMRN2QCnonhilNasz2nPnZgJ7RJhYBCRMfiSnk1+8TQd2Di/Jvd+/dX3kvWMuzZYYohFj5NCBmKqVex+98VrtffZMpH3Yg86N7NLjhx06kCD5kZym99wW9TJUKJa0YfywPvLv/5zdeowQ0AG0bLY0r12HprV78nTUS4EI6EAiDWXjUbseOHj8HDv1GCCgAwm0967NygzEZ0i1S/SBiQECOpBA+ZGc9t9782IVzJrVmcgD/EyhqI3jh7Vt4ii79YhQhw6kRGV/86HVGRVmS5HVrZuk+7eu1778lohWkE716tApWwRSorrMMcohG66FvProjddS3thD7NCBFKvctQ+YRTZsI0frgI5ghw70scpd+8bxw5GtY6ZQ1Ng3Ti2uCd3BTVGgT0Td67w053rkuTM1vzd5ckbbJo5yU7VNHQnoZna7mZ01s1fNbLwTrwmgs8Z2bFLUhY7VDcekK7n+mUJRrivzTQnqzWs7oJvZoKQ/kPRJSR+RdJ+ZfaTd1wXQWfmRnO7fuj7qZSyz/8hZ5pt2SCd26LdIetXdf+ju70v6M0l3d+B1AXTYvvwWHdg5rGwmmmxrrROuYXNMmW/avE78qeYkvVHx+M3yc0uY2QNmNmVmUxcuXOjA2wJoRX4kp5e//Mll7XkP7BzWgZ3DXR2JVyiWluXIw3L7Uef8k6jtskUzu1fSDnf/1+XH/0LSLe7+O2G/h7JFIN52T57WwePnunYwKTNgumrVgN59v3aNfDYzuGT2Khb0YsDFm5JuqHj8QUnnO/C6ACKyL79Fj5V38N1QmvfQYL5mdYZg3qJOBPTvSvqQmW00s6sk/YakZzvwugAilB/J6dj49q4F9TCXZkvaf+QsVS4taPtgkbtfNrPPSToiaVDSE+5eu9gUQOKM7dikBw9N97QvTFC6OPX6RQZVN4Gj/wDq2hDRKdPqwdj9nltnSDSAtvU67RKo3m5Sn74yAjqAusZ2bFI2Mxj1MiRRn74SAjqAuvIjOT16z5bIduqVBszo+RKCbosAGlLduTGq4RlBC+DgxmmwNrBDB9CCuJziJKe+FAEdQNPilFOfKRRJv5QR0AE0rTKnHvSCiRItdxeQQwfQkuoZptsmjmqmRgXKmtUZrb5qlc6X+513S5B+yY/kloze66cDSezQAXRErTRMNjOoPXdu1rHx7frRxB0atO6O2DhfKPb1wAwCOoCOqJWGqT7Ved9Hbwh/gQ5YN5Tt64EZpFwAdEx1GqbavvwW/ejCOzr2g4tdef+g70wt/XAgiR06gJ46+NlfWhyuUc+nt65vqpomP5Lr64EZBHQAPRe05n1t4o7QwJ4bympffosevWdLU7n3sFz+2I5Nba05CQjoACJVLwDnR3Kab7Ar7LaJo5JUN5efVuTQAUQqCLS1ygyD8sNGyx2DipZH79miY+Pbm1pHGkod6YcOIJaC8sPqipVG5IaydQN6ZQD/2WxG775/WaW5K/Ewjr3X6YcOIJFqlR82ql5FS3WteqFYWhLMpWSWOhLQAcTSSkG5XuVLvYqWRn9YJK3UkRw6gFhaN5St2UogV85vP/LcGV2aLS37fnVFS63ceKOBOmmljuzQAcRSveqX90rzy37PmtWZJXnvWm0AHjw0rWymfuhLYqkjO3QAsbRS9cu2iaM1Uyarr1q15CbmI8+dWXadS5otzSszaEvy5pkB0weuXqXCbCmxVS4EdACxFdZKICxlEjw/eXImNCUTuOaqVbrmp1YlukyxGikXAImz0vH+IM2yUjCXpLeLJR0b367Hdg5Lkh48NL1kUMbkyRltmziaqPml7NABJM7Yjk3LatSDnHejFSyVwT+4Psixf33qnL537u0lzydhfik7dACJs1Kr3kYqWFYK/i7p2A8uJrIFLzt0AIkUll8PK3cMDGUz2nvXZuVHcqGtdsOs9LpxwA4dQKqEDbAeymZ0YOewpvfctviDoNk6825PXGoXO3QAqbJSuWO1Wz+8Vl87fq7h155z1+TJmWWvFdbYq9cNvwjoAFKn3uSkwAuvXGj6tStvjtYqjwxuoE69flFPn5jp6Y1Vui0C6Fsbxw833Jq30lA2IzOtWBo5aKa5GvG1kU6QYei2CAAhWu3VUiiW6ta51wrmUncbfhHQAfStsBuonTAQcv+0mw2/COgA+lZQz96N6pX5Ghv0bjf8IqAD6Gv5kZy+8us3d22nHjBJv/aLjd2sbRUBHUDfq3XydCibqXntUDbTUvB3tVZV0wyqXACghlozTYM5o9KVOveh1Rm9895llWrlWGowqeWa9HpVLm3VoZvZXkmflRT82Pm8u3+rndcEgDiod0CpMhgHB4gaaQ0QDNroRk16Jw4WPebu/6kDrwMAsbLSAaVap0AlLdvVhwmafXUyoJNDB4Am1RptF+y4g1y8VL/3S6dr0juxQ/+cmf2mpClJD7n7pVoXmdkDkh6QpPXr13fgbQEgGrXa7gY77mPj25fsurdNHA1NxXS6Jr3uDt3M/tLMXqrxdbekP5R0k6RhST+W9JWw13H3x9191N1H165d26n1A0DP1RuB18i1kjpek153h+7uH2/khczsjyX9j7ZXBAAxF9ZzvdaOO+zaNaszHa9JbyuHbmbXVzz8lKSX2lsOAMRfrZYBYadAw67dc+fmjq+r3Rz6fzSzYS1U4rwm6bfaXRAAxF0zPdebubZdHCwCgISgfS4A9AkCOgCkBAEdAFKCgA4AKUFAB4CUiKTKxcwuSHq9529c23WS/j7qRXRAGj5HGj6DxOeImzR8juAz3OjuoUftIwnocWJmUyuVASVFGj5HGj6DxOeImzR8jkY/AykXAEgJAjoApAQBXXo86gV0SBo+Rxo+g8TniJs0fI6GPkPf59ABIC3YoQNAShDQASAlCOiSzOzLZvZ9M5s2s+fNbF3Ua2qWme03s1fKn+ObZjYU9ZpaYWb3mtkZM5s3s8SVmpnZ7WZ21sxeNbPxqNfTCjN7wszeMrPEzjcwsxvM7AUze7n89+l3o15TK8zsajP7GzM7Vf4cj6x4PTl0ycx+xt3/ofzrfyPpI+7+2xEvqylmdpuko+5+2cz+gyS5++9FvKymmdk/kTQv6Y8k/Tt3T0yfZTMblPR/JH1C0puSvivpPnf/20gX1iQz++eS3pH0X939n0a9nlaUh+9c7+7fM7OflnRCUj6BfxYm6Rp3f8fMMpL+WtLvuvvxWtezQ5cUBPOya7QwsCNR3P15d79cfnhc0gejXE+r3P1ldz8b9TpadIukV939h+7+vqQ/k3R3xGtqmrt/R9LFqNfRDnf/sbt/r/zr/yfpZUmdnyjRZb7gnfLDTPkrND4R0MvM7PfN7A1J90v6YtTradO/lPTnUS+iD+UkvVHx+E0lMIikjZltkDQi6cWIl9ISMxs0s2lJb0n6truHfo6+Cehm9pdm9lKNr7slyd2/4O43SDoo6XPRrra2ep+hfM0XJF3WwueIpUY+R0JZjecS96+9NDGzD0h6WtKuqn+JJ4a7z7n7sBb+1X2LmYWmwdqdKZoY7v7xBi/9b5IOS9rTxeW0pN5nMLPPSPpVSb/sMb450sSfRdK8KemGiscflHQ+orX0vXLO+WlJB939majX0y53L5jZ/5R0u6SaN6z7Zoe+EjP7UMXDuyS9EtVaWmVmt0v6PUl3ufts1OvpU9+V9CEz22hmV0n6DUnPRrymvlS+mfgnkl529/8c9XpaZWZrg4o1M8tK+rhWiE9UuUgys6clbdJCdcXrkn7b3WeiXVVzzOxVST8l6f+WnzqetEodSTKzT0n6L5LWSipImnb3HZEuqglm9iuSDkgalPSEu/9+tCtqnpk9KeljWmjZ+hNJe9z9TyJdVJPM7J9J+l+STmvh/2tJ+ry7fyu6VTXPzH5B0le18PdpQNJT7v6l0OsJ6ACQDqRcACAlCOgAkBIEdABICQI6AKQEAR0AUoKADgApQUAHgJT4/whekE1DgF3ZAAAAAElFTkSuQmCC", "text/plain": [ "