如何可视化 PyTorch 模型
更新时间:2024 年 4 月
准备模型
首先我们搭建一个简单的模型,用于演示如何可视化 PyTorch 模型。为了演示复杂模型的结构,我们在模型中加入了一个跨层连接。
import torch
import torch.nn as nn
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.cnn = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Flatten(),
)
self.mlp1 = nn.Sequential(
nn.Linear(7*7*64, 128),
nn.ReLU(),
)
self.mlp2 = nn.Sequential(
nn.Linear(7*7*64, 128),
nn.ReLU(),
)
self.fc = nn.Linear(256, 10)
def forward(self, x):
x = self.cnn(x)
x1 = self.mlp1(x)
x2 = self.mlp2(x)
x = torch.cat([x1, x2], dim=1)
x = self.fc(x)
return x
model = Model()
dummy_input = torch.randn(1, 1, 28, 28)
这里我们以 28x28 的输入为例,搭建了一个简单的卷积神经网络。
print 大法
我们可以直接使用 print 打印模型:
Model(
(cnn): Sequential(
(0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU()
(2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): ReLU()
(5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(6): Flatten(start_dim=1, end_dim=-1)
)
(mlp1): Sequential(
(0): Linear(in_features=3136, out_features=128, bias=True)
(1): ReLU()
)
(mlp2): Sequential(
(0): Linear(in_features=3136, out_features=128, bias=True)
(1): ReLU()
)
(fc): Linear(in_features=256, out_features=10, bias=True)
)
缺点:只能看到线性结构,看不到跨层连接。
Netron 在线可视化
Netron 是一个经典的模型可视化工具,官网:https://netron.app/
代码:https://github.com/lutzroeder/netron
这个工具可以直接在线可视化模型,不需要安装 python 包,你只需要将模型保存为 .onnx
格式,然后上传到网站即可。
保存为 onnx 格式
torch.onnx.export(model, dummy_input, "model.onnx")
使用 Netron 可视化 onnx 模型
模型可视化挺好看,跨层连接也很清晰。
TensorBoard
官网:https://www.tensorflow.org/tensorboard/get_started?hl=zh-cn
代码:https://github.com/tensorflow/tensorboard
代码更新频繁,每天都有更新。
安装 tensorboard
pip install tensorboard
使用 tensorboard
首先保存模型结构:
from torch.utils.tensorboard import SummaryWriter
with SummaryWriter(comment='model') as w:
w.add_graph(model, dummy_input)
然后在终端里运行 tensorboard:
tensorboard --logdir=.
最后在浏览器里打开 http://localhost:6006/
即可看到模型结构。
这个模型结构是可以交互式展开的,比如:
不仅可以看到最里面的模型,也能看到每一层的输入输出尺寸。
torchview
1 年未更新,目前可用。
安装 torchview
pip install torchview
安装 graphviz
安装 graphviz:
Mac
brew install graphviz
Ubuntu:
sudo apt-get install graphviz
参考链接:https://graphviz.readthedocs.io/en/stable/manual.html
使用 torchview
from torchview import draw_graph
model_graph = draw_graph(model, input_size=(1, 1, 28, 28), save_graph=True, expand_nested=True)
可视化结果:
torchviz
这个可视化工具比较传统,已经三年未更新:
安装 torchviz
安装 torchviz:
pip install torchviz
安装 graphviz
安装 graphviz:
Mac
brew install graphviz
Ubuntu:
sudo apt-get install graphviz
参考链接:https://graphviz.readthedocs.io/en/stable/manual.html
使用 torchviz
使用 torchviz 可视化模型:
from torchviz import make_dot
dot = make_dot(model(dummy_input), params=dict(model.named_parameters()))
dot.render("model", format="png")
你也可以存储 dot 文件,然后手动修改样式:
dot.save('vis.dot')
使用这个网站可以在线编辑 dot 文件:https://dreampuf.github.io/GraphvizOnline
缺点:看到的是反向传播的路径,不是模型结构。
其他失效工具
tensorwatch
10 个 commit 之前是四年前的代码,已不支持 PyTorch 2.x。
报错:
File ~/miniconda3/lib/python3.11/site-packages/tensorwatch/model_graph/hiddenlayer/summary_graph.py:85, in SummaryGraph.__init__(self, model, dummy_input, apply_scope_name_workarounds)
81 # Switch all instances of torch.nn.ModuleList in the model to our DistillerModuleList
82 # See documentation of _DistillerModuleList class for details on why this is done
83 model_clone, converted_module_names_map = _to_distiller_modulelist(model_clone)
---> 85 with torch.onnx.set_training(model_clone, False):
87 device = distiller.model_device(model_clone)
88 dummy_input = distiller.convert_tensors_recursively_to(dummy_input, device=device)
AttributeError: module 'torch.onnx' has no attribute 'set_training'
hiddenlayer
4 年未更新,10 个 commit 之前是 6 年前的代码。
报错:
File ~/miniconda3/lib/python3.11/site-packages/hiddenlayer/pytorch_builder.py:71, in import_graph(hl_graph, model, args, input_names, verbose)
66 def import_graph(hl_graph, model, args, input_names=None, verbose=False):
67 # TODO: add input names to graph
68
69 # Run the Pytorch graph to get a trace and generate a graph from it
70 trace, out = torch.jit._get_trace_graph(model, args)
---> 71 torch_graph = torch.onnx._optimize_trace(trace, torch.onnx.OperatorExportTypes.ONNX)
73 # Dump list of nodes (DEBUG only)
74 if verbose:
AttributeError: module 'torch.onnx' has no attribute '_optimize_trace'
总结
工具 | 是否可用 | 更新频率 | 优点 | 缺点 |
---|---|---|---|---|
Netron | 可用 | 高 | 在线可视化,不用安装 | 需要保存为 onnx 格式,看不到输入输出的尺寸 |
tensorboard | 可用 | 高 | 可交互式展开,可视化效果好 | 需要安装 tensorboard,并且启动后台服务 |
torchview | 可用 | 1 年 | 可以看到每一层的输入输出尺寸 | 需要安装 torchview 和 graphviz |
torchviz | 可用 | 3 年 | 无 | 看到的是反向传播的路径,不是模型结构 |
print 大法 | 永久可用 | 无 | 永久可用,不会失效 | 只有文字,无法展示跨层连接 |
tensorwatch | 失效 | 4 年 | 无 | 无 |
hiddenlayer | 失效 | 4 年 | 无 | 无 |