Skip to content

Instantly share code, notes, and snippets.

import torch
import torch.nn as nn
model = NeuralNet()
print(model)
for name, param in model.named_parameters():
print('name: ', name)
print(type(param))
print('param.shape: ', param.shape)
print('param.requires_grad: ', param.requires_grad)
print('=====')
for name, param in model.named_parameters():
if name in ['fc.weight', 'fc.bias']:
param.requires_grad = True
else:
param.requires_grad = False
for name, param in model.named_parameters():
print(name, ':', param.requires_grad)
isinstance(model.fc, nn.Module)
for name, child in model.named_children():
print('name: ', name)
print('isinstance({}, nn.Module): '.format(name), isinstance(child, nn.Module))
print('=====')