WHALE來了,南大周志華團隊做出更強泛化的世界模型

機器之心報道

作者:陳陳、佳琪

人類能夠在腦海中設想一個想象中的世界,以預測不同的動作可能導致不同的結果。受人類智能這一方面的啓發,世界模型被設計用於抽象化現實世界的動態,並提供這種「如果…… 會怎樣」的預測。

因此,具身智能體可以與世界模型進行交互,而不是直接與現實世界環境交互,以生成模擬數據,這些數據可以用於各種下游任務,包括反事實預測、離線策略評估、離線強化學習。

世界模型在具身環境的決策中起着至關重要的作用,使得在現實世界中成本高昂的探索成爲可能。爲了促進有效的決策,世界模型必須具備強大的泛化能力,以支持分佈外 (OOD) 區域的想象,並提供可靠的不確定性估計來評估模擬體驗的可信度,這兩者都對之前的可擴展方法提出了重大挑戰。

本文,來自南京大學、南棲仙策等機構的研究者引入了 WHALE(World models with beHavior-conditioning and retrAcing-rollout LEarning),這是一個用於學習可泛化世界模型的框架,由兩種可以與任何神經網絡架構普遍結合的關鍵技術組成。

首先,在確定策略分佈差異是泛化誤差的主要來源的基礎上,作者引入了一種行爲 - 條件(behavior-conditioning)技術來增強世界模型的泛化能力,該技術建立在策略條件模型學習的概念之上,旨在使模型能夠主動適應不同的行爲,以減輕分佈偏移引起的外推誤差。

此外,作者還提出了一種簡單而有效的技術,稱爲 retracing-rollout,以便對模型想象進行有效的不確定性估計。作爲一種即插即用的解決方案, retracing-rollout 可以有效地應用於各種實施任務中的末端執行器姿態控制,而無需對訓練過程進行任何更改。

爲了實現 WHALE 框架,作者提出了 Whale-ST,這是一個基於時空 transformer 的可擴展具身世界模型,旨在爲現實世界的視覺控制任務提供忠實的長遠想象。

爲了證實 Whale-ST 的有效性,作者在模擬的 Meta-World 基準和物理機器人平臺上進行了廣泛的實驗。

在模擬任務上的實驗結果表明,Whale-ST 在價值估計準確率和視頻生成保真度方面均優於現有的世界模型學習方法。此外,作者還證明了基於 retracing-rollout 技術的 Whale-ST 可以有效捕獲模型預測誤差並使用想象的經驗增強離線策略優化。

作爲進一步的舉措,作者引入了 Whale-X,這是一個具有 414M 參數的世界模型,該模型在 Open X-Embodiment 數據集中的 970k 個現實世界演示上進行了訓練。通過在完全沒見過的環境和機器人中的一些演示進行微調,Whale-X 在視覺、動作和任務視角中展示了強大的 OOD 通用性。此外,通過擴大預訓練數據集或模型參數,Whale-X 在預訓練和微調階段都表現出了令人印象深刻的可擴展性。

總結來說,這項工作的主要貢獻概述如下:

學習可泛化的世界模型以進行具身決策

世界模型中的序列決策通常需要智能體探索超出訓練數據集的分佈外 (OOD) 區域。這要求世界模型表現出強大的泛化能力,使其能夠做出與現實世界動態密切相關的準確預測。同時,可靠地量化預測不確定性對於穩健的決策至關重要,這可以防止離線策略優化利用錯誤的模型預測。考慮到這些問題,作者提出了 WHALE,這是一個用於學習可泛化世界模型的框架,具有增強的泛化性和高效的不確定性估計。

用於泛化的行爲 - 條件

根據公式(2)的誤差分解可知,世界模型的泛化誤差主要來源於策略分歧引起的誤差積累。

爲了解決這個問題,一種可能的解決方案是將行爲信息嵌入到世界模型中,使得模型能夠主動識別策略的行爲模式,並適應由策略引起的分佈偏移。

基於行爲 - 條件,作者引入了一個學習目標,即從訓練軌跡中獲取行爲嵌入,並整合學習到的嵌入。

作者希望將訓練軌跡 τ_H 中的決策模式提取到行爲嵌入中,這讓人聯想到以歷史 τ_h 爲條件的軌跡似然 ELBO(evidence lower bound)的最大化:

作者建議通過最大化 H 個決策步驟上的 ELBO 並調整類似於 β-VAE 的 KL 約束數量來學習行爲嵌入:

不確定性估計 Retracing-rollout

世界模型不可避免地會產生不準確和不可靠的樣本,先前的研究從理論和實驗上都證明,如果無限制地使用模型生成的數據,策略的性能可能會受到嚴重損害。因此,不確定性估計對於世界模型至關重要。

作者引入了一種新穎的不確定性估計方法,即 retracing-rollout。retracing-rollout 的核心創新在於引入了 retracing-action,它利用了具身控制中動作空間的語義結構,從而能夠更準確、更高效地估計基於 Transformer 的世界模型的不確定性。

爲了估計某一時間點 (o_t,a_t) 的不確定性,採用多種回溯步驟生成不同的回溯 - 軌跡預測結果。具體來說,要計算不同回溯 - 軌跡輸出與不使用回溯的輸出之間的「感知損失」。同時,引入動態模型的預測熵,通過將「感知損失」和預測熵相乘,得到最終的不確定性估計結果。

與基於集成的其他方法不同,retracing-rollout 方法不需要在訓練階段進行任何修改,因此相比集成方法,它顯著減少了計算成本。

作者在論文中還給出了具體的實例。圖 3 展示了 Whale-ST 的整體架構。具體來說,Whale-ST 包含三個主要組件:行爲調節模型、視頻 tokenizer 和動態模型。這些模塊採用了時空 transformer 架構。

