首页天道酬勤时空AI技术:AutoSTG:面向时空图预测的神经网络架构搜索| WWW 2021

时空AI技术:AutoSTG:面向时空图预测的神经网络架构搜索| WWW 2021

admin 12-01 05:39 283次浏览

时空图是一种重要的数据结构,可广泛用于描述城市中的各种数据形式,如交通数据、空气质量数据等。由于这类时空图数据包含了大量关于城市动态变化的信息,准确预测和分析这些时空图数据是城市智能化的重要一步。本文将介绍一种用于时空图预测的神经网络架构搜索算法框架,该框架来自上海交通大学和京东智能城市研究院的论文《AutoSTG: Neural Architecture Search for Predictions of Spatio-Temporal Graph》。目前,该论文已被数据挖掘领域顶级会议WWW 2021(CCF A类)接收。

一.研究背景

时空图是一种重要的数据结构,可广泛用于描述城市中各种感知数据的形式。如图1(a)所示,交通、流量和空气质量等数据可以建模为时空图结构。这些数据记录了大量关于城市发展变化的信息,对这些数据进行分析和预测可以挖掘其潜在价值。图1(b)显示了近年来在人工智能和数据挖掘领域的顶级会议上发表的与时空图预测相关的论文数量。可以发现,由于时空图预测任务背后的巨大应用价值,这类任务逐渐成为新的研究热点。

图1时空图预报任务现状

但在现实中,时空图的数据种类繁多,往往需要解决不同数据的预测任务。每当我们想要预测一种新的数据,就需要从众多时空神经网络模型中选择一种有效的网络结构,重新设计预测网络。这个过程往往需要基于大量的专业知识,只有不断尝试和调整预测网络的结构,才能得到更好的预测模型。即使你处理的是同一类数据,当面对不同的任务时,直接重用或迁移现有的网络结构是不可行的,导致模型开发成本高的问题。因此,如果能够根据采集到的时空数据高效、自动地设计神经网络结构,就可以帮助我们节约成本,提高效率。为了解决这一痛点,本文以城市交通数据预测任务为支点,探索时空数据预测任务的神经网络结构搜索算法。这项研究需要解决两个挑战:

1.如何构建网络结构搜索空间?由于数据在时间维度和空间维度上是相互关联的,并且这种时空关联作用于一般的图结构,我们需要设计搜索空间来建模数据中复杂的时空关联。

2.属性图影响的网络参数权重如何学习?时空图往往对应一个静态属性图,其中包含时空图的元知识,影响时空数据的变化趋势。以交通数据为例,道路的宽度(边缘属性)和交叉口的岔道数(点属性)反映了该地清除交通流的能力(元知识),从而影响交通数据的变化趋势。同时,这个节点/边元知识也会在属性图结构上进行交互。比如一个路口如果因为规划不合理容易发生交通堵塞,那么它周边的路口也容易发生拥堵。由于网络参数的权重反映了数据中的时空相关性,因此在学习这些参数时需要考虑属性图的影响。

二、模式方案

如图2所示,我们提出的AutoSTG框架主要包括两个模块:(a)网络结构搜索;(b)参数权重学习。在网络结构搜索中,需要设计搜索空间,使得搜索到的模型能够对数据中的时空相关性进行建模。在参数权重学习模块中,采用了元学习的方法。首先利用元学习网络从属性图中学习节点和边的元知识,然后利用元学习网络学习时空网络中的参数权重,从而建模属性图对时空相关性的影响。

图2自动转换框架

1.搜索空间

受许多基于卷积神经网络的时空神经网络模型的启发,我们的AutoSTG框架也是基于卷积运算的。如图2(a)所示,预测网络包括几个Cell或Pooling层,并且所有层的输出被聚集以预测具有完全连接的网络的未来数据。我们的目标是搜索每个细胞的网络结构。每个像元的搜索空间是一个有向无环图。每条边对应四个可能的网络:零、同一性、空间图卷积和时间卷积。我们需要确定每个边对应的网络。在这项工作中,我们使用DARTS方案[1]来学习网络的结构权重。

在时空图,数据的相关性受属性图的影响,因此我们在空间图卷积网络中使用元学习生成邻接矩阵,在时间卷积网络中使用卷积核来建模网络的参数权重与属性图中包含的元知识之间的关系。接下来,我们详细介绍元学习方法。

2.权重元素学习

首先,我们需要从属性图中学习节点和边的元知识。为了建模节点和边之间的交互,我们提出了一个图元知识学习网络。如图3所示,网络在多个步骤中迭代,每个步骤分为两部分:(a)对于每个节点,网络用图卷起,其周围节点的表示用边聚合;(b)对于每个边,聚合两个连接节点的表示。通过多次迭代,我们可以在属性图上传播节点和边的值。

属性信息,并学到有效表征,作为节点和边的元知识。

图3 图元知识学习网络

在获得节点和边的元知识后,我们就可以用一组全连接网络,将边元知识分别映射为每个空间图卷积网络的邻接矩阵;用另一组全连接网络,将点元知识分别映射为每个时间卷积网络的卷积核,从而建模属性图对时空相关性的影响。

在网络结构搜索阶段,和DARTS[1]类似,AutoSTG框架交替优化网络结构权重和元学习网络参数权重,即可输出搜索后的网络结构。最后,以该网络结构为基础重新训练网络参数权重,即可得到时空数据预测模型。

