无线边缘网络中对数据污染具有鲁棒性的联邦学习方法

文档序号:1938685 发布日期:2021-12-07 浏览:21次 >En<

阅读说明:本技术 无线边缘网络中对数据污染具有鲁棒性的联邦学习方法 (Federal learning method with robustness to data pollution in wireless edge network ) 是由 李文玲 李钰浩 刘杨 于 2021-09-07 设计创作,主要内容包括:本发明公开了一种无线边缘网络中对数据污染具有鲁棒性的联邦学习方法,包括以下步骤:搭建模型结构,并初始化全局参数;中心服务器将全局参数广播至无线边缘网络的客户端,客户端以全局参数作为本轮训练初值;各客户端计算梯度值并进一步更新偏差系数各客户端更新迭代系数各客户端更新一阶动量与二阶动量各客户端更新模型参数重复步骤三至六,至迭代次数达到预设值;各客户端上传本地参数至中心服务器;中心服务器接收各客户端的本地参数并聚合,得到更新后的全局参数;重复步骤二至九,直至全局模型性能达到要求。该联邦学习方法能够提高算法面对有毒数据时的鲁棒性,并减少由本地模型差异性造成的性能影响。(The invention discloses a federal learning method with robustness to data pollution in a wireless edge network, which comprises the following steps: building a model structure and initializing global parameters; the central server broadcasts the global parameters to the client side of the wireless edge network, and the client side takes the global parameters as initial training values of the current round; each client calculates the gradient value And further updating the deviation coefficient Updating iteration coefficient by each client Updating first-order momentum by each client And second order motionMeasurement of Updating model parameters by each client Repeating the third step to the sixth step until the iteration times reach a preset value; uploading local parameters by each client To a central server; the central server receives and aggregates the local parameters of the clients to obtain updated global parameters; and repeating the steps from two to nine until the performance of the global model meets the requirement. The federated learning method can improve the robustness of the algorithm in the face of toxic data and reduce the performance impact caused by local model differences.)

无线边缘网络中对数据污染具有鲁棒性的联邦学习方法

技术领域

本发明属于联邦学习领域,具体涉及无线边缘网络中客户端数据集受污染时的联邦学习方法,在客户端数据受污染时具有鲁棒性,仍能使模型获得较好的性能。

背景技术

数据是机器学习的基础,作为人工智能的主要方向,机器学习需要数据去训练人工智能模型。而在大多数行业中,由于行业竞争、隐私安全、行政手续复杂等问题,数据常常是以孤岛的形式存在的,而只利用数据孤岛内的数据训练所得的人工智能模型性能往往不能满足任务需求。针对数据孤岛和数据隐私的两难问题,联邦学习方法框架应运而生。

在联邦学习方法框架下,有多个相互独立的客户端和一个中心服务器,客户端有不同且不可共享的本地数据。训练过程中,服务器向客户端广播全局参数,客户端将更新下载得到的全局模型参数用于自己的数据集上进行训练,然后只上传本地参数到服务器进行聚合,经过多次“下载-训练-上传-聚合”的过程得到最终模型参数。显然,在联邦学习框架下客户端的数据得到了保护,数据孤岛的问题也得以解决。

联邦学习方法的经典方法是联邦平均,在每个客户端上传参数至服务器后,服务器对局部参数进行加权平均,得到全局参数后服务器再将全局参数广播给各客户端。Adam算法作为SGD的一种变形,具有收敛速度快、超参数易调整的优点。利用梯度信息求取一阶、二阶动量可以使参数快速收敛,并且使学习率自适应调整,因此Adam算法被广泛应用于联邦学习方法的本地训练中。但是,在实际场景中,本地客户端的数据集若由于网络攻击或其他原因受到污染,训练过程中计算得到的随机梯度必然有异常值产生。而Adam算法在参数更新时由于一阶、二阶动量对梯度值的依赖性,使其对异常值的鲁棒性极差。此外,不同的客户端训练生成的本地模型通常具有差异性,在这种差异性下聚合得到的全局模型性能不稳定。

发明内容

