Skip to content
Go back

Stabilising Small World Models With Flow Forcing

Table of contents

Open Table of contents

Introduction

There are a lot of arguments online about whether LLMs develop implicit world models as an emergent property of next-token-prediction objective based training.

However, I’m very interested in models which are designed as explicit world models. More concretely, they model the following probability p(sts1:t1,a1:t)p(s_t\mid s_{1: t-1}, a_{1: t}), Where sts_t and ata_t are the state and action at time tt. The coolest examples of these are generated gameworlds which allow for realtime interaction, like Decart’s OASIS model, where you can play a game in a purely generated environment that is created on-the-fly. Wayve’s GAIA 1 and GAIA 2 models are also good examples of these types of models, which can generate interactive and controllable driving scenarios.

Example of OASIS model generating interactive gameplay. Source

Questions about long term consistency aside, there is an issue with training these types of model with the traditional “teacher-forcing” objective commonly seen used to pretrain language models. “teacher forcing” roughly means obtaining a ground truth partial trajectory s1:t1,a1:ts_{1:t-1}, a_{1:t} and attempting to learn a parameterised distribution over the next state using some loss between the model’s prediction and the ground truth next state.

However, one thing autoregressive diffusion/flow matching models built in this way suffer from is compounding errors. That is, because the model is only trained on predicting the next state from ‘perfect’ historical trajectories, the inference setup of autoregressive rollout results in little errors which accumulate and push the trajectory further and further away from the distribution that the model has become accustomed to seeing in training. Ultimately, this results in trajectories which are unstable and degrade very quickly.

One way of looking at this error is as noise added to the new state, which corrupts it. What if we could model this inference-added noise which isn’t taken into account during training? A paper called Diffusion Forcing: Next-token Prediction Meets Full-Sequence Diffusion takes this idea and cleverly changes the standard teacher forcing objective into one that accommodates for this noise, which I’ll refer to as flow forcing. The paper goes into more detail, but the high level idea is that a single sequence of NN states acts as NN training subexamples, where each state is noised independently, and we compute the standard flow matching objective for each frame. Critically, the model gets the noised history of previous states, with the strength of the noise that has been applied at each state. This allows the model to selectively grab information from past states with the knowledge that some past states are more information rich than others.

Read the paper for implications and the different modes of operation this setup can facilitate, but ultimately this allows for robust autoregressive rollout, since, during the rollout, we can noise previously generated frames on purpose and tell the model we have noised them, so that we can “cover up” the weakness of the model and focus on the information that it can glean across the sequence of frames, rather than strongly relying on most recent state for prediction (which is a good strategy if the state is clean/well generated).

The following is an attempt to try testing this idea by creating a small driving simulator.

Model

The model architecture isn’t super interesting; if you understand the autoregressive transformer model, you’ll probably understand diffusion/flow matching transformer models.

AR Models Architecture
Architecture diagram of the autoregressive diffusion/flow matching transformer model used for stable world modeling. Source

The “diff” between a standard transformer and my model is as follows:

My final model had 36 layers, 4 attention heads per layer, dmodel=1536d_{model} = 1536, with 2.4B parameters.

Data

The training data predominantly comes from the vista dataset, which consists of high quality driving 4k driving videos which I downloaded, downsampled to 480p 10fps and cut up into 4s clips. Together with 14422 clips applying the same process to the nexar dashcam accident dataset, which has per second time of accident labels which I applied as conditioning labels into the model, we arrive at a total of roughly 850k 4 second clips. For my main model, I found that it seemed to converge having seen 270k clips, so stopped training here to save on compute costs.

Training

Objective

This blog is definitely not supposed to be a guide for how diffusion/flow matching works. I recommend this paper for a comprehensive guide to flow matching. However, the algorithm is relatively straightforward, and is shown below: It ends up boiling down to per frame noising of the input coupled with the flow matching objective summed across each frame in the sequence.

