Skip to content

Commit

Permalink
[docs]: 更新5.3 PyTorch修改模型.md
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhikangNiu committed Oct 2, 2023
1 parent 3549763 commit 1290af6
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions source/第五章/5.3 PyTorch修改模型.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
我们这里以PyTorch官方视觉库torchvision预定义好的模型ResNet50为例,探索如何修改模型的某一层或者某几层。我们先看看模型的定义是怎样的:

```python
# 导入必要的package
import torch
import torch.nn as nn
from collections import OrderedDict
import torchvision.models as models
net = models.resnet50()
print(net)
Expand Down Expand Up @@ -53,7 +57,6 @@ ResNet(
假设我们要用这个resnet模型去做一个10分类的问题,就应该修改模型的fc层,将其输出节点数替换为10。另外,我们觉得一层全连接层可能太少了,想再加一层。可以做如下修改:

```python
from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
('relu1', nn.ReLU()),
('dropout1',nn.Dropout(0.5)),
Expand Down Expand Up @@ -99,7 +102,6 @@ class Model(nn.Module):
之后对我们修改好的模型结构进行实例化,就可以使用了:

```python
import torchvision.models as models
net = models.resnet50()
model = Model(net).cuda()
```
Expand Down

0 comments on commit 1290af6

Please sign in to comment.