本發(fā)明涉及人工智能,具體地,涉及一種能夠同時實現(xiàn)穩(wěn)定化和加速化全局收斂的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,同時涉及一種相應(yīng)的裝置、系統(tǒng)、計算機終端和計算機可讀存儲介質(zhì)。
背景技術(shù):
1、現(xiàn)代分布式網(wǎng)絡(luò)日趨龐大,例如手機、可穿戴設(shè)備和自動駕駛等每天都會產(chǎn)生大量數(shù)據(jù)。然而,由于不斷增加的數(shù)據(jù)隱私問題和網(wǎng)絡(luò)通信限制,純粹的集中式數(shù)據(jù)存儲和分析方法變得不切實際。事實上,數(shù)據(jù)所有者經(jīng)常擔心與第三方共享他們的數(shù)據(jù)會造成隱私泄露。在此背景下,美國的《消費者隱私權(quán)利法案》和歐盟委員會的《通用數(shù)據(jù)保護條例》(gdpr)等嚴格立法旨在保護用戶的隱私。鑒于此因素,數(shù)據(jù)存儲和分析正在從集中式轉(zhuǎn)向分布式。這種轉(zhuǎn)變的關(guān)鍵推動技術(shù)是邊緣計算,其中邊緣客戶端如智能手機、傳感器和微型服務(wù)器等具有較高的計算和存儲能力,使它們能夠以較小的延遲分析數(shù)據(jù)和傳輸結(jié)果。
2、在此背景下,mcmahan等人在2016年發(fā)表的“communication-efficient?learningof?deep?networks?from?decentralized?data”一文中提出了聯(lián)邦學習。聯(lián)邦學習的主要想法是邊緣客戶端首先根據(jù)全局模型初始化本地模型,然后基于大量存儲的本地數(shù)據(jù)并行訓練本地機器學習模型,之后無需共享其原始數(shù)據(jù),只需發(fā)送本地機器學習模型到服務(wù)器。服務(wù)器然后聚合所有的本地模型,并且基于此更新獲得新的全局模型,重復(fù)這些步驟,直到滿足某個收斂標準。
3、nag(nesterov?accelerated?gradient?descent)加速算法實現(xiàn)了在一階優(yōu)化算法中的最優(yōu)收斂速率。因此,為了加速聯(lián)邦學習的訓練過程,yu等人在2019年的《international?conference?on?machine?learning》發(fā)表的論文“on?the?linearspeedup?analysis?of?communication?efficient?momentum?sgd?for?distributed?non-convex?optimization”中提出了在本地并行執(zhí)行標準的nag算法,然而由于數(shù)據(jù)異構(gòu)性,本地加速訓練很容易造成過擬合,造成全局模型收斂不穩(wěn)定甚至不收斂。
4、經(jīng)過檢索發(fā)現(xiàn):
5、公開號為cn117829317a的中國發(fā)明專利申請《一種基于本地模型差異的個性化聯(lián)邦學習方法》,包括以下步驟:服務(wù)器發(fā)送全局模型給本地客戶端,初始化化客戶端模型;客戶端根據(jù)本地數(shù)據(jù)進行模型訓練,并上傳模型參數(shù)給服務(wù)器;服務(wù)器根據(jù)這一輪客戶端上傳的參數(shù)進行生成模型差異矩陣;服務(wù)器依據(jù)客戶端模型進行全局模型聚合,同時利用模型差異矩陣選擇合適的鄰居模型聚合并更新各個客戶端的個性化模型;本地客戶端根據(jù)服務(wù)端返回的模型更新本地模型。該聯(lián)邦學習方法仍然存在如下技術(shù)問題:
6、該方案考慮的是通過節(jié)點選擇策略為每個節(jié)點訓練一個個性化的私有模型,無法訓練一個全局最優(yōu)模型。
7、該方案中本地訓練仍然是基于傳統(tǒng)的梯度下降算法進行本地模型更新,這種策略仍然會導(dǎo)致本地訓練陷入局部最優(yōu),導(dǎo)致過擬合;同時該方案并未起到本地加速訓練的效果。
技術(shù)實現(xiàn)思路
1、本發(fā)明針對現(xiàn)有技術(shù)不足,提供了一種能夠同時實現(xiàn)穩(wěn)定化和加速化全局收斂的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,同時提供了一種相應(yīng)的裝置、系統(tǒng)、計算機終端和計算機可讀存儲介質(zhì)。
2、根據(jù)本發(fā)明的一個方面,提供了一種基于nag模擬的聯(lián)邦學習穩(wěn)定加速方法,應(yīng)用于客戶端,所述客戶端包括本地模型,所述方法包括:
3、接收服務(wù)器發(fā)送的全局加速項和全局模型;
4、根據(jù)所述全局模型初始化本地模型;
5、將本地訓練數(shù)據(jù)輸入經(jīng)過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù);
6、根據(jù)所述本地模型的損失函數(shù),基于反向傳播和所述全局加速項,獲得本地加速項;
7、基于所述本地加速項,對所述本地模型進行更新;
8、在所述本地模型更新定次數(shù)之后,將更新末次和初始的本地模型作差,確定本地模型目標改變量。
9、優(yōu)選地,所述本地模型的損失函數(shù)為本地數(shù)據(jù)的訓練標簽和真實標簽之間的交叉熵損失。
10、優(yōu)選地,所述根據(jù)所述本地模型的損失函數(shù),基于反向傳播和所述全局加速項,獲得本地加速項,包括:
11、對反向傳播所得梯度與所述全局加速項進行首次滑動平均;
12、將所述梯度繼續(xù)與所述首次滑動平均的結(jié)果進行二次滑動平均,獲得本地加速項。
13、優(yōu)選地,所述基于所述本地加速項,對所述本地模型進行更新,包括:
14、基于所述本地加速項乘以學習率,對所述本地模型進行更新。
15、根據(jù)本發(fā)明的另一個方面,提供了一種基于nag模擬的聯(lián)邦學習穩(wěn)定加速方法,應(yīng)用于服務(wù)器端,所述服務(wù)器端包括全局模型,所述方法包括:
16、接收客戶端上傳的本地模型目標改變量;
17、在服務(wù)器端聚合所有客戶端的本地模型目標改變量,更新全局模型;
18、基于所述本地模型目標改變量的聚合結(jié)果,更新全局加速項。
19、優(yōu)選地,所述在服務(wù)器端聚合所有客戶端的本地模型目標改變量,更新全局模型,包括:
20、通過加權(quán)求和的方法對所有客戶端的本地模型目標改變量進行聚合,所獲得的聚合結(jié)果與陳舊的全局模型線性相加,更新得到全新的全局模型。
21、優(yōu)選地,所述基于所述本地模型目標改變量的聚合結(jié)果,更新全局加速項,包括:
22、將所述本地模型目標改變量的聚合結(jié)果與陳舊的全局加速項線性相加,更新得到全新的全局加速項。
23、優(yōu)選地,上述方法,還包括:
24、將更新后的全局模型和全局加速項發(fā)送至客戶端。
25、根據(jù)本發(fā)明的第三個方面,提供了一種基于nag模擬的聯(lián)邦學習穩(wěn)定加速裝置,應(yīng)用于客戶端,所述客戶端包括用于提供本地模型的本地模型模塊,所述裝置包括:
26、客戶端接收模塊,該模塊用于接收服務(wù)器發(fā)送的全局模型和全局加速項;
27、客戶端初始化模塊,該模塊用于根據(jù)所述全局模型初始化本地模型;
28、客戶端第一確定模塊,該模塊用于將本地訓練數(shù)據(jù)輸入經(jīng)過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù);
29、客戶端第二確定模塊,該模塊用于根據(jù)所述本地模型的損失函數(shù),基于反向傳播和所述全局加速項,獲得本地加速項;
30、客戶端更新模塊,該模塊基于所述本地加速項,對所述本地模型進行更新;
31、客戶端第三確定模塊,該模塊用于在所述本地模型更新定次數(shù)之后,將更新末次和初始的本地模型作差,確定本地模型目標改變量。
32、根據(jù)本發(fā)明的第四個方面,提供了一種基于nag模擬的聯(lián)邦學習穩(wěn)定加速裝置,應(yīng)用于服務(wù)器端,所述服務(wù)器端包括用于提供全局模型的全局模型模塊,所述裝置包括:
33、服務(wù)器初始化模塊,用于初始化全局模型和全局加速項;
34、服務(wù)器發(fā)送模塊,用于將經(jīng)過所述初始化的全局模型和全局加速項發(fā)送至所述客戶端;
35、服務(wù)器接收模塊,用于接收所述客戶端發(fā)送的本地模型改變量;
36、服務(wù)器第一確定模塊,用于將所述本地模型改變量進行全局聚合處理,確定全局模型;
37、服務(wù)器第二確定模塊,用于基于所述本地模型改變量的聚合結(jié)果,確定全局加速項。
38、根據(jù)本發(fā)明的第五個方面,提供了一種基于nag模擬的聯(lián)邦學習方法,包括:
39、在客戶端,接收服務(wù)器發(fā)送的全局加速項和全局模型;根據(jù)所述全局模型初始化本地模型;將本地訓練數(shù)據(jù)輸入經(jīng)過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù);根據(jù)所述本地模型的損失函數(shù),基于反向傳播和所述全局加速項,獲得本地加速項;基于所述本地加速項,對所述本地模型進行更新;在所述本地模型更新定次數(shù)之后,將更新末次和初始的本地模型作差,確定本地模型目標改變量,并上傳至服務(wù)器端;
40、在服務(wù)器端,接收客戶端上傳的本地模型目標改變量;聚合所有客戶端的本地模型目標改變量,更新全局模型;基于所述本地模型目標改變量的聚合結(jié)果,更新全局加速項;將更新的全局模型和全局加速項傳輸至客戶端。
41、根據(jù)本發(fā)明的第六個方面,提供了一種基于nag模擬的聯(lián)邦學習系統(tǒng),包括:布置于客戶端的本地模型更新模塊和模型改變量計算模塊、布置于服務(wù)器端的全局聚合模塊和全局加速項更新模塊以及用于連接所述客戶端和所述服務(wù)器端的通信模塊;其中:
42、所述本地模型更新模塊,用于接收服務(wù)器發(fā)送的全局加速項和全局模型;根據(jù)所述全局模型初始化本地模型;將本地訓練數(shù)據(jù)輸入經(jīng)過初始化的所述本地模型進行模型訓練,確定所述本地模型的損失函數(shù);根據(jù)所述本地模型的損失函數(shù),基于反向傳播和所述全局加速項,獲得本地加速項;基于所述本地加速項,對所述本地模型進行更新;
43、所述模型改變量計算模塊,用于在所述本地模型更新定次數(shù)之后,將更新末次和初始的本地模型作差,確定本地模型目標改變量,并上傳至服務(wù)器端;
44、所述全局聚合模塊,用于接收客戶端上傳的本輪參與訓練的本地模型目標改變量;聚合所有客戶端的本地模型目標改變量,更新全局模型;
45、所述全局加速項更新模塊,用于利用所述全局聚合模塊獲得的本地模型目標改變量的聚合結(jié)果,更新全局加速項;
46、所述通信模塊,用于將客戶端所確定的本地模型目標改變量傳輸至服務(wù)器端,并將服務(wù)器端更新的全局模型和全局加速項傳輸至客戶端。
47、根據(jù)本發(fā)明的第七個方面,提供了一種計算機終端,包括存儲器、處理器及存儲在存儲器上并可在處理器上運行的計算機程序,該處理器執(zhí)行該計算機程序時可用于執(zhí)行本發(fā)明上述中任一項所述的方法,或,運行本發(fā)明上述中任一項所述的裝置。
48、根據(jù)本發(fā)明的第八個方面,提供了一種計算機可讀存儲介質(zhì),其上存儲有計算機程序,該計算機程序被處理器執(zhí)行時可用于執(zhí)行本發(fā)明上述中任一項所述的方法,或,運行本發(fā)明上述中任一項所述的裝置。
49、由于采用了上述技術(shù)方案,本發(fā)明與現(xiàn)有技術(shù)相比,具有如下至少一項的有益效果:
50、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,本地模型更新通過全局加速項進行加速訓練,能夠起到加速收斂的作用,減少所需的通信輪次即可完成訓練。
51、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,本地模型更新通過全局加速項進行加速訓練,能夠避免本地訓練陷入局部最優(yōu),緩解本地模型過擬合,提高了所聚合的全局模型的泛化表現(xiàn)。
52、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,本地模型更新通過全局加速項進行加速訓練,在本地穩(wěn)定加速訓練中,一方面因此本地加速過程能夠獲得全局加速項的約束保證局部訓練不會陷入局部最優(yōu),同時通過全局加速項能夠?qū)崿F(xiàn)加速收斂的作用,獲得全局最優(yōu)。
53、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,采用可實現(xiàn)穩(wěn)定化和加速化全局收斂聯(lián)邦學習的技術(shù),為聯(lián)邦學習場景提供了一種新的解決方案。通過在服務(wù)器端和本地客戶端分別使用模擬nag加速算法,既可以加速訓練,同時使收斂變得穩(wěn)定,避免客戶端更新對私有數(shù)據(jù)過擬合,提高了全局模型的收斂速率和泛化表現(xiàn)。
54、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,通過模擬nag加速算法,既能加速模型收斂,同時能夠使收斂穩(wěn)定,提高模型收斂速率和泛化表現(xiàn),并且同時不會造成本地數(shù)據(jù)隱私泄露。
55、本發(fā)明提供的基于nag模擬的聯(lián)邦學習方法及其穩(wěn)定加速方法,可以作為一種保護用戶數(shù)據(jù)隱私的分布式機器學習方法,并進一步應(yīng)用于圖像分類和自然語言處理等多個應(yīng)用領(lǐng)域。