三、实验结果

1. 模型准确度对比

我们用两个真实的交通数据集PEMS-BAY和METR-LA来验证AutoSTG框架的有效性,实验设置与DCRNN[2]完全相同,实验选用平均绝对误差(简称MAE)和均方根误差(简称RMSE)评价模型的预测准确度。

如表所示,传统模型HA、GBRT在这两个数据集中的预测准确度最低。这是因为这两个模型依赖人工提取的统计信息,而无法自动从交通数据中学习高阶时空特征。鉴于交通数据的复杂性,这些模型无法有效建模时空数据。其次,GAT-Seq2Seq模型中分别用循环神经网络和图注意力神经网络建模数据中的时间及空间相关性。由于神经网络模型能够帮助学习高阶时空特征,因此它比非深度学习模型的预测准确度高很多。但是,这个模型无法利用空间信息(路网结构、距离),忽略了空间信息和数据中时空相关性的关联。因此,预测准确度仍有较大提升空间。

DCRNN和Graph WaveNet两个模型在这两个数据集上都取得了很不错的预测准确度。其原因在于它们基于先验知识,将路网信息引入图卷积神经网络的计算。虽然这种额外的信息可以显著提升准确度,但是,这需要基于对交通数据深入的理解,在实际场景中,这类方案无法直接迁移到其他时空预测任务。为了解决这个问题,ST-MetaNet+框架能直接学习空间信息与时空相关性的关联,从而取得更好的预测准确度。

与所有基准模型相比,AutoSTG框架可以自动学习时空神经网络结构,并取得与人工精心设计的模型相近的预测准确度。尤其是在PEMS-BAY数据集中,AutoSTG搜索出的结构在预测准确度指标MAE和RMSE上都提升2%,这也进一步说明,AutoSTG是一个有效的时空神经网络结构搜索框架。

2. 候选网络结构有效性实验

为了验证空间图卷积神经网络和时间卷积神经网络在AutoSTG中的有效性,我们在两个预测任务上测试四组AutoSTG框架及其变体:(1)w/o SC,即在候选网络结构集合中,删除空间图卷积神经网络;(2)w/o TC,即在候选网络结构集合中,删除时间卷积神经网络;(3)w/o SC & TC,即删掉SC和TC网络;(4)完整的AutoSTG框架。

如图4所示,在两个数据集上删除空间图卷积神经网络或时间卷积神经网络,都会降低模型的预测准确度。其中,空间图卷积神经网络对性能的影响比时间卷积神经网络更大。其原因在于,交通数据中,空间相关性较为复杂,准确预测这两个数据集需要大量表示空间相关性的特征;而网络输入中涉及到的时间片段较少,因此,直接用全连接神经网络一定程度上也可以建模时间维度上的关联。但总的来说,两种候选神经网络结构在AutoSTG中都可以带来预测准确度的提升。

图4 AutoSTG候选网络结构的有效性实验

3. 算法模块有效性实验

为了验证AutoSTG中各算法模块的有效性,我们测试四组AutoSTG框架及其变体:(1)随机搜索,从搜索空间中随机采样网络结构,直接训练网络参数并预测结果;(2)DARTS,在AutoSTG中,去掉元学习模块,即直接用Graph WaveNet中介绍的空间图卷积和时间卷积计算(不再自动学习空间图卷积神经网络的邻接矩阵和时间卷积神经网络的卷积核);(3) w/o graph,在元知识学习网络中,直接用全连接神经网络分别学习节点和边的元知识,而不再考虑节点和边之间元知识的关联;(4)完整的AutoSTG框架。

如图5所示,由于较为完善地定义了搜索空间,随机采样网络结构已经可以取得不错的预测结果。其次,在没有用参数元学习方案的时候,算法准确度下降较多,说明参数元学习是很有效的。最后,在进一步考虑节点和边之间特性关联的情况下,预测准确度还能够进一步提升。总的来说,AutoSTG框架里的每一个算法模块都能显著提升搜索出来的网络结构的性能。

图5 AutoSTG算法模块有效性实验

四、总结

在这个工作中,我们以城市交通预测任务为支点,初步研究了针对时空数据预测任务的神经网络结构自动化搜索问题。为此,我们提出了一个时空神经网络搜索框架AutoSTG,并将该框架应用于交通预测任务中,并在两个真实的路网交通速度预测任务中,验证了AutoSTG框架的有效性。未来,我们将会尝试在其他时空图建模任务中应用AutoSTG框架。

参考文献

[1] Liu, Hanxiao, Karen Simonyan, and Yiming mldrjb. "Darts: Differentiable architecture search." arXiv preprint arXiv:1806.09055 (2018).

[2] Li, Yaguang, et al. "Diffusion convolutional recurrent neural network: Data-driven traffic forecasting." arXiv preprint arXiv:1707.01926 (2017).

linux下运行文件的命令是什么-linux运维Semaphore以及CyclicBarrier基础网络 UNet防火墙启用教程iOS项目中的version和build详解IOS网络请求之AFNetWorking3.x使用详情Java设计模式之装饰模式是什么及怎么实现连接Vue登录功能如何实现如何获取three.jsLineSegments仅渲染可见线怎么用Java实现单机版五子棋游戏付费方式等。html和css算不算编程语言
系统运行质量评价维度(正方教务系统教学质量评价) solidworks保存图纸格式(solidworks零件格式)
相关内容