有鉴于此,本发明提供一种无线边缘网络中的对数据污染具有鲁棒性的联邦学习方法,以提高算法面对有毒数据时的鲁棒性,并减少由本地模型差异性造成的性能影响。

具体技术方案如下:

一种无线边缘网络中对数据污染具有鲁棒性的联邦学习方法,包括以下步骤:

步骤一:搭建用于学习的模型结构,并初始化全局参数,包括:全局模型参数、全局一阶动量、全局二阶动量;

步骤二:中心服务器将全局参数广播至无线边缘网络的客户端,客户端以全局参数作为本轮训练的初始值;

步骤三:所述客户端利用历史时刻模型参数获取在本地数据集上的梯度值,并获取梯度值与历史一阶动量的偏差系数

步骤四:所述客户端更新二阶动量迭代系数

步骤五:所述客户端利用偏差系数迭代系数梯度值与历史动量值更新一阶动量与二阶动量

步骤六:所述客户端利用更新后的一、二阶动量更新模型参数

步骤七:重复步骤三至步骤六,直至迭代次数达到预设迭代阈值;

步骤八:所述客户端上传本地模型参数一阶动量及二阶动量至中心服务器;

步骤九:中心服务器接收所述客户端的本地参数并进行参数聚合,得到更新后的全局参数xt、mt、vt

步骤十:重复步骤二至步骤九,直至全局模型性能达到要求。

步骤二中所述客户端以全局参数作为本轮训练的初始值表示如下:

其中,下标i表示第i个客户端,上标t′表示本轮训练的初始时刻,xt′为初始时刻全局模型参数,mt′为初始时刻全局一阶动量,vt′为初始时刻全局二阶动量。

步骤三中所述偏差系数更新方式表示如下:

其中,下标i表示第i个客户端,上标t表示当前迭代时刻,d为向量维度,下标j表示向量的第j个分量,g表示梯度值,m表示一阶动量,v表示二阶动量,梯度值 为t时刻第i个客户端的随机采样数据,Di为第i个客户端的本地数据集,为t-1时刻第i个客户端的模型参数,fi为第i个客户端的本地损失函数。

步骤四中所述迭代系数更新方式表示如下:

其中γ为预设常数。

步骤五中所述一阶动量与二阶动量更新方式表示如下:

步骤六中所述模型参数更新方式表示如下:

其中,vt′为初始时刻全局二阶动量,α为预设全局学习率。

步骤九中所述参数聚合方式为加权平均,涉及参数为一阶动量二阶动量与模型参数具体表示如下:

其中,pi为第i个客户端的权重,N为客户端的数量。

本发明与现有技术相比所具有的有益效果:

1.本发明利用偏差系数对异常梯度值进行检测,并且在出现异常值时对更新方向进行控制。具体表现为:当为异常值时,趋近于1,则进而使模型参数更新方向不受异常值影响,体现了本发明对异常梯度值的鲁棒性,从而降低了有毒数据对模型性能的影响。

2.本发明利用使算法在迭代后期计算二阶动量并利用其调整学习率时减少对梯度值的依赖,训练后期趋近于1,使二阶动量满足保证了训练后期梯度值较小造成学习率过大的问题,同时也排除了异常值的影响,体现了本发明对异常梯度值的鲁棒性,并提高了模型的性能。

3.本发明将全局二阶动量作为更新步长的分母,在不同的客户端的本地训练过程中使用相同的学习率,更新过程中本地模型差异变小,使模型性能更稳定。

附图说明

图1为本发明无线边缘网络的结构示意图。

图2为本发明无线边缘网络中的对数据污染具有鲁棒性的联邦学习方法的流程图。

图3为本发明方法与现有技术实验结果对比。

具体实施方式

下面结合附图和实施例对本发明进行进一步的详细介绍。

图1为本发明无线边缘网络的系统结构图,包括一个中心服务器和N个客户端,数据分布在N个客户端中,客户端与服务器只传递参数不传递数据,其中服务器采用全局模型,客户端采用本地模型;为获得性能更好的全局模型,采用联邦学习进行模型训练。