這些設計顯著簡化了計算需求,從相對於序列長度的二次依賴關係簡化爲線性依賴關係,從而降低了模型訓練的內存使用量和計算成本,同時提高了模型推理速度。

實驗

該團隊在模擬任務和現實世界任務上進行了廣泛的實驗,主要是爲了回答以下問題:

模擬任務中的 Whale-ST

該團隊在 Meta-World 基準測試上開展實驗。Meta-World 是一個包含多種視覺操作任務的測試集。研究者們構建了一個包含 6 萬條軌跡的訓練數據集,這些軌跡是從 20 個不同的任務中收集來的。模型學習算法需要使用這些數據從頭開始訓練。

研究團隊將 Whale-ST 與 FitVid、MCVD、DreamerV3、iVideoGPT 進行了對比。評估指標如下:

下表展示了預測準確性的結果,其中,Whale-ST 在所有三個指標上都表現出色。在 64 × 64 的分辨率下,Whale-ST 的值差與 DreamerV3 的最高分非常接近。當在更高分辨率 256 × 256 測試時,Whale-ST 的表現進一步提升,取得了最小的值差和最高的回報相關性,反映了 Whale-ST 能更細緻地理解動態環境。

表 2 展示了視頻保真度的結果,Whale-ST 在所有指標上均優於其他方法,特別是 FVD 具有顯著優勢。

不確定性估計

針對不確定性,研究團隊比較了 retracing-rollout 與兩種基準方法:

(1)基於熵的方法:研究團隊採用基於 Transformer 的動態模型,它通過計算模型輸出的預測熵來量化不確定性

(2)基於集成的方法:研究團隊訓練了三個獨立的動態模型,然後通過比較每個模型生成的圖像之間的像素級差異來估計不確定性。

具體來說,他們從模型誤差預測和離線強化學習兩個角度進行評估。

下表展示了模型誤差預測的結果,在所有 5 個任務中,retracing-rollout 均優於其他基線方法。與基於集成的方法相比,retracing-rollout 提升了 500%,與基於熵的方法相比,提高了 50%。

下圖展示了離線 MBRL 的結果,retracing-rollout 在 5 個任務中的 3 個任務中收斂得更好、具備更強的穩定性。特別是在關水龍頭和滑盤子任務中,retracing-rollout 是唯一能夠穩定收斂的方法,而其他方法在訓練後期出現了不同程度的性能下降。

Whale-X 在真實世界中的表現

爲了評估 Whale-X 在實際物理環境中的泛化能力,研究團隊在 ARX5 機器人上進行了全面實驗。

與預訓練數據不同,評估任務調整了攝像機角度和背景等,增加了對世界模型的挑戰。他們收集了每個任務 60 條軌跡的數據集用於微調,任務包括開箱、推盤、投球和移動瓶子,還設計了多個模型從未接觸過的任務來測試模型的視覺、運動和任務泛化能力。

如圖 5 所示,Whale-X 在真實世界中展現出了明顯的優勢。

具體來說:

1. 與沒有行爲 - 條件的模型相比,Whale-X 的一致性提高了 63%,表明該機制顯著提升了 OOD 泛化能力;

2. 在 97 萬個樣本上進行預訓練的 Whale-X,比從零開始訓練的模型具有更高的一致性,凸顯了大規模互聯網數據預訓練的優勢;

3. 增加模型參數能夠提升世界模型的泛化能力。Whale-X-base(203M)動態模型在三個未見任務中的一致性率是 77M 版本的三倍。

此外,視頻生成質量與一致性的結果一致,如表 4 所示。通過行爲 - 條件策略、大規模預訓練數據集和擴展模型參數,三種策略結合,顯著提高了模型的 OOD 泛化能力,尤其是在生成高質量視頻方面。

擴展性

固定視頻 token 和行爲 - 條件這兩個部分不變,僅調整模型的參數量和預訓練數據集的大小,Whale-X 的拓展性如何呢?

研究團隊在預訓練階段訓練了四個動態模型,參數數量從 39M 到 456M 不等,結果如圖 7 的前兩幅圖所示。

這些結果表明,Whale-X 展現出強大的擴展性:無論是增加預訓練數據還是增加模型參數,都會降低訓練 loss。

除此之外,研究團隊還驗證了更大的模型在微調階段是否能夠展現更好的性能。

爲此,他們微調了一系列動態模型,結果如圖 7 最左側所示。不難發現,經過微調後,更大的模型在測試數據上表現出更低的 loss,進一步突顯了 Whale-X 在真實任務中出色的擴展性。

可視化

圖 1 展示了在 Meta-World、Open X-Embodiment 和研究團隊設計的真實任務上的定性評估結果。

結果表明,Whale-ST 和 Whale-X 能夠生成高保真度的視頻軌跡,尤其是在長時間跨度的軌跡生成過程中,保持了視頻的質量和一致性。

圖 8 展示了 Whale-X 在控制性和泛化性方面的強大能力。給定一個未見過的動作序列,Whale-X 能夠生成與人類理解相符的視頻,學習動作與機器人手臂移動之間的因果聯繫。

通過 t-SNE 可視化,研究表明 Whale-X 成功地學習到行爲嵌入,能夠區分不同策略之間的差異。例如,對於同一任務,不同的策略會有不同的行爲表示,而噪聲策略的嵌入則介於專家策略和隨機策略之間,體現了模型在策略建模上的合理性。此外,專家策略在不同任務中的嵌入也能被區分,而隨機策略則無法區分,表明模型更擅長表示和區分策略,而不是任務本身。

更多研究細節,請參考原文。

參考鏈接:https://arxiv.org/abs/2411.05619