1.一种联邦学习方法,其特征在于,包括:
步骤S1,中央服务系统将初始全局模型发送给所有客户端,所有客户端对初始全局模型进行训练获得初始本地模型,客户端上传初始本地模型至中央服务系统;
步骤S2,中央服务系统根据客户端上传的初始本地模型对客户端进行聚类获得一个以上客户端类;
步骤S3,对全局模型进行多轮迭代训练直到达到迭代停止条件,设t为正整数,第t轮迭代训练过程为:步骤S301,从每个客户端类中选取至少一个客户端参与第t轮迭代训练;
步骤S302,发送第t轮全局模型至选取的客户端,选取的客户端对接收的第t轮全局模型进行训练并返回第t轮本地模型和第t轮损失函数值至中央服务系统;
步骤S303,中央服务系统基于客户端返回的第t轮本地模型和第t轮损失函数值判断参与第t轮迭代训练的客户端之间是否存在梯度冲突,并根据梯度冲突情况获取累积模型差异;
所述步骤S303具体包括:
步骤A,设第t轮迭代训练选取了M个客户端参与训练,M为正整数;分别计算第t轮全局模型与M个客户端返回的第t轮本地模型的差异获得M个第t轮差异向量,所述M个第t轮差异向量组成序列步骤B,按照相应的第t轮损失函数值从小到大的顺序对M个客户端的第t轮差异向量进行排序并将排序结果保存在序列W0中,步骤C,进行M个客户端之间的梯度冲突判断和获取累积模型差异,具体包括:设从序列W中获取第m个客户端的第t轮差异向量表示为 m表示参与第t轮迭代训练的客户端在M个客户端中的索引,m为正整数,m∈[1,M];设 对应的修正向量为 的初始值为序列W0中的向量排列顺序依次与 进行是否存在梯度冲突判断,当判断结果为存在梯度冲突时,对 进行修正,进行序列W0中下一向量的冲突判断;设 为W0序列中的向量,若 成立,则认为 与 存在梯度冲突,p表示W0序列中向量的索引,p为正整数,p∈[1,M],按照公式 修正
步骤D,求取M个客户端第t轮差异向量的修正向量的和获得累积模型差异;
所述累积模型差异为: 其中,nm表示参与第t轮迭代训练的M个客户端中第m个客户端的数据量,NUMM表示参与第t轮迭代训练中的M个客户端的数据总量;
步骤S304,利用所述累积模型差异更新第t轮全局模型,将更新后的第t轮全局模型作为第t+1轮全局模型。
2.如权利要求1所述的联邦学习方法,其特征在于,所述步骤S2包括:步骤S201,中央服务系统计算初始全局模型与每个客户端的初始本地模型的差异,记为初始差异,对所述初始差异进行向量化,获得初始差异向量;
步骤S202,基于所有客户端的差异向量进行聚类获得多个客户端类。
3.如权利要求1或2所述的联邦学习方法,其特征在于,所述步骤S301具体包括:对于第j个客户端类,从第j个客户端类中抽取 个客户端参与第t轮迭代训练;其中,k表示客户端总数;j表示客户端类的索引,j为正整数;Nj表示第j个客户端类的客户端数量;
在抽取客户端的过程中,每个客户端的抽取概率与其拥有的数据量正相关。
4.如权利要求3所述的联邦学习方法,其特征在于,第i个客户端的抽取概率为ni/NUMj;
其中,i表示客户端在所有客户端中的索引,i为正整数,i∈[1,k];ni表示第i个客户端拥有的数据量;设第i个客户端属于第j个客户端类,NUMj表示第i个客户端所属的第j个客户端类的数据总量。
5.如权利要求1所述的联邦学习方法,其特征在于,在所述步骤S304中,按照如下公式t更新全局模型:ωt+1=ωt+WD;其中,ωt+1表示第t+1轮迭代训练的全局模型,ωt表示第t轮迭代训练的全局模型。
6.一种联邦学习系统,其特征在于,包括中央服务系统和与所述中央服务系统通信连接的多个客户端,所述中央服务系统和多个客户端按照权利要求1‑5之一所述的方法进行联邦学习。
7.如权利要求6所述的联邦学习系统,其特征在于,所述中央服务系统为区块链系统。