Algorithm: Flow Forcing Training Input: Video sequences x1:T={x1,x2,,xT}, accident flags sequences a1:T={a1,a2,,aT}Output: Model parameters θfor each training iteration doSample batch of video sequences {x1:T(i)}i=1B with flags {a1:T(i)}i=1B// Sample independent timesteps and noise for each frameτ1:TUniform(0,1)T (independent for each frame)ϵ1:TN(0,I)T (independent for each frame)x~1:T=τ1:Tx1:T+(1τ1:T)ϵ1:Twhere  denotes element-wise multiplication// Forward pass on noisy sequencevθ(x~1:T,a1:T,τ1:T)FlowTransformer(x~1:T,a1:T,τ1:T)// Compute flow matching lossvtarget=x1:Tϵ1:TL=1MtMvθ(x~t,at,τt)(xtϵt)22where M is the set of masked (valid) framesUpdate θ using gradient descent on L \begin{gathered} \begin{aligned} &\textbf{Input:} \text{ Video sequences } x_{1:T} = \{x_1, x_2, \ldots, x_T\}, \text{ accident flags sequences } a_{1:T} = \{a_1, a_2, \ldots, a_T\} \\ &\textbf{Output:} \text{ Model parameters } \theta \\ \\ &\textbf{for } \text{each training iteration } \textbf{do} \\ &\quad \text{Sample batch of video sequences } \{x_{1:T}^{(i)}\}_{i=1}^B \text{ with flags } \{a_{1:T}^{(i)}\}_{i=1}^B \\ \\ &\quad \text{// Sample independent timesteps and noise for each frame} \\ &\quad \tau_{1:T} \sim \text{Uniform}(0, 1)^T \text{ (independent for each frame)} \\ &\quad \epsilon_{1:T} \sim \mathcal{N}(0, I)^T \text{ (independent for each frame)} \\ &\quad \tilde{x}_{1:T} = \tau_{1:T} \odot x_{1:T} + (1 - \tau_{1:T}) \odot \epsilon_{1:T} \\ &\quad \text{where } \odot \text{ denotes element-wise multiplication} \\ \\ &\quad \text{// Forward pass on noisy sequence} \\ &\quad v_\theta(\tilde{x}_{1:T}, a_{1:T}, \tau_{1:T}) \leftarrow \text{FlowTransformer}(\tilde{x}_{1:T}, a_{1:T}, \tau_{1:T}) \\ \\ &\quad \text{// Compute flow matching loss} \\ &\quad v_{\text{target}} = x_{1:T} - \epsilon_{1:T} \\ &\quad \mathcal{L} = \frac{1}{|\mathcal{M}|} \sum_{t \in \mathcal{M}} \|v_\theta(\tilde{x}_t, a_t, \tau_t) - (x_t - \epsilon_t)\|_2^2 \\ &\quad \text{where } \mathcal{M} \text{ is the set of masked (valid) frames} \\ \\ &\quad \text{Update } \theta \text{ using gradient descent on } \mathcal{L} \end{aligned} \end{gathered}

Liger Kernels

One simple way to speed up training is to use torch.compile. I also used the custom Triton kernels provided by the Liger library, for things like RMSNorm and SwiGLU. I think speed benefits from using these are most apparent during multi-GPU training, so perhaps this is something worth ablating on my single-GPU setup.

MuP initialisation

When training a really big model, often you can’t afford to do a big sweep over hyperparameters to work out the best set for that specific setup. Ideally, what you want to be able to do is find optimal settings for things like learning rate, init std for weights using a smaller model, then transfer these to the bigger setup. However, this tends not to work if you do it naively; the best learning rate changes with model scale. However, there is a specific way of scaling these parameters with model size, called Maximal Update Parametrization, which allows this transfer. I implemented this, and the training across a few steps looks more stable across model width with MuP:

Training loss without MuP initialization
Activation sizes across width with MuP off
Training loss with MuP initialization
Activation sizes across width with MuP on

I did however end up turning off the scaling of the final readout/detokenizer layer and extra dhead\sqrt{d_{head}} scaling on QKTQK^T, since I found this bottlenecked performance. I suspect the problem is more with the detokenizer scaling - if the output of the model is scaled down too much, it might prevent the output being of the right scale as compared to the ground truth, meaning the weights have to be very large to compensate, which is quite challenging to achieve if you use weight decay which opposes this. The softmax over vocab in a traditional LM makes this a nonissue as it normalises.

The optimal setting for hyperparameters was obtained with Bayesopt on a model 10x smaller than the final one. I also used a relatively small batch size of 4 for the final training run, since in a single GPU setup where small batch sizes can be decently efficient in terms of MFU, gradient accumulation is not useful (see here) for more details. I also had to use gradient checkpointing on the attention computations to fit the model in memory.

One gotcha with MuP is that it allows you to transfer HPs given a fixed data budget; if you also scale up the amount of data you use along with model size, there is no guarantee that your optimal hyperparameter setup is still optimal! Ideally what you would do is do a multiobjective hyperparameter optimisation for minimising flops and loss for a fixed model size but optimising both the HPs and dataset size and fit scaling laws, but I was cheap and ended up fudging it based on advice from this great blog post.

Inference

To perform a stable rollout with the trained model where divergence is an issue we have two options:

Since the “noise” from the model being bad is not likely to be Gaussian, I chose the latter option as it seemed more principled. Here is a proper explanation of how this works:

