Improving Token-Based World Models with Parallel Observation Prediction

2402.05643

YC

0

Reddit

0

Published 5/30/2024 by Lior Cohen, Kaixin Wang, Bingyi Kang, Shie Mannor

🔮

Abstract

Motivated by the success of Transformers when applied to sequences of discrete symbols, token-based world models (TBWMs) were recently proposed as sample-efficient methods. In TBWMs, the world model consumes agent experience as a language-like sequence of tokens, where each observation constitutes a sub-sequence. However, during imagination, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations. To resolve this bottleneck, we devise a novel Parallel Observation Prediction (POP) mechanism. POP augments a Retentive Network (RetNet) with a novel forward mode tailored to our reinforcement learning setting. We incorporate POP in a novel TBWM agent named REM (Retentive Environment Model), showcasing a 15.4x faster imagination compared to prior TBWMs. REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours. Our code is available at url{https://github.com/leor-c/REM}.

Create account to get full access

or

If you already have an account, we'll log you in

Overview

  • Researchers propose a novel "Parallel Observation Prediction" (POP) mechanism to address the bottleneck in token-based world models (TBWMs), which consume agent experiences as a sequence of tokens.
  • The POP mechanism is incorporated into a new TBWM agent called REM (Retentive Environment Model), which achieves superhuman performance on 12 out of 26 Atari 100K benchmark games in under 12 hours of training.
  • REM demonstrates a 15.4x faster imagination process compared to prior TBWMs, improving sample efficiency and GPU utilization.

Plain English Explanation

Researchers have found success using Transformer models to process sequences of discrete symbols, like words in a language. This led to the development of token-based world models (TBWMs), which consume an agent's experiences as a sequence of tokens, where each observation is a sub-sequence.

However, the sequential, token-by-token generation of the next observations during the imagination process (where the model predicts future states) created a major bottleneck. This resulted in slow training, inefficient use of GPUs, and limited representation capabilities.

To address this issue, the researchers devised a novel "Parallel Observation Prediction" (POP) mechanism. POP is incorporated into a new TBWM agent called REM (Retentive Environment Model), which can imagine future states 15.4 times faster than previous TBWMs. This improvement in imagination speed leads to better sample efficiency and GPU utilization.

Notably, REM is able to achieve superhuman performance on 12 out of 26 games in the Atari 100K benchmark, all within less than 12 hours of training. This demonstrates the powerful capabilities of the POP-enhanced REM agent.

Technical Explanation

The researchers propose a novel Parallel Observation Prediction (POP) mechanism to address the bottleneck in token-based world models (TBWMs). In TBWMs, the world model consumes agent experiences as a language-like sequence of tokens, where each observation constitutes a sub-sequence.

During the imagination process, the sequential token-by-token generation of next observations results in a severe bottleneck, leading to long training times, poor GPU utilization, and limited representations. To resolve this, the researchers incorporate POP into a new TBWM agent called REM (Retentive Environment Model).

POP augments a Retentive Network (RetNet) with a novel forward mode tailored to the reinforcement learning setting. This allows REM to imagine future states 15.4x faster compared to prior TBWMs, improving sample efficiency and GPU utilization.

Experiments show that REM attains superhuman performance on 12 out of 26 games of the Atari 100K benchmark, while training in less than 12 hours. This showcases the effectiveness of the POP mechanism in enhancing the capabilities of token-based world models.

Critical Analysis

The researchers acknowledge that while REM demonstrates impressive performance, there are still some limitations and areas for further research. For example, the paper mentions that the POP mechanism may not scale as efficiently to more complex environments or tasks with longer-range dependencies.

Additionally, the researchers note that the REM agent's strong performance on the Atari 100K benchmark may not necessarily translate to real-world applications, where the dynamics and visual complexity could be significantly different.

Further research could explore ways to address these limitations, such as exploring multi-modality scene tokenization and motion prediction or learning more robust latent representations. Incorporating locality-sensitive sparse encoding or empowering large language models as multimodal world models could also be valuable avenues for future work.

Conclusion

The researchers have developed a novel Parallel Observation Prediction (POP) mechanism that addresses the bottleneck in token-based world models (TBWMs), enabling significant improvements in sample efficiency, GPU utilization, and overall performance.

By incorporating POP into their Retentive Environment Model (REM) agent, the researchers have demonstrated superhuman performance on a majority of the Atari 100K benchmark games, all within a remarkably short training time of less than 12 hours.

This work represents an important step forward in enhancing the capabilities of token-based world models, paving the way for more efficient and powerful reinforcement learning agents that can tackle a wide range of complex environments and tasks.



This summary was produced with help from an AI and may contain inaccuracies - check out the links to read the original source documents!

Related Papers

Cognitively Inspired Energy-Based World Models

Cognitively Inspired Energy-Based World Models

Alexi Gladstone, Ganesh Nanduru, Md Mofijul Islam, Aman Chadha, Jundong Li, Tariq Iqbal

YC

0

Reddit

0

One of the predominant methods for training world models is autoregressive prediction in the output space of the next element of a sequence. In Natural Language Processing (NLP), this takes the form of Large Language Models (LLMs) predicting the next token; in Computer Vision (CV), this takes the form of autoregressive models predicting the next frame/token/pixel. However, this approach differs from human cognition in several respects. First, human predictions about the future actively influence internal cognitive processes. Second, humans naturally evaluate the plausibility of predictions regarding future states. Based on this capability, and third, by assessing when predictions are sufficient, humans allocate a dynamic amount of time to make a prediction. This adaptive process is analogous to System 2 thinking in psychology. All these capabilities are fundamental to the success of humans at high-level reasoning and planning. Therefore, to address the limitations of traditional autoregressive models lacking these human-like capabilities, we introduce Energy-Based World Models (EBWM). EBWM involves training an Energy-Based Model (EBM) to predict the compatibility of a given context and a predicted future state. In doing so, EBWM enables models to achieve all three facets of human cognition described. Moreover, we developed a variant of the traditional autoregressive transformer tailored for Energy-Based models, termed the Energy-Based Transformer (EBT). Our results demonstrate that EBWM scales better with data and GPU Hours than traditional autoregressive transformers in CV, and that EBWM offers promising early scaling in NLP. Consequently, this approach offers an exciting path toward training future models capable of System 2 thinking and intelligently searching across state spaces.

Read more

6/14/2024

Reinforcement Learning from Delayed Observations via World Models

Reinforcement Learning from Delayed Observations via World Models

Armin Karamzade, Kyungmin Kim, Montek Kalsi, Roy Fox

YC

0

Reddit

0

In standard reinforcement learning settings, agents typically assume immediate feedback about the effects of their actions after taking them. However, in practice, this assumption may not hold true due to physical constraints and can significantly impact the performance of learning algorithms. In this paper, we address observation delays in partially observable environments. We propose leveraging world models, which have shown success in integrating past observations and learning dynamics, to handle observation delays. By reducing delayed POMDPs to delayed MDPs with world models, our methods can effectively handle partial observability, where existing approaches achieve sub-optimal performance or degrade quickly as observability decreases. Experiments suggest that one of our methods can outperform a naive model-based approach by up to 250%. Moreover, we evaluate our methods on visual delayed environments, for the first time showcasing delay-aware reinforcement learning continuous control with visual observations.

Read more

6/27/2024

Decentralized Transformers with Centralized Aggregation are Sample-Efficient Multi-Agent World Models

Decentralized Transformers with Centralized Aggregation are Sample-Efficient Multi-Agent World Models

Yang Zhang, Chenjia Bai, Bin Zhao, Junchi Yan, Xiu Li, Xuelong Li

YC

0

Reddit

0

Learning a world model for model-free Reinforcement Learning (RL) agents can significantly improve the sample efficiency by learning policies in imagination. However, building a world model for Multi-Agent RL (MARL) can be particularly challenging due to the scalability issue in a centralized architecture arising from a large number of agents, and also the non-stationarity issue in a decentralized architecture stemming from the inter-dependency among agents. To address both challenges, we propose a novel world model for MARL that learns decentralized local dynamics for scalability, combined with a centralized representation aggregation from all agents. We cast the dynamics learning as an auto-regressive sequence modeling problem over discrete tokens by leveraging the expressive Transformer architecture, in order to model complex local dynamics across different agents and provide accurate and consistent long-term imaginations. As the first pioneering Transformer-based world model for multi-agent systems, we introduce a Perceiver Transformer as an effective solution to enable centralized representation aggregation within this context. Results on Starcraft Multi-Agent Challenge (SMAC) show that it outperforms strong model-free approaches and existing model-based methods in both sample efficiency and overall performance.

Read more

6/26/2024

Efficient World Models with Context-Aware Tokenization

Efficient World Models with Context-Aware Tokenization

Vincent Micheli, Eloi Alonso, Franc{c}ois Fleuret

YC

0

Reddit

0

Scaling up deep Reinforcement Learning (RL) methods presents a significant challenge. Following developments in generative modelling, model-based RL positions itself as a strong contender. Recent advances in sequence modelling have led to effective transformer-based world models, albeit at the price of heavy computations due to the long sequences of tokens required to accurately simulate environments. In this work, we propose $Delta$-IRIS, a new agent with a world model architecture composed of a discrete autoencoder that encodes stochastic deltas between time steps and an autoregressive transformer that predicts future deltas by summarizing the current state of the world with continuous tokens. In the Crafter benchmark, $Delta$-IRIS sets a new state of the art at multiple frame budgets, while being an order of magnitude faster to train than previous attention-based approaches. We release our code and models at https://github.com/vmicheli/delta-iris.

Read more

6/28/2024