PyTorch使用Torchdyn实现连续时间神经网络的代码示例!

PyTorch使用Torchdyn实现连续时间神经网络的代码示例!

神经常微分方程(Neural ODEs)是深度学习领域的创新性模型架构,它将神经网络的离散变换扩展为连续时间动力系统,本文将基于Torchdyn(一个专门用于连续深度学习和平衡模型的PyTorch扩展库)介绍Neural ODE的实现与训练方法,需要的朋友可以参考下。

Torchdyn概述

Torchdyn是基于PyTorch构建的专业库,专注于连续深度学习和隐式神经网络模型(如Neural ODEs)的开发。该库具有以下核心特性:

  • 支持深度不变性和深度可变性的ODE模型
  • 提供多种数值求解算法(如Runge-Kutta法,Dormand-Prince法)
  • 与PyTorch Lightning框架的无缝集成,便于训练流程管理

本教程将以经典的moons数据集为例,展示Neural ODEs在分类问题中的应用。

20252593430254

数据集构建

首先,我们使用Torchdyn内置的数据集生成工具创建实验数据:

1
2
3
4
5
6
7
8
9
10
11
12
from torchdyn.datasets import ToyDataset 
import matplotlib.pyplot as plt 
# 生成示例数据
d = ToyDataset() 
X, yn = d.generate(n_samples=512, noise=1e-1, dataset_type='moons'
# 可视化数据集
colors = ['orange', 'blue'
fig, ax = plt.subplots(figsize=(3, 3)) 
for i in range(len(X)): 
ax.scatter(X[i, 0], X[i, 1], s=1, color=colors[yn[i].int()]) 
plt.show()

数据预处理

将生成的数据转换为PyTorch张量格式,并构建训练数据加载器。Torchdyn支持CPU和GPU计算,可根据硬件环境灵活选择:

1
2
3
4
5
6
7
8
import torch 
import torch.utils.data as data 
device = torch.device("cpu"# 如果使用GPU则改为'cuda'
X_train = torch.Tensor(X).to(device) 
y_train = torch.LongTensor(yn.long()).to(device) 
train = data.TensorDataset(X_train, y_train) 
trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

Neural ODE模型构建

Neural ODEs的核心组件是向量场(vector field),它通过神经网络定义了数据在连续深度域中的演化规律。以下代码展示了向量场的基本实现:

1
2
3
4
5
6
7
8
import torch.nn as nn 
# 定义向量场f
f = nn.Sequential( 
nn.Linear(2, 16), 
nn.Tanh(), 
nn.Linear(16, 2
)

接下来,我们使用Torchdyn的

1
NeuralODE

类定义Neural ODE模型。这个类接收向量场和求解器设置作为输入。

1
2
3
4
from torchdyn.core import NeuralODE 
t_span = torch.linspace(0, 1, 5# 时间跨度
model = NeuralODE(f, sensitivity='adjoint', solver='dopri5').to(device)

类来管理训练过程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import pytorch_lightning as pl 
class Learner(pl.LightningModule): 
def __init__(self, t_span: torch.Tensor, model: nn.Module): 
super().__init__() 
self.model, self.t_span = model, t_span 
def forward(self, x): 
return self.model(x) 
def training_step(self, batch, batch_idx): 
x, y = batch 
t_eval, y_hat = self.model(x, self.t_span) 
y_hat = y_hat[-1# 选择轨迹的最后一个点
loss = nn.CrossEntropyLoss()(y_hat, y) 
return {'loss': loss} 
def configure_optimizers(self): 
return torch.optim.Adam(self.model.parameters(), lr=0.01
def train_dataloader(self): 
return trainloader

最后训练模型:

1
2
3
learn = Learner(t_span, model) 
trainer = pl.Trainer(max_epochs=200
trainer.fit(learn)

实验结果可视化

深度域轨迹分析

训练完成后,我们可以观察数据样本在深度域(即ODE的时间维度)中的演化轨迹:

1
2
3
4
5
6
7
8
9
10
t_eval, trajectory = model(X_train, t_span) 
trajectory = trajectory.detach().cpu() 
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 2)) 
for i in range(500): 
ax0.plot(t_span, trajectory[:, i, 0], alpha=0.1, color=colors[int(yn[i])]) 
ax1.plot(t_span, trajectory[:, i, 1], alpha=0.1, color=colors[int(yn[i])]) 
ax0.set_title("维度 0"
ax1.set_title("维度 1"
plt.show()

向量场可视化

通过可视化学习得到的向量场,我们可以直观理解模型的动力学特性:

1
2
3
4
5
6
7
8
9
10
11
x = torch.linspace(trajectory[:, :, 0].min(), trajectory[:, :, 0].max(), 50
y = torch.linspace(trajectory[:, :, 1].min(), trajectory[:, :, 1].max(), 50
X, Y = torch.meshgrid(x, y) 
z = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1
f_eval = model.vf(0, z.to(device)).cpu().detach() 
fx, fy = f_eval[:, 0], f_eval[:, 1
fx, fy = fx.reshape(50, 50), fy.reshape(50, 50
fig, ax = plt.subplots(figsize=(4, 4)) 
ax.streamplot(X.numpy(), Y.numpy(), fx.numpy(), fy.numpy(), color='black'
plt.show()

Torchdyn进阶特性

Torchdyn框架的功能远不限于基础的Neural ODEs实现。它提供了丰富的高级特性,包括:

  • 高精度数值求解器
  • 平衡模型支持
  • 自定义微分方程系统

无论是物理模型的数值模拟,还是连续深度学习模型的开发,Torchdyn都提供了完整的工具链支持。

以上就是PyTorch使用Torchdyn实现连续时间神经网络的代码示例的详细内容。

 

 

学习资料见知识星球。

以上就是今天要分享的技巧,你学会了吗?若有什么问题,欢迎在下方留言。

快来试试吧,小琥 my21ke007。获取 1000个免费 Excel模板福利​​​​!

更多技巧, www.excelbook.cn

欢迎 加入 零售创新 知识星球,知识星球主要以数据分析、报告分享、数据工具讨论为主;

Excelbook.cn Excel技巧 SQL技巧 Python 学习!

你将获得:

1、价值上万元的专业的PPT报告模板。

2、专业案例分析和解读笔记。

3、实用的Excel、Word、PPT技巧。

4、VIP讨论群,共享资源。

5、优惠的会员商品。

6、一次付费只需129元,即可下载本站文章涉及的文件和软件。

文章版权声明 1、本网站名称:Excelbook
2、本站永久网址:http://www.excelbook.cn
3、本网站的文章部分内容可能来源于网络,仅供大家学习与参考,如有侵权,请联系站长王小琥进行删除处理。
4、本站一切资源不代表本站立场,并不代表本站赞同其观点和对其真实性负责。
5、本站一律禁止以任何方式发布或转载任何违法的相关信息,访客发现请向站长举报。
6、本站资源大多存储在云盘,如发现链接失效,请联系我们我们会第一时间更新。

THE END
分享
二维码
< <上一篇
下一篇>>