PyTorch学习笔记(十三)——现有网络模型的使用及修改

 以分类模型的VGG为例

 

vgg16_false = torchvision.models.vgg16(weights=False)
vgg16_true = torchvision.models.vgg16(weights=True)
print(vgg16_true)

vgg16_true.classifier.add_module("add_linear",nn.Linear(1000,10))
print(vgg16_true)

vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)