Last active
April 8, 2020 23:07
-
-
Save zhezh/ccc7e7b70338c6b882e08113d7706530 to your computer and use it in GitHub Desktop.
[pytorch 分层设置学习率] #pytorch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"分层设置学习率" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"import torch.optim as optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0.4.0\n" | |
] | |
} | |
], | |
"source": [ | |
"print(torch.__version__)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# 构建一个简单多层网络结构\n", | |
"class TwoLayerNet(torch.nn.Module):\n", | |
" def __init__(self, D_in, H, D_out):\n", | |
" \"\"\"\n", | |
" In the constructor we instantiate two nn.Linear modules and assign them as\n", | |
" member variables.\n", | |
" \"\"\"\n", | |
" super(TwoLayerNet, self).__init__()\n", | |
" self.linear1 = torch.nn.Linear(D_in, H)\n", | |
" self.linear2 = torch.nn.Linear(H, D_out)\n", | |
"\n", | |
" def forward(self, x):\n", | |
" \"\"\"\n", | |
" In the forward function we accept a Tensor of input data and we must return\n", | |
" a Tensor of output data. We can use Modules defined in the constructor as\n", | |
" well as arbitrary operators on Tensors.\n", | |
" \"\"\"\n", | |
" h_relu = F.relu(self.linear1(x))\n", | |
" y_pred = self.linear2(h_relu)\n", | |
" return y_pred" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# N is batch size; D_in is input dimension;\n", | |
"# H is hidden dimension; D_out is output dimension.\n", | |
"N, D_in, H, D_out = 64, 1000, 100, 10" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x = torch.randn(N, D_in)\n", | |
"y = torch.randn(N, D_out)\n", | |
"\n", | |
"# Construct our model by instantiating the class defined above\n", | |
"model = TwoLayerNet(D_in, H, D_out)\n", | |
"\n", | |
"# Construct our loss function and an Optimizer. The call to model.parameters()\n", | |
"# in the SGD constructor will contain the learnable parameters of the two\n", | |
"# nn.Linear modules which are members of the model.\n", | |
"criterion = torch.nn.MSELoss(size_average=False)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"查看模型的参数名称" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"参数名: linear1.weight , id: 140705409071936\n", | |
"参数名: linear1.bias , id: 140705409072008\n", | |
"参数名: linear2.weight , id: 140705409072296\n", | |
"参数名: linear2.bias , id: 140705409072656\n" | |
] | |
} | |
], | |
"source": [ | |
"for pname, p in model.named_parameters():\n", | |
" print('参数名: {: <18}, id: {}'.format(pname, id(p)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 21, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"all_parameters = model.parameters()\n", | |
"\n", | |
"lin1_parameters = []\n", | |
"for pname, p in model.named_parameters():\n", | |
" if pname.find('linear1') >= 0:\n", | |
" lin1_parameters.append(p)\n", | |
"\n", | |
"lin1_parameters_id = list(map(id, lin1_parameters))\n", | |
"other_parameters = list(filter(lambda p: id(p) not in lin1_parameters_id,\n", | |
" all_parameters))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"现在获得了两组参数,一组是linear1,另一组是其他的(本程序中即linear2)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"linear1组参数id: \n", | |
"140705409071936\n", | |
"140705409072008\n", | |
"\n", | |
"\n", | |
"other组参数id: \n", | |
"140705409072296\n", | |
"140705409072656\n" | |
] | |
} | |
], | |
"source": [ | |
"print('linear1组参数id: ')\n", | |
"for p in lin1_parameters:\n", | |
" print(id(p))\n", | |
" \n", | |
"print('\\n')\n", | |
"print('other组参数id: ')\n", | |
"for p in other_parameters:\n", | |
" print(id(p))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"构造optim" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"optimizer = optim.SGD([\n", | |
" {'params': lin1_parameters},\n", | |
" {'params': other_parameters, 'lr': 1e-3}\n", | |
" ], lr=1e-4)\n", | |
"# linear1层的学习率1e-4,其它层1e-3" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"训练网络" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 25, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 719.4832153320312\n", | |
"1 618.053466796875\n", | |
"2 556.1675415039062\n", | |
"3 500.43212890625\n", | |
"4 447.5682678222656\n", | |
"5 396.4813537597656\n", | |
"6 347.07867431640625\n", | |
"7 300.09326171875\n", | |
"8 256.21868896484375\n", | |
"9 216.0878448486328\n", | |
"10 180.35525512695312\n", | |
"11 149.25791931152344\n", | |
"12 122.67479705810547\n", | |
"13 100.36481475830078\n", | |
"14 81.89598083496094\n", | |
"15 66.74868774414062\n", | |
"16 54.42697525024414\n", | |
"17 44.45567321777344\n", | |
"18 36.39970397949219\n", | |
"19 29.910730361938477\n", | |
"20 24.662458419799805\n", | |
"21 20.40622901916504\n", | |
"22 16.947093963623047\n", | |
"23 14.121529579162598\n", | |
"24 11.805665969848633\n", | |
"25 9.908512115478516\n", | |
"26 8.346722602844238\n", | |
"27 7.056419849395752\n", | |
"28 5.984086036682129\n", | |
"29 5.090843677520752\n", | |
"30 4.344158172607422\n", | |
"31 3.7168867588043213\n", | |
"32 3.1883127689361572\n", | |
"33 2.7413690090179443\n", | |
"34 2.3627569675445557\n", | |
"35 2.0411109924316406\n", | |
"36 1.7670531272888184\n", | |
"37 1.5327770709991455\n", | |
"38 1.3317700624465942\n", | |
"39 1.1588612794876099\n", | |
"40 1.0100197792053223\n", | |
"41 0.8820691108703613\n", | |
"42 0.7714465856552124\n", | |
"43 0.6757064461708069\n", | |
"44 0.592779815196991\n", | |
"45 0.5206921696662903\n", | |
"46 0.45796385407447815\n", | |
"47 0.403344064950943\n", | |
"48 0.3557352125644684\n", | |
"49 0.31413477659225464\n", | |
"50 0.27769696712493896\n", | |
"51 0.24573519825935364\n", | |
"52 0.2176705002784729\n", | |
"53 0.19302386045455933\n", | |
"54 0.17133362591266632\n", | |
"55 0.15221278369426727\n", | |
"56 0.135351300239563\n", | |
"57 0.12046048790216446\n", | |
"58 0.1073007881641388\n", | |
"59 0.09565050154924393\n", | |
"60 0.08532743155956268\n", | |
"61 0.07617135345935822\n", | |
"62 0.06804817914962769\n", | |
"63 0.06084805354475975\n", | |
"64 0.05446131154894829\n", | |
"65 0.04884392023086548\n", | |
"66 0.043833404779434204\n", | |
"67 0.039361849427223206\n", | |
"68 0.03537042811512947\n", | |
"69 0.03180677816271782\n", | |
"70 0.028617050498723984\n", | |
"71 0.025749675929546356\n", | |
"72 0.023186029866337776\n", | |
"73 0.02088870480656624\n", | |
"74 0.018828196451067924\n", | |
"75 0.016980471089482307\n", | |
"76 0.015320717357099056\n", | |
"77 0.013830263167619705\n", | |
"78 0.012490017339587212\n", | |
"79 0.011284894309937954\n", | |
"80 0.01020009908825159\n", | |
"81 0.00922376848757267\n", | |
"82 0.008343766443431377\n", | |
"83 0.007550488226115704\n", | |
"84 0.0068353088572621346\n", | |
"85 0.006190172396600246\n", | |
"86 0.005607489496469498\n", | |
"87 0.005081566050648689\n", | |
"88 0.00460641598328948\n", | |
"89 0.004176917020231485\n", | |
"90 0.0037886006757616997\n", | |
"91 0.0034373654052615166\n", | |
"92 0.0031197601929306984\n", | |
"93 0.0028322283178567886\n", | |
"94 0.0025717862881720066\n", | |
"95 0.002335888333618641\n", | |
"96 0.0021221640054136515\n", | |
"97 0.0019284778973087668\n", | |
"98 0.0017529240576550364\n", | |
"99 0.0015937236603349447\n", | |
"100 0.0014493277994915843\n", | |
"101 0.0013182209804654121\n", | |
"102 0.0011992763029411435\n", | |
"103 0.0010912807192653418\n", | |
"104 0.0009932058164849877\n", | |
"105 0.0009040983277373016\n", | |
"106 0.0008233404951170087\n", | |
"107 0.0007499091443605721\n", | |
"108 0.0006831525824964046\n", | |
"109 0.0006224379176273942\n", | |
"110 0.0005672484403476119\n", | |
"111 0.0005170325748622417\n", | |
"112 0.0004713317903224379\n", | |
"113 0.0004297299892641604\n", | |
"114 0.0003918729198630899\n", | |
"115 0.0003573991998564452\n", | |
"116 0.0003260155499447137\n", | |
"117 0.0002974196686409414\n", | |
"118 0.0002713669091463089\n", | |
"119 0.0002476317167747766\n", | |
"120 0.00022600177908316255\n", | |
"121 0.0002062939602183178\n", | |
"122 0.00018832141358871013\n", | |
"123 0.0001719275169307366\n", | |
"124 0.0001569933956488967\n", | |
"125 0.00014336439198814332\n", | |
"126 0.00013093784218654037\n", | |
"127 0.0001196042139781639\n", | |
"128 0.0001092585880542174\n", | |
"129 9.982137999031693e-05\n", | |
"130 9.120586764765903e-05\n", | |
"131 8.33444792078808e-05\n", | |
"132 7.616882066940889e-05\n", | |
"133 6.961503822822124e-05\n", | |
"134 6.363449210766703e-05\n", | |
"135 5.817100827698596e-05\n", | |
"136 5.318074545357376e-05\n", | |
"137 4.8620247980579734e-05\n", | |
"138 4.4461063225753605e-05\n", | |
"139 4.0658000216353685e-05\n", | |
"140 3.7181245716055855e-05\n", | |
"141 3.400376590434462e-05\n", | |
"142 3.110080797341652e-05\n", | |
"143 2.8448537705116905e-05\n", | |
"144 2.6025612896773964e-05\n", | |
"145 2.3810694983694702e-05\n", | |
"146 2.1784513592137955e-05\n", | |
"147 1.9931732822442427e-05\n", | |
"148 1.8238761185784824e-05\n", | |
"149 1.669217635935638e-05\n", | |
"150 1.5273904864443466e-05\n", | |
"151 1.3978798961034045e-05\n", | |
"152 1.279618481930811e-05\n", | |
"153 1.1711815204762388e-05\n", | |
"154 1.071973474608967e-05\n", | |
"155 9.814746590564027e-06\n", | |
"156 8.985691238194704e-06\n", | |
"157 8.226681529777125e-06\n", | |
"158 7.533013558713719e-06\n", | |
"159 6.897260846017161e-06\n", | |
"160 6.3151273934636265e-06\n", | |
"161 5.784675977338338e-06\n", | |
"162 5.296439212543191e-06\n", | |
"163 4.851282938034274e-06\n", | |
"164 4.442661975190276e-06\n", | |
"165 4.068862381245708e-06\n", | |
"166 3.7270233406161424e-06\n", | |
"167 3.413488684600452e-06\n", | |
"168 3.1275621950044297e-06\n", | |
"169 2.8650440526689636e-06\n", | |
"170 2.624645276227966e-06\n", | |
"171 2.405057784926612e-06\n", | |
"172 2.2035642359696794e-06\n", | |
"173 2.019183966694982e-06\n", | |
"174 1.850062517405604e-06\n", | |
"175 1.6953827071120031e-06\n", | |
"176 1.5531543340330245e-06\n", | |
"177 1.4232090279620024e-06\n", | |
"178 1.3044416391494451e-06\n", | |
"179 1.1956101388932439e-06\n", | |
"180 1.0957942322420422e-06\n", | |
"181 1.0046904890259611e-06\n", | |
"182 9.207004154632159e-07\n", | |
"183 8.440935062026256e-07\n", | |
"184 7.73405361087498e-07\n", | |
"185 7.091608722475939e-07\n", | |
"186 6.500075642179581e-07\n", | |
"187 5.959356030871277e-07\n", | |
"188 5.460437932924833e-07\n", | |
"189 5.007885874874773e-07\n", | |
"190 4.5921461833131616e-07\n", | |
"191 4.2098395169887226e-07\n", | |
"192 3.8619594988631434e-07\n", | |
"193 3.539971658028662e-07\n", | |
"194 3.244428512516606e-07\n", | |
"195 2.975311303998751e-07\n", | |
"196 2.726736170188815e-07\n", | |
"197 2.5038809781108284e-07\n", | |
"198 2.2951981293317658e-07\n", | |
"199 2.1044718323537381e-07\n", | |
"200 1.931230571017295e-07\n", | |
"201 1.7704486765524052e-07\n", | |
"202 1.625433725394032e-07\n", | |
"203 1.4899816846991598e-07\n", | |
"204 1.3666438292148086e-07\n", | |
"205 1.2532166238088394e-07\n", | |
"206 1.1494875451489861e-07\n", | |
"207 1.054767295727288e-07\n", | |
"208 9.676008971837291e-08\n", | |
"209 8.872893886291422e-08\n", | |
"210 8.135914697504631e-08\n", | |
"211 7.470005414234038e-08\n", | |
"212 6.852687306491134e-08\n", | |
"213 6.293034005011577e-08\n", | |
"214 5.7790256136058815e-08\n", | |
"215 5.302708672161316e-08\n", | |
"216 4.8682117892440147e-08\n", | |
"217 4.453647051150256e-08\n", | |
"218 4.099248585021087e-08\n", | |
"219 3.761395817036828e-08\n", | |
"220 3.453242669593237e-08\n", | |
"221 3.170368501059784e-08\n", | |
"222 2.915122720992258e-08\n", | |
"223 2.6717120960029206e-08\n", | |
"224 2.451883673870725e-08\n", | |
"225 2.2580659120308155e-08\n", | |
"226 2.0710798409595554e-08\n", | |
"227 1.9040477639009623e-08\n", | |
"228 1.7526446072224644e-08\n", | |
"229 1.6122124080197864e-08\n", | |
"230 1.4806063042271944e-08\n", | |
"231 1.3633435713700237e-08\n", | |
"232 1.2587334730085331e-08\n", | |
"233 1.1583493275679757e-08\n", | |
"234 1.0632591695980409e-08\n", | |
"235 9.816670143436568e-09\n", | |
"236 9.059893280038978e-09\n", | |
"237 8.38562641547469e-09\n", | |
"238 7.752122499482539e-09\n", | |
"239 7.144894009769587e-09\n", | |
"240 6.591914125664289e-09\n", | |
"241 6.152251152968802e-09\n", | |
"242 5.6817590632363135e-09\n", | |
"243 5.275397008119853e-09\n", | |
"244 4.890560845183245e-09\n", | |
"245 4.5295536210687715e-09\n", | |
"246 4.211996085246028e-09\n", | |
"247 3.908982026956664e-09\n", | |
"248 3.6504110845214655e-09\n", | |
"249 3.4168203821849374e-09\n", | |
"250 3.187831110196271e-09\n", | |
"251 2.98542235377397e-09\n", | |
"252 2.8059994328089033e-09\n", | |
"253 2.626949102690901e-09\n", | |
"254 2.4697992540012592e-09\n", | |
"255 2.3114088421039014e-09\n", | |
"256 2.1845512065965522e-09\n", | |
"257 2.0542922918309614e-09\n", | |
"258 1.9309427390368228e-09\n", | |
"259 1.8381356436947272e-09\n", | |
"260 1.7248911188261218e-09\n", | |
"261 1.6398804536521538e-09\n", | |
"262 1.5482324311477669e-09\n", | |
"263 1.4635379574912122e-09\n", | |
"264 1.3910612661760524e-09\n", | |
"265 1.317505660125562e-09\n", | |
"266 1.2539014271339965e-09\n", | |
"267 1.2011907024600532e-09\n", | |
"268 1.1445275838184443e-09\n", | |
"269 1.0953595808160799e-09\n", | |
"270 1.0399190397691882e-09\n", | |
"271 1.0023235574863065e-09\n", | |
"272 9.588654314995892e-10\n", | |
"273 9.140234680238279e-10\n", | |
"274 8.714531318787522e-10\n", | |
"275 8.341509150078252e-10\n", | |
"276 8.028351317079796e-10\n", | |
"277 7.760427300773642e-10\n", | |
"278 7.34666716351029e-10\n", | |
"279 7.090298348444435e-10\n", | |
"280 6.824536491478739e-10\n", | |
"281 6.50852871597607e-10\n", | |
"282 6.332582236368012e-10\n", | |
"283 6.151034126489208e-10\n", | |
"284 5.901392152729557e-10\n", | |
"285 5.691445092992353e-10\n", | |
"286 5.507771461132904e-10\n", | |
"287 5.32691946109054e-10\n", | |
"288 5.124157764768711e-10\n", | |
"289 4.940373665718312e-10\n", | |
"290 4.766944616818591e-10\n", | |
"291 4.6186365842970645e-10\n", | |
"292 4.470791514776806e-10\n", | |
"293 4.3581016573313036e-10\n", | |
"294 4.2106490516502504e-10\n", | |
"295 4.066213477038616e-10\n", | |
"296 3.981817098264173e-10\n", | |
"297 3.8970893179168797e-10\n", | |
"298 3.72095521061766e-10\n", | |
"299 3.6663283520255163e-10\n", | |
"300 3.573504825382656e-10\n", | |
"301 3.487677924240984e-10\n", | |
"302 3.3633168472491093e-10\n", | |
"303 3.290345773621084e-10\n", | |
"304 3.224511213595349e-10\n", | |
"305 3.122619940398863e-10\n", | |
"306 3.0483571222816863e-10\n", | |
"307 2.9414332081145744e-10\n", | |
"308 2.8667790363812173e-10\n", | |
"309 2.8062335788447967e-10\n", | |
"310 2.762368667141857e-10\n", | |
"311 2.6522151141961103e-10\n", | |
"312 2.57844356976733e-10\n", | |
"313 2.5349577992273e-10\n", | |
"314 2.478507399317209e-10\n", | |
"315 2.41882402995941e-10\n", | |
"316 2.3629426193494396e-10\n", | |
"317 2.3070964583205011e-10\n", | |
"318 2.264485127190241e-10\n", | |
"319 2.2164882429454025e-10\n", | |
"320 2.1533744232193897e-10\n", | |
"321 2.097950702051321e-10\n", | |
"322 2.064158427517171e-10\n", | |
"323 1.990261011552974e-10\n", | |
"324 1.9497652103961371e-10\n", | |
"325 1.9140954099494678e-10\n", | |
"326 1.8693983860895713e-10\n", | |
"327 1.8207133023473432e-10\n", | |
"328 1.783010128431073e-10\n", | |
"329 1.763045681668629e-10\n", | |
"330 1.7145321273837055e-10\n", | |
"331 1.6931771262829187e-10\n", | |
"332 1.6673835923075586e-10\n", | |
"333 1.6416644432748484e-10\n", | |
"334 1.6168200112076647e-10\n", | |
"335 1.5868138747432425e-10\n", | |
"336 1.5622357574240908e-10\n", | |
"337 1.5223147742382537e-10\n", | |
"338 1.4958895233618819e-10\n", | |
"339 1.4682752236261365e-10\n", | |
"340 1.436487179207191e-10\n", | |
"341 1.4157740257925155e-10\n", | |
"342 1.410829092440835e-10\n", | |
"343 1.380528469319131e-10\n", | |
"344 1.3580145341585137e-10\n", | |
"345 1.341790289988154e-10\n", | |
"346 1.3181483682345174e-10\n", | |
"347 1.3028327028319353e-10\n", | |
"348 1.2861381404327688e-10\n", | |
"349 1.2620478273550617e-10\n", | |
"350 1.2542905603041277e-10\n", | |
"351 1.238803642999997e-10\n", | |
"352 1.218108947043106e-10\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"353 1.2094443502252972e-10\n", | |
"354 1.1777642749954964e-10\n", | |
"355 1.1760770135538223e-10\n", | |
"356 1.1544241951266798e-10\n", | |
"357 1.1402352673162142e-10\n", | |
"358 1.1242246023002167e-10\n", | |
"359 1.1167516911214648e-10\n", | |
"360 1.0990484911044263e-10\n", | |
"361 1.0966394459188678e-10\n", | |
"362 1.0963921437401325e-10\n", | |
"363 1.0603967703914918e-10\n", | |
"364 1.03834024711702e-10\n", | |
"365 1.039718450224214e-10\n", | |
"366 1.0166703590108739e-10\n", | |
"367 1.0114989401621699e-10\n", | |
"368 9.834176528666916e-11\n", | |
"369 9.60351104195567e-11\n", | |
"370 9.521219923591673e-11\n", | |
"371 9.493408836824813e-11\n", | |
"372 9.308730175572322e-11\n", | |
"373 9.198990180703248e-11\n", | |
"374 9.021075553228286e-11\n", | |
"375 8.961704989207675e-11\n", | |
"376 8.807096718577156e-11\n", | |
"377 8.61079402225684e-11\n", | |
"378 8.594966405262028e-11\n", | |
"379 8.452996635988086e-11\n", | |
"380 8.392590788997012e-11\n", | |
"381 8.171358034658738e-11\n", | |
"382 8.029973908030286e-11\n", | |
"383 7.93237697749305e-11\n", | |
"384 7.853767636234465e-11\n", | |
"385 7.892052289459883e-11\n", | |
"386 7.902740961629462e-11\n", | |
"387 7.743728575038133e-11\n", | |
"388 7.645977601056231e-11\n", | |
"389 7.571010485207808e-11\n", | |
"390 7.520549460959813e-11\n", | |
"391 7.54181994633285e-11\n", | |
"392 7.461733314562125e-11\n", | |
"393 7.3604816686057e-11\n", | |
"394 7.3249614707116e-11\n", | |
"395 7.211352348601707e-11\n", | |
"396 7.178711791677728e-11\n", | |
"397 7.167869769952873e-11\n", | |
"398 7.025070108968023e-11\n", | |
"399 6.825126575016327e-11\n", | |
"400 6.793333950927405e-11\n", | |
"401 6.772558902579107e-11\n", | |
"402 6.730412061006774e-11\n", | |
"403 6.625886644906487e-11\n", | |
"404 6.461721435702117e-11\n", | |
"405 6.47099734907286e-11\n", | |
"406 6.391782242376465e-11\n", | |
"407 6.264266882993752e-11\n", | |
"408 6.173965505507084e-11\n", | |
"409 6.116362971653189e-11\n", | |
"410 6.077430919626536e-11\n", | |
"411 6.030231869402769e-11\n", | |
"412 5.947273923334606e-11\n", | |
"413 5.867167862660949e-11\n", | |
"414 5.855109452834739e-11\n", | |
"415 5.7634869099487673e-11\n", | |
"416 5.798097418852066e-11\n", | |
"417 5.714314438298729e-11\n", | |
"418 5.7525491314880384e-11\n", | |
"419 5.753773499317383e-11\n", | |
"420 5.72689846933816e-11\n", | |
"421 5.631775948367057e-11\n", | |
"422 5.598362051717487e-11\n", | |
"423 5.5867376697049664e-11\n", | |
"424 5.629419153052595e-11\n", | |
"425 5.5520841396594633e-11\n", | |
"426 5.46322084793438e-11\n", | |
"427 5.4010414196614676e-11\n", | |
"428 5.341607364761636e-11\n", | |
"429 5.3519549902958374e-11\n", | |
"430 5.247411186126705e-11\n", | |
"431 5.307565845158457e-11\n", | |
"432 5.2371048470112314e-11\n", | |
"433 5.2449510012930745e-11\n", | |
"434 5.1879195384074706e-11\n", | |
"435 5.1574945703070085e-11\n", | |
"436 5.1512395043973314e-11\n", | |
"437 5.08322897663227e-11\n", | |
"438 5.057060326052465e-11\n", | |
"439 4.9787240302689995e-11\n", | |
"440 4.823476340565236e-11\n", | |
"441 4.8459992962879284e-11\n", | |
"442 4.797246627719076e-11\n", | |
"443 4.824013410953398e-11\n", | |
"444 4.7559997606860804e-11\n", | |
"445 4.690832444698145e-11\n", | |
"446 4.619358368040949e-11\n", | |
"447 4.583974866356755e-11\n", | |
"448 4.605983650041168e-11\n", | |
"449 4.5755635391664384e-11\n", | |
"450 4.547212259509159e-11\n", | |
"451 4.5017298916372184e-11\n", | |
"452 4.5089352390670356e-11\n", | |
"453 4.437321690642371e-11\n", | |
"454 4.417576374149412e-11\n", | |
"455 4.4186581477090314e-11\n", | |
"456 4.324371069563959e-11\n", | |
"457 4.303992579002269e-11\n", | |
"458 4.266882333570088e-11\n", | |
"459 4.2662408328286716e-11\n", | |
"460 4.2061472360632735e-11\n", | |
"461 4.153392560435343e-11\n", | |
"462 4.1709347781138106e-11\n", | |
"463 4.15860852698291e-11\n", | |
"464 4.1531614952683427e-11\n", | |
"465 4.106182407981329e-11\n", | |
"466 4.0790346794716825e-11\n", | |
"467 4.062897934753451e-11\n", | |
"468 3.990976646384148e-11\n", | |
"469 4.011613263799063e-11\n", | |
"470 3.9666533946380866e-11\n", | |
"471 3.945664281412853e-11\n", | |
"472 3.835248785222234e-11\n", | |
"473 3.8230120458226935e-11\n", | |
"474 3.747497104300557e-11\n", | |
"475 3.7175280215295814e-11\n", | |
"476 3.713800100779707e-11\n", | |
"477 3.680154792018442e-11\n", | |
"478 3.743382687160235e-11\n", | |
"479 3.680740781608627e-11\n", | |
"480 3.6479482629081517e-11\n", | |
"481 3.598047901287593e-11\n", | |
"482 3.572271645158054e-11\n", | |
"483 3.5526505348659754e-11\n", | |
"484 3.541624979397362e-11\n", | |
"485 3.4761606787503396e-11\n", | |
"486 3.4769877949036854e-11\n", | |
"487 3.460022199308632e-11\n", | |
"488 3.453215144388899e-11\n", | |
"489 3.4166957457726355e-11\n", | |
"490 3.428968220475781e-11\n", | |
"491 3.439980245101282e-11\n", | |
"492 3.4364174700263206e-11\n", | |
"493 3.4181008717881767e-11\n", | |
"494 3.368537393466653e-11\n", | |
"495 3.368653272994848e-11\n", | |
"496 3.291827227469568e-11\n", | |
"497 3.2805473615393765e-11\n", | |
"498 3.3228246543171025e-11\n", | |
"499 3.3127479925898484e-11\n" | |
] | |
} | |
], | |
"source": [ | |
"for t in range(500):\n", | |
" # Forward pass: Compute predicted y by passing x to the model\n", | |
" y_pred = model(x)\n", | |
"\n", | |
" # Compute and print loss\n", | |
" loss = criterion(y_pred, y)\n", | |
" print(t, loss.item())\n", | |
"\n", | |
" # Zero gradients, perform a backward pass, and update the weights.\n", | |
" optimizer.zero_grad()\n", | |
" loss.backward()\n", | |
" optimizer.step()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment