本發(fā)明屬于信息,具體涉及一種用于工業(yè)異構(gòu)設(shè)備的對(duì)比雙焦點(diǎn)知識(shí)蒸餾聯(lián)邦學(xué)習(xí)方法。
背景技術(shù):
1、隨著6g網(wǎng)絡(luò)的快速發(fā)展,工業(yè)互聯(lián)網(wǎng)正迎來(lái)深刻變革,其中邊緣計(jì)算能力的提升成為關(guān)鍵驅(qū)動(dòng)力。6g網(wǎng)絡(luò)不僅提供了高速、低延遲的信息傳輸,還通過(guò)邊緣計(jì)算技術(shù),使大量分布式客戶端設(shè)備能夠本地高效處理數(shù)據(jù)并進(jìn)行初步模型訓(xùn)練。這種能力優(yōu)化了數(shù)據(jù)處理流程,客戶端設(shè)備可以即時(shí)分析和提煉關(guān)鍵信息,再將精煉后的數(shù)據(jù)和初步模型上傳至服務(wù)器進(jìn)行進(jìn)一步融合和優(yōu)化,從而生成高性能的綜合模型。這種模式有效減少了數(shù)據(jù)傳輸量,減輕了服務(wù)器的計(jì)算負(fù)擔(dān),實(shí)現(xiàn)了資源的高效利用和優(yōu)化配置,尤其在流程工業(yè)中需求尤為突出。流程工業(yè)具有高度自動(dòng)化和復(fù)雜的生產(chǎn)流程,對(duì)實(shí)時(shí)監(jiān)控與高安全性有著嚴(yán)苛要求。
2、然而,隨著分布式智能設(shè)備的激增,如何在保障數(shù)據(jù)隱私的前提下實(shí)現(xiàn)高效協(xié)同成為新的挑戰(zhàn)。聯(lián)邦學(xué)習(xí)作為一種無(wú)需集中存儲(chǔ)數(shù)據(jù)的分布式模型訓(xùn)練技術(shù),備受關(guān)注。通過(guò)將模型訓(xùn)練任務(wù)分布至各客戶端設(shè)備,聯(lián)邦學(xué)習(xí)不僅保護(hù)了數(shù)據(jù)隱私,還充分利用了客戶端設(shè)備的計(jì)算能力,實(shí)現(xiàn)了分布式協(xié)同學(xué)習(xí)。以平局聯(lián)邦學(xué)習(xí)(federated?averagingfedavg)為代表的經(jīng)典聯(lián)邦學(xué)習(xí)算法通過(guò)各客戶端設(shè)備在本地更新模型,并將模型更新而非原始數(shù)據(jù)傳送至服務(wù)器進(jìn)行聚合,從而提升全局模型性能。然而,工業(yè)場(chǎng)景中的數(shù)據(jù)分布通常呈現(xiàn)非獨(dú)立同分布(non-independent?identically?distributed?non-iid)的特性,同時(shí)客戶端模型存在異構(gòu)性。這種情況容易在模型優(yōu)化過(guò)程中偏離全局最優(yōu)解,從而影響模型的準(zhǔn)確性。
3、例如,公開號(hào)為cn?118709753?a的中國(guó)發(fā)明專利,公開了基于聯(lián)邦學(xué)習(xí)的工業(yè)設(shè)備集群非獨(dú)立同分布數(shù)據(jù)處理框架,提出一種基于聯(lián)邦學(xué)習(xí)的工業(yè)設(shè)備集群非獨(dú)立同分布數(shù)據(jù)處理框架,旨在應(yīng)對(duì)工業(yè)設(shè)備集群中普遍存在的非獨(dú)立同分布數(shù)據(jù)挑戰(zhàn)。該框架融合了優(yōu)化的參考樣本選擇策略與聯(lián)邦學(xué)習(xí)技術(shù),旨在減輕由數(shù)據(jù)不平衡及數(shù)據(jù)特征類別缺失所帶來(lái)的不利影響。但是,該專利依賴過(guò)采樣技術(shù)來(lái)調(diào)整本地?cái)?shù)據(jù)集,以期篩選出數(shù)據(jù)量均衡且結(jié)構(gòu)相對(duì)完整的客戶端。這一策略的有效性建立在一個(gè)關(guān)鍵假設(shè)之上,即系統(tǒng)中必須存在至少一部分客戶端擁有相對(duì)完整的數(shù)據(jù)結(jié)構(gòu)。此外,該框架在設(shè)計(jì)上未考慮客戶端間可能存在的模型異構(gòu)性,即不同客戶端可能使用不同結(jié)構(gòu)或參數(shù)的模型進(jìn)行訓(xùn)練。這一限制可能限制了框架在復(fù)雜工業(yè)環(huán)境中的廣泛適用性。
4、例如,公開號(hào)為cn?118279639?a的中國(guó)發(fā)明專利,公開了神經(jīng)架構(gòu)搜索下知識(shí)蒸餾聯(lián)邦學(xué)習(xí)的醫(yī)學(xué)圖像分類方法,提出一種神經(jīng)架構(gòu)搜索下知識(shí)蒸餾個(gè)性化聯(lián)邦學(xué)習(xí)的醫(yī)學(xué)圖像分類方法,該方法通過(guò)只共享模型的輸出層,來(lái)將各個(gè)客戶端的異構(gòu)模型進(jìn)行聚合。但是,盡管該專利通過(guò)引入知識(shí)蒸餾機(jī)制,成功地在聯(lián)邦學(xué)習(xí)方法內(nèi)支持了客戶端與服務(wù)器端模型的異構(gòu)性,但在應(yīng)對(duì)那些分布場(chǎng)景差異極為顯著的工業(yè)應(yīng)用環(huán)境時(shí),其采用的傳統(tǒng)知識(shí)蒸餾方法可能會(huì)遭遇性能顯著下滑的。
5、因此,需要一種用于工業(yè)異構(gòu)設(shè)備的對(duì)比雙焦點(diǎn)知識(shí)蒸餾聯(lián)邦學(xué)習(xí)方法,能夠快速響應(yīng)、精準(zhǔn)分類工業(yè)圖像的智能分類器。
技術(shù)實(shí)現(xiàn)思路
1、本發(fā)明針對(duì)上述背景技術(shù)的不足,提出了一種用于工業(yè)異構(gòu)設(shè)備的對(duì)比雙焦點(diǎn)知識(shí)蒸餾聯(lián)邦學(xué)習(xí)方法(contrastive?bifocal?distillation?for?federated?learningcbdfl),旨在通過(guò)知識(shí)傳遞而非模型結(jié)構(gòu)共享,提升non-iid數(shù)據(jù)場(chǎng)景下的學(xué)習(xí)效率。
2、為了解決這一問(wèn)題,提出了cbdfl方法,該方法遵循聯(lián)邦學(xué)習(xí)的基本范式,主要由兩個(gè)核心組件構(gòu)成:服務(wù)器和客戶端。服務(wù)器的主要任務(wù)是接收來(lái)自多個(gè)客戶端的知識(shí),通過(guò)雙焦點(diǎn)蒸餾策略推動(dòng)服務(wù)器模型演化為更高精度、更具泛化能力的服務(wù)器端全局模型??蛻舳说闹饕蝿?wù)是計(jì)算全局模型與客戶端本地模型輸出的特征對(duì)比損失和相對(duì)熵(kullback-leibler?divergence?kl散度)損失,并將其作為知識(shí)蒸餾的損失項(xiàng)。同時(shí),客戶端還需計(jì)算本地模型與全局模型和上一輪本地模型輸出的特征對(duì)比損失,以及本地模型輸出與硬性標(biāo)簽之間的交叉熵,并將其作為本地監(jiān)督損失項(xiàng)。通過(guò)知識(shí)蒸餾損失項(xiàng)和本地監(jiān)督損失項(xiàng)的聯(lián)合訓(xùn)練,本地模型在保持原有數(shù)據(jù)分布優(yōu)勢(shì)的同時(shí),能夠獲得更高的全局準(zhǔn)確率。
3、本發(fā)明的目的通過(guò)以下技術(shù)方案來(lái)實(shí)現(xiàn),一種用于工業(yè)異構(gòu)設(shè)備的對(duì)比雙焦點(diǎn)知識(shí)蒸餾聯(lián)邦學(xué)習(xí)方法,所述方法包括以下步驟:
4、步驟s1:通過(guò)1個(gè)服務(wù)器和n個(gè)客戶端構(gòu)建對(duì)比雙焦點(diǎn)知識(shí)蒸餾聯(lián)邦學(xué)習(xí)模型,完成?k?類分類任務(wù),每個(gè)客戶端擁有一個(gè)本地?cái)?shù)據(jù)集;各個(gè)客戶端分別在本地?cái)?shù)據(jù)集上進(jìn)行預(yù)訓(xùn)練,得到初始化本地模型;
5、步驟s2:客戶端與服務(wù)器進(jìn)行聯(lián)邦學(xué)習(xí)訓(xùn)練,在客戶端與服務(wù)器進(jìn)行第q輪通信中,將各客戶端的本地模型輸出的知識(shí)在服務(wù)器進(jìn)行聚合,得到知識(shí)<msup><msup><mi>z</mi><mi>c</mi></msup><mi>q</mi></msup><mi>=[</mi><msub><msup><msup><mi>z</mi><mi>c</mi></msup><mi>q</mi></msup><mn>1</mn></msub><mi>, </mi><msub><msup><msup><mi>z</mi><mi>c</mi></msup><mi>q</mi></msup><mn>2</mn></msub><mi>, ?, </mi><msub><msup><msup><mi>z</mi><mi>c</mi></msup><mi>q</mi></msup><mi>t</mi></msub><mi>, ?, </mi><msub><msup><msup><mi>z</mi><mi>c</mi></msup><mi>q</mi></msup><mi>k</mi></msub><mi>]</mi>,用于更新全局模型,在第i次迭代中,全局模型輸出的知識(shí)為<msubsup><msup><mi>z</mi><mi>s</mi></msup><mi>i</mi><mi>q</mi></msubsup><mi>=[</mi><msub><msubsup><msup><mi>z</mi><mi>s</mi></msup><mi>i</mi><mi>q</mi></msubsup><mn>1</mn></msub><mi>, </mi><msub><msubsup><msup><mi>z</mi><mi>s</mi></msup><mi>i</mi><mi>q</mi></msubsup><mn>2</mn></msub><mi>, ?, </mi><msub><msubsup><msup><mi>z</mi><mi>s</mi></msup><mi>i</mi><mi>q</mi></msubsup><mi>t</mi></msub><mi>, ?, </mi><msub><msubsup><msup><mi>z</mi><mi>s</mi></msup><mi>i</mi><mi>q</mi></msubsup><mi>k</mi></msub><mi>]</mi>;
6、步驟s3:利用softmax函數(shù),根據(jù)聚合的知識(shí)和全局模型輸出的知識(shí),分別計(jì)算客戶端的隱性概率和顯性概率矩陣以及全局模型第i次迭代時(shí)服務(wù)器的隱性概率和顯性概率矩陣,并通過(guò)kl散度求得隱性損失項(xiàng)和顯性損失項(xiàng),用兩個(gè)超參數(shù),將隱性損失項(xiàng)和顯性損失項(xiàng)聯(lián)系起來(lái)作為雙焦點(diǎn)知識(shí)蒸餾損失項(xiàng);
7、步驟s4:在第i次迭代中,利用雙焦點(diǎn)知識(shí)蒸餾損失項(xiàng)對(duì)全局模型進(jìn)行更新,生成新一輪的全局模型,直至i=es,得到全局模型,全局模型輸出的知識(shí)為<msup><msup><mi>z</mi><mi>s</mi></msup><mi>q</mi></msup><mi>=[</mi><msub><msup><msup><mi>z</mi><mi>s</mi></msup><mi>q</mi></msup><mn>1</mn></msub><mi>, </mi><msub><msup><msup><mi>z</mi><mi>s</mi></msup><mi>q</mi></msup><mn>2</mn></msub><mi>, ?, </mi><msub><msup><msup><mi>z</mi><mi>s</mi></msup><mi>q</mi></msup><mi>k</mi></msub><mi>]</mi>,其中,es為設(shè)定的服務(wù)器的迭代次數(shù);
8、步驟s5:服務(wù)器將全局模型和其輸出的知識(shí)的發(fā)送到各個(gè)客戶端,在第f次迭代中,本地模型輸出的知識(shí)為<msubsup><msup><mi>z</mi><msub><mi>c</mi><mi>n</mi></msub></msup><mi>f</mi><mi>q+1</mi></msubsup><mi>=[</mi><msub><msubsup><msup><mi>z</mi><msub><mi>c</mi><mi>n</mi></msub></msup><mi>f</mi><mi>q+1</mi></msubsup><mn>1</mn></msub><mi>, </mi><msub><msubsup><msup><mi>z</mi><msub><mi>c</mi><mi>n</mi></msub></msup><mi>f</mi><mi>q+1</mi></msubsup><mn>2</mn></msub><mi>, ?, </mi><msub><msubsup><msup><mi>z</mi><msub><mi>c</mi><mi>n</mi></msub></msup><mi>f</mi><mi>q+1</mi></msubsup><mi>k</mi></msub><mi>]</mi>,;
9、步驟s6:利用kl散度計(jì)算全局模型輸出知識(shí)與本地模型輸出知識(shí)的kl散度損失項(xiàng);
10、步驟s7:根據(jù)全局模型和本地模型輸出的特征值計(jì)算相似矩陣m,并求得對(duì)比損失項(xiàng),其中,表示平均值;
11、步驟s8:將對(duì)比損失項(xiàng)和kl散度損失項(xiàng)利用,兩個(gè)超參數(shù)聯(lián)系起來(lái)作為客戶端蒸餾總損失項(xiàng);
12、步驟s9:在第f次迭代中,利用客戶端蒸餾總損失項(xiàng)對(duì)本地模型進(jìn)行更新,生成本地模型;
13、步驟s10:將本地?cái)?shù)據(jù)輸入到本地模型中,計(jì)算真實(shí)標(biāo)簽與預(yù)測(cè)結(jié)果的交叉熵,作為硬性指標(biāo)監(jiān)督損失項(xiàng);
14、步驟s11:將第f-1次迭代的本地模型輸出的特征值視為負(fù)樣本,將全局模型輸出值視為正樣本,計(jì)算對(duì)比監(jiān)督損失項(xiàng);
15、步驟s12:將硬性指標(biāo)監(jiān)督損失項(xiàng)和對(duì)比監(jiān)督損失項(xiàng),利用,兩個(gè)超參數(shù)聯(lián)系起來(lái)作為客戶端模型監(jiān)督總損失項(xiàng);
16、步驟s13:在第f次迭代中,利用客戶端模型監(jiān)督總損失項(xiàng),對(duì)本地模型進(jìn)行更新,生成新一輪的本地模型,
17、步驟s14,重復(fù)執(zhí)行步驟s6-s13,直至f=ec,得到本地模型,其中,ec為設(shè)定的客戶端的迭代次數(shù);
18、步驟s15:重復(fù)執(zhí)行步驟s2-s4,將各個(gè)客戶端的本地模型輸出的知識(shí)在服務(wù)器進(jìn)行聚合,得到知識(shí),用于更新服務(wù)器端全局模型;
19、步驟s16:重復(fù)步驟s2-s15,直至客戶端與服務(wù)器通信輪數(shù)達(dá)到設(shè)定值e,則完成訓(xùn)練。
20、進(jìn)一步地,步驟s3,利用softmax函數(shù),根據(jù)聚合的知識(shí)和全局模型輸出的知識(shí),分別計(jì)算客戶端的隱性概率和顯性概率矩陣以及全局模型第i次迭代時(shí)服務(wù)器的隱性概率和顯性概率矩陣,并通過(guò)kl散度求得隱性損失項(xiàng)和顯性損失項(xiàng),用兩個(gè)超參數(shù),將隱性損失項(xiàng)和顯性損失項(xiàng)聯(lián)系起來(lái)作為雙焦點(diǎn)知識(shí)蒸餾損失項(xiàng),具體為:
21、根據(jù)聚合的知識(shí)分別計(jì)算客戶端的非目標(biāo)概率和目標(biāo)概率:
22、;
23、其中,為顯性溫度系數(shù),
24、根據(jù)全局模型輸出的知識(shí)分別計(jì)算全局模型第i次迭代時(shí)服務(wù)器的非目標(biāo)概率和目標(biāo)概率:
25、;
26、則得到客戶端的顯性概率矩陣和全局模型第i次迭代時(shí)服務(wù)器的顯性概率矩陣;
27、通過(guò)kl散度求得顯性損失項(xiàng)為:
28、;
29、定義客戶端的隱性概率,其中的每個(gè)元素表示為:
30、;
31、定義全局模型第i次迭代時(shí)服務(wù)器的隱性概率,其中的每個(gè)元素表示為:
32、;
33、其中,為隱性溫度系數(shù),
34、通過(guò)kl散度求得隱性損失項(xiàng)為:
35、;
36、則雙焦點(diǎn)知識(shí)蒸餾損失項(xiàng)為:
37、。
38、進(jìn)一步地,步驟s4中,在第i次迭代中,利用雙焦點(diǎn)知識(shí)蒸餾損失項(xiàng)對(duì)全局模型進(jìn)行更新,生成新一輪的全局模型,具體為:,其中,為模型學(xué)習(xí)率。
39、進(jìn)一步地,步驟s6,利用kl散度計(jì)算全局模型輸出知識(shí)與本地模型輸出知識(shí)的kl散度損失項(xiàng),具體為:
40、定義服務(wù)器的輸出概率:
41、;
42、其中,每個(gè)元素表示為:
43、;
44、定義本地模型第f次迭代時(shí)客戶端的輸出概率:
45、;
46、其中,每個(gè)元素表示為:
47、;
48、其中,t為客戶端蒸餾的溫度系數(shù);
49、則kl散度損失項(xiàng)為:
50、。
51、進(jìn)一步地,步驟s8中,客戶端蒸餾總損失項(xiàng)的表達(dá)式為:
52、。
53、進(jìn)一步地,步驟s10中,所述硬性指標(biāo)監(jiān)督損失項(xiàng)的表達(dá)式為:
54、;
55、其中,,分別為第f次迭代中本地模型的輸出結(jié)果與客戶端的本地?cái)?shù)據(jù)集的真實(shí)標(biāo)簽。
56、進(jìn)一步地,步驟s11中,所述對(duì)比監(jiān)督損失項(xiàng)的表達(dá)式為:
57、;
58、其中,為第f次迭代的本地模型的輸出結(jié)果,為第f-1次迭代的本地模型的輸出結(jié)果,為全局模型的輸出結(jié)果,為對(duì)比監(jiān)督損失項(xiàng)的溫度系數(shù),表示余弦相似度。
59、進(jìn)一步地,步驟s12中,客戶端模型監(jiān)督總損失項(xiàng)的表達(dá)式為:
60、。
61、進(jìn)一步地,步驟s13中,在第f次迭代中,利用客戶端模型監(jiān)督總損失項(xiàng),對(duì)本地模型進(jìn)行更新,生成新一輪的本地模型,具體為:=,其中,為模型學(xué)習(xí)率。
62、本發(fā)明具有以下技術(shù)效果:(1)本發(fā)明提出cbdfl,將傳統(tǒng)知識(shí)蒸餾損失項(xiàng)分為顯性損失項(xiàng)和隱性損失項(xiàng),并引入不同的溫度系數(shù)機(jī)制,通過(guò)調(diào)節(jié)蒸餾過(guò)程的平滑度,既保留目標(biāo)知識(shí)的穩(wěn)定貢獻(xiàn),又最大化非目標(biāo)知識(shí)對(duì)模型性能提升的潛力。該方法強(qiáng)調(diào)全局模型與本地模型的協(xié)同監(jiān)督,在提升全局模型性能的同時(shí),保留并強(qiáng)化本地特性,從而為聯(lián)邦學(xué)習(xí)在non-iid場(chǎng)景中的應(yīng)用提供了新的解決思路。
63、(2)針對(duì)流程工業(yè)中客戶端設(shè)備因所處環(huán)境差異而可能采用不同模型所導(dǎo)致的異構(gòu)性挑戰(zhàn),本發(fā)明創(chuàng)新性地融合了聯(lián)邦學(xué)習(xí)與知識(shí)蒸餾技術(shù)。通過(guò)知識(shí)傳遞這一替代策略,避免了直接共享模型參數(shù)或梯度,從而成功實(shí)現(xiàn)了對(duì)模型異構(gòu)性的有效支持。這一方法不僅突破了傳統(tǒng)限制,還為流程工業(yè)中的異構(gòu)設(shè)備協(xié)同提供了全新的解決方案。
64、(3)針對(duì)當(dāng)前支持模型異構(gòu)的聯(lián)邦學(xué)習(xí)方法在處理分布差異顯著的客戶端數(shù)據(jù)集時(shí)面臨的性能下降問(wèn)題,本發(fā)明進(jìn)一步提出了雙焦點(diǎn)知識(shí)蒸餾與監(jiān)督對(duì)比知識(shí)蒸餾兩種策略。其中,雙焦點(diǎn)知識(shí)蒸餾策略使客戶端在向服務(wù)器進(jìn)行知識(shí)蒸餾時(shí),能夠更加聚焦于提取隱性支持信息,從而顯著提升服務(wù)器的全局性能。而監(jiān)督對(duì)比知識(shí)蒸餾策略則確保了服務(wù)器向客戶端進(jìn)行蒸餾時(shí),客戶端模型在逼近全局最優(yōu)解的同時(shí),依然能夠保持對(duì)本地?cái)?shù)據(jù)的良好適應(yīng)性。這兩種策略相輔相成,共同實(shí)現(xiàn)了更加高效的雙向蒸餾效果,顯著提升了聯(lián)邦學(xué)習(xí)系統(tǒng)的整體性能。