图2为本发明无线边缘网络中的对数据污染具有鲁棒性的联邦学习方法的流程图。开始时对全局参数进行初始化并进行广播,N个客户端利用下载到的参数在本地数据集上进行本地训练。经过本地训练后,客户端将本地参数上传至服务器进行参数的加权平均,并对此时所得全局模型进行评估,若满足性能要求则算法结束,否则继续循环。具体包括以下步骤:

步骤一:搭建用于学习的模型结构,并初始化全局参数,包括:全局模型参数、全局一阶动量、全局二阶动量;

步骤二:中心服务器将全局参数广播至无线边缘网络的客户端设备,客户端以全局参数作为本轮训练的初始值;

步骤三:各客户端利用历史时刻模型参数获取在本地数据集上的梯度值,并获取梯度值与历史一阶动量的偏差系数

步骤四:各客户端更新二阶动量迭代系数

步骤五:各客户端利用偏差系数迭代系数梯度值与历史动量值 更新一阶动量与二阶动量

步骤六:各客户端利用更新后的一、二阶动量更新模型参数

步骤七:重复步骤三至步骤六,直至迭代次数达到预设迭代阈值;

步骤八:各客户端上传本地模型参数一阶动量及二阶动量至中心服务器;

步骤九:中心服务器接收各客户端的本地参数并进行参数聚合,得到更新后的全局参数xt、mt、vt

步骤十:重复步骤二至步骤九,直至全局模型性能达到要求。

以下对本地训练的过程进行说明:

本地训练开始时客户端获取全局参数,包括全局模型参数、全局一阶动量、全局二阶动量,作为本地训练的初始参数值:

其中,下标i表示第i个客户端,上标t′表示本轮训练的初始时刻,xt′为初始时刻全局模型参数,mt′为初始时刻全局一阶动量,vt′为初始时刻全局二阶动量。。

以第i个客户端为例,在每次迭代开始时在本地数据集内进行随机采样得到部分数据,并计算梯度值 为t时刻第i个客户端的随机采样数据,Di为第i个客户端的本地数据集,为t-1时刻第i个客户端的模型参数,fi为第i个客户端的本地损失函数。利用此梯度值与上一步一阶动量构建下式得到

其中,下标i表示第i个客户端,上标t表示当前迭代时刻,d为向量维度,下标j表示向量的第j个分量,g表示梯度值,m表示一阶动量,v表示二阶动量。利用迭代时刻t构建下式计算得到

其中γ为预设常数。构建下式获得当前迭代时刻的一阶动量与二阶动量

由此可见,当异常梯度值出现时,差异性增大,则趋近于1,此时影响减小,保证了更新方向不受异常值影响,故异常值得以控制。同时,当训练步入后期,趋近于1,这保证了在参数接近最优值时学习率不会因异常值出现而过大或过小,也增强了算法鲁棒性。

利用历史迭代时刻本地模型参数与全局学习率α,一阶动量初始时刻全局二阶动量vt′计算得到当前时刻本地模型参数

在更新本地参数时,采用全局二阶动量作为学习率的分母,这保证了不同客户端有相同的更新步长,使不同客户端的本地模型差异性减小,进而提升全局模型的性能。

当本地迭代次数达到预设值时上传本地参数并进行模型融合:

其中,pi为第i个客户端的权重。

实际中,将MNIST手写数字训练集平均分配至十个客户端,同时以50%的概率在每张图片one-hot标签加上均值为0、方差为0.4的高斯噪声,训练逻辑回归模型。全局模型在测试集上的结果如图3所示,本发明所提出的方法准确度与稳定性都优于现有技术。通过本实例所述方法,实现了在受污染数据集下进行模型训练,最大程度的消除了有毒数据对模型性能的影响,精度高并具有稳定性。

以上所述仅为本发明的具体实施方式,并不用于限定本发明的保护范围,凡在本发明的精神和原则之内,所做的任何修改、等同替换、改进等,均应包含在本发明的保护范围之内。

11页详细技术资料下载
上一篇:一种医用注射器针头装配设备
下一篇:一种建筑冷热负荷预测方法、装置、设备及存储介质

网友询问留言

已有0条留言

还没有人留言评论。精彩留言会获得点赞!

精彩留言,会给你点赞!