Algorithm: Stable Autoregressive Rollouts Input: Initial frame x1, total frames T, regularisation strength αOutput: Generated video frames {xt}t=1TC[x1]// Clean framesN[x1]// Noisy context framesfor t=2 to T do// Prepare regularized contextx~contextN with noise levels:τ1=1 (first frame stays clean)τi=1α for i>1 (regularisation noise)// Generate clean frame using flow matchingxtDenoise(x~context, position=t)where Denoise integrates from τ=0 to τ=1// Update cachesCC{xt}x~t(1α)xt+αϵ,ϵN(0,I)NN{x~t}return C \begin{gathered} \begin{aligned} &\mathbf{Input:}\ \mathrm{Initial\ frame\ } x_1,\ \mathrm{total\ frames\ } T,\ \mathrm{regularisation\ strength\ } \alpha \\ &\mathbf{Output:}\ \mathrm{Generated\ video\ frames\ } \{x_t\}_{t=1}^T \\ \\ &\mathcal{C} \leftarrow [x_1] \quad \mathrm{//\ Clean\ frames} \\ &\mathcal{N} \leftarrow [x_1] \quad \mathrm{//\ Noisy\ context\ frames} \\ \\ &\mathbf{for\ } t = 2\ \mathbf{to\ } T\ \mathbf{do} \\ &\quad \mathrm{//\ Prepare\ regularized\ context} \\ &\quad \tilde{x}_{\mathrm{context}} \leftarrow \mathcal{N}\ \mathrm{with\ noise\ levels:} \\ &\qquad \tau_1 = 1\ \mathrm{(first\ frame\ stays\ clean)} \\ &\qquad \tau_i = 1 - \alpha\ \mathrm{for\ } i > 1\ \mathrm{(regularisation\ noise)} \\ \\ &\quad \mathrm{//\ Generate\ clean\ frame\ using\ flow\ matching} \\ &\quad x_t \leftarrow \mathrm{Denoise}(\tilde{x}_{\mathrm{context}},\ \mathrm{position}=t) \\ &\quad \mathrm{where\ Denoise\ integrates\ from\ } \tau=0\ \mathrm{to\ } \tau=1 \\ \\ &\quad \mathrm{//\ Update\ caches} \\ &\quad \mathcal{C} \leftarrow \mathcal{C} \cup \{x_t\} \\ &\quad \tilde{x}_t \leftarrow (1 - \alpha) \cdot x_t + \alpha \cdot \epsilon, \quad \epsilon \sim \mathcal{N}(0, I) \\ &\quad \mathcal{N} \leftarrow \mathcal{N} \cup \{\tilde{x}_t\} \\ \\ &\mathbf{return\ } \mathcal{C} \end{aligned} \end{gathered}

Results

In each clip, the model receives the same (high quality) starting image, and performs an autoregressive rollout a series of 29 new frames, which results in a three second clip. From left to right, the regularisation level increases (the levels used are [0.0,0.1,0.2,0.3,0.4][0.0,0.1,0.2,0.3,0.4]).

Very qualitatively, I noticed slightly better consistency as the regularisation was increased, but this was not perfect in all cases. For example, in the following clip, the motion of the signpost looks better with higher regularisation, but the left car seems to smear more rather than being passed by as is the case with low regularisation.

Movement of signposts is correct

I particularly like the following example, where the direction of travel of the other car changes in different rollouts:

Car drives in opposite directions on different rollouts

Overall, perhaps unsurprisingly, the model does best in very simple scenes, where the benefit of the regularisation is somewhat clearer:

Example of Autoregressive Rollout

If you really squint at the low regularisation examples, it kinda looks like the right car is overtaking the left one here (but perhaps I’m reaching too much).

Example of overtaking?

Failure modes

I did notice a number of failure modes with the rollouts, regardless of how much regularisation was used.

Firstly, approaching a car in front from behind can cause instability, and the model stretches it out in an non-physical way past a certain point in the trajectory. You do get a good sense of motion in the periphery though.

relative motorcycle speed changes in different rollouts

Secondly, the model does not deal with turns well at all. I guess turns are much less represented in the training data than long stretches of straight driving, but the model really likes to drive “forward”, often inventing/shifting the direction of the road to make this more feasible.

Light turning causes issues

Model can't deal with sharp right turns

Another poor turn

I’m not going to pretend this is a particularly good world model, but for just under £200 (very roughly 100x less than the world model pretraining costs of GAIA-1) I don’t feel too bad about the quality of the samples that I ended up with.

Next steps

That said, I think there are a bunch of things that I’d like to try out next:


Share this post on: