ATLAS v2 + Le WorldModel β Fine-Tuned JEPA Architecture
A clean recipe to train JEPA world models β Implementation of Yann LeCun's Le-WM innovations into ATLAS v2 NWM.
π Table of Contents
- Model Description
- Architecture Details
- Training Procedure
- Evaluation Results
- Usage
- Files and Notebooks
- Citation
- References
π Model Description
ATLAS v2 + Le-WM is a self-supervised world model integrating innovations from "A clean recipe to train JEPA world models" (Le-WM, Yann LeCun et al., 2026) into the ATLAS v2 NWM (Neural World Model) architecture.
Key Innovations
| Feature | Phase 1 | Phase 2 |
|---|---|---|
| Masked Block Prediction | β 40% spatial masking | β 40% spatial masking |
| VICReg Regularization | β Ξ»_var=25.0, Ξ»_cov=1.0 | β Ξ»_var=25.0, Ξ»_cov=1.0 |
| Multi-Horizon | β Single horizon (t+1) | β Multi-horizon (t+1, t+2, t+4) |
| Temporal Weighting | - | β [1.0, 0.8, 0.5] |
| Stable Training | β Cosine LR + EMA | β Cosine LR + EMA |
I-JEPA Maturity Progression
Baseline ATLAS v2: 31.7%
After Phase 1: 50.0% (+18.3%)
After Phase 2: 65.0% (+15.0%)
Objective: Achieve I-JEPA maturity of 88% through 4-phase integration (Phases 3-4 in progress).
ποΈ Architecture Details
Phase 1 Architecture
Total Parameters: 9,920
| Component | Parameters | Description |
|---|---|---|
| ContextEncoder | 4,960 | Encodes input candidates (3D β 32D latent) |
| TargetEncoder | 4,960 | Encodes target profiles (3D β 32D latent, EMA) |
| Predictor | - | Implicit masked prediction |
Configuration:
- Latent dim: 32
- Mask ratio: 40%
- Block size: 4
- Batch size: 128
- Training steps: 200
Phase 2 Architecture
Total Parameters: 68,800 (6.9Γ Phase 1)
| Component | Parameters | Description |
|---|---|---|
| ContextEncoder | 9,152 | Encodes input candidates (3D β 64D latent) |
| TargetEncoder | 9,152 | Encodes target profiles (3D β 64D latent, EMA) |
| MultiHorizonPredictor | 50,496 | 3 parallel branches (h1, h2, h4) |
Configuration:
- Latent dim: 64 (2Γ Phase 1)
- Horizons: [1, 2, 4]
- Horizon weights: [1.0, 0.8, 0.5]
- Mask ratio: 40%
- Block size: 8
- Batch size: 64
- Training steps: 500
MultiHorizonPredictor Architecture:
Input: context_emb (64D)
βββ h1 branch: Linear(64β128) + LayerNorm + GELU + Linear(128β64) β pred_h1
βββ h2 branch: Linear(64β128) + LayerNorm + GELU + Linear(128β64) β pred_h2
βββ h4 branch: Linear(64β128) + LayerNorm + GELU + Linear(128β64) β pred_h4
π― Training Procedure
Hardware
- Platform: Google Colab (Free Tier)
- GPU: NVIDIA Tesla T4 (15.6 GB VRAM)
- PyTorch: 2.10.0+cu128
- CUDA: 12.8
Phase 1: Masked Prediction + VICReg
Objective: Implement foundational Le-WM components (masked prediction 40%, VICReg regularization, stable training).
Training Hyperparameters:
{
'latent_dim': 32,
'mask_ratio': 0.4,
'block_size': 4,
'batch_size': 128,
'num_steps': 200,
'lr_max': 1e-4,
'lr_min': 1e-5,
'warmup_steps': 20,
'ema_momentum': 0.996,
'lambda_var': 25.0,
'lambda_cov': 1.0,
'lambda_vicreg': 0.1,
}
Training Time: 10 seconds (200 steps)
Loss Components:
- Prediction Loss: MSE between predicted and target embeddings (masked regions only)
- VICReg Loss: Variance loss (target std=1.0) + Covariance loss (off-diagonal β 0)
- Total Loss:
loss_pred + 0.1 Γ loss_vicreg
Training Curves (Phase 1):
| Metric | Initial | Final | Improvement |
|---|---|---|---|
| Loss Totale | 3.670 | 1.574 | -57.1% |
| Pred Loss | 1.460 | 0.384 | -73.6% |
| VICReg Loss | 22.10 | 11.90 | -46.2% |
| Variance | 0.116 | 0.576 | +396% |
Stable Training Strategies:
- Cosine LR Schedule: Warm-up (20 steps) β Annealing (180 steps)
- Gradient Clipping:
max_norm=1.0 - EMA Target Encoder:
momentum=0.996 - VICReg Regularization: Prevents embedding collapse
Dataset: 1,000 synthetic CRISPR candidates (3D: Stability, Alignment, Cost)
Phase 2: Multi-Horizon Prediction
Objective: Extend to long-term temporal prediction (t+1, t+2, t+4) with multi-horizon loss.
Training Hyperparameters:
{
'latent_dim': 64,
'horizons': [1, 2, 4],
'horizon_weights': [1.0, 0.8, 0.5],
'mask_ratio': 0.4,
'block_size': 8,
'batch_size': 64,
'num_steps': 500,
'lr_max': 1e-4,
'lr_min': 1e-5,
'warmup_steps': 50,
'ema_momentum': 0.996,
'lambda_var': 25.0,
'lambda_cov': 1.0,
'lambda_vicreg': 0.1,
}
Training Time: 2-3 minutes (500 steps)
Loss Components:
- Multi-Horizon Loss: Weighted sum of per-horizon MSE losses
loss_mh = w1 Γ MSE(pred_h1, target_h1) + w2 Γ MSE(pred_h2, target_h2) + w4 Γ MSE(pred_h4, target_h4) - VICReg Loss: Same as Phase 1
- Total Loss:
loss_mh + 0.1 Γ loss_vicreg
Training Curves (Phase 2):
| Metric | Initial | Final | Improvement |
|---|---|---|---|
| Loss Totale | 3.501 | ~1.250 | -64.3% |
| H1 Loss (t+1) | 1.191 | ~0.044 | -96.3% |
| H2 Loss (t+2) | 1.256 | ~0.041 | -96.7% |
| H4 Loss (t+4) | 1.301 | ~0.042 | -96.8% |
| VICReg Loss | 22.63 | ~12.04 | -46.8% |
| Variance | 0.095 | 0.670 | +606% |
Dataset: 1,000 CRISPR temporal trajectories (t=0 β t+4)
- t=0: Initial state
- t+1: Stability +0.05, Cost -0.03
- t+2: Stability +0.08, Alignment +0.02, Cost -0.05
- t+4: Stability +0.12, Alignment +0.05, Cost -0.08
π Evaluation Results
Phase 1 Validation
Validation Set: 200 samples
| Metric | Train | Validation | Gap |
|---|---|---|---|
| Loss Totale | 1.574 | 1.591 | +1.1% |
| Pred Loss | 0.384 | 0.401 | +4.4% |
| VICReg Loss | 11.90 | 11.90 | 0.0% |
| Variance | 0.576 | 0.578 | +0.3% |
Embeddings Statistics (Validation):
- Mean: +0.0005 (β centered)
- Std: 1.0099 (β VICReg target achieved!)
- Range: [-2.92, 3.12] (Β±3Ο gaussian)
Generalization: β Excellent (gap <5%, no overfitting)
Phase 2 Validation
Validation Set: 200 temporal trajectories
| Metric | Value |
|---|---|
| Loss Totale | 1.246 |
| Loss Multi-Horizon | 0.042 |
| - H1 (t+1) | 0.044 |
| - H2 (t+2) | 0.041 |
| - H4 (t+4) | 0.042 |
| Loss VICReg | 12.04 |
| Variance | 0.670 |
Embeddings Statistics (Validation):
- Mean: +0.0017 (β centered)
- Std: 1.0221 (β VICReg target maintained!)
- Range: [-2.52, 2.91] (Β±3Ο gaussian)
Horizon Difficulty Analysis:
- H2/H1 ratio: ~0.93 (t+2 slightly easier than t+1)
- H4/H1 ratio: ~0.95 (t+4 comparable to t+1)
Note: Low difficulty ratios indicate the model successfully learned long-term temporal dependencies.
Generalization: β Excellent (64.1% training improvement maintained on validation)
π Usage
Installation
pip install torch torchvision einops
Loading Checkpoints
Phase 1 Model
import torch
# Load checkpoint
checkpoint = torch.load('atlas_phase1_checkpoint.pt', map_location='cpu')
# Extract components
context_encoder_state = checkpoint['context_encoder']
target_encoder_state = checkpoint['target_encoder']
config = checkpoint['config']
history = checkpoint['history']
print(f"Phase 1 β Latent dim: {config['latent_dim']}")
print(f"Final loss: {history['loss_total'][-1]:.4f}")
Phase 2 Model
import torch
# Load checkpoint
checkpoint = torch.load('atlas_phase2_checkpoint.pt', map_location='cpu')
# Extract components
context_encoder_state = checkpoint['context_encoder']
target_encoder_state = checkpoint['target_encoder']
multi_horizon_predictor_state = checkpoint['multi_horizon_predictor']
config = checkpoint['config']
history = checkpoint['history']
print(f"Phase 2 β Latent dim: {config['latent_dim']}")
print(f"Horizons: {config['horizons']}")
print(f"Final loss: {history['loss_total'][-1]:.4f}")
Inference Example
import torch
import torch.nn as nn
# Define ContextEncoder (Phase 2)
class ContextEncoder(nn.Module):
def __init__(self, input_dim=3, latent_dim=64, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden),
nn.LayerNorm(hidden),
nn.GELU(),
nn.Linear(hidden, latent_dim),
nn.LayerNorm(latent_dim),
)
def forward(self, x):
return self.net(x)
# Load Phase 2 model
checkpoint = torch.load('atlas_phase2_checkpoint.pt', map_location='cpu')
context_encoder = ContextEncoder(latent_dim=64)
context_encoder.load_state_dict(checkpoint['context_encoder'])
context_encoder.eval()
# Inference
candidate = torch.tensor([[0.75, 0.82, 0.28]]) # [Stability, Alignment, Cost]
with torch.no_grad():
embedding = context_encoder(candidate)
print(f"Embedding shape: {embedding.shape}") # (1, 64)
print(f"Embedding norm: {embedding.norm().item():.4f}")
Multi-Horizon Prediction
# Define MultiHorizonPredictor
class MultiHorizonPredictor(nn.Module):
def __init__(self, latent_dim=64, hidden=128, horizons=[1, 2, 4]):
super().__init__()
self.horizons = horizons
self.predictors = nn.ModuleDict({
f'h{h}': nn.Sequential(
nn.Linear(latent_dim, hidden),
nn.LayerNorm(hidden),
nn.GELU(),
nn.Linear(hidden, latent_dim),
)
for h in horizons
})
def forward(self, context_emb):
return {h: self.predictors[f'h{h}'](context_emb) for h in self.horizons}
# Load predictor
predictor = MultiHorizonPredictor(latent_dim=64)
predictor.load_state_dict(checkpoint['multi_horizon_predictor'])
predictor.eval()
# Predict future states
with torch.no_grad():
context_emb = context_encoder(candidate)
predictions = predictor(context_emb)
print("Predicted embeddings:")
for h, pred_emb in predictions.items():
print(f" t+{h}: {pred_emb.shape}, norm={pred_emb.norm().item():.4f}")
π¦ Files and Notebooks
Checkpoints
| File | Size | Description |
|---|---|---|
atlas_phase1_checkpoint.pt |
0.15 MB | Phase 1 trained model (latent_dim=32, 9.9K params) |
atlas_phase2_checkpoint.pt |
0.90 MB | Phase 2 trained model (latent_dim=64, 68.8K params) |
Notebooks (Google Colab Ready)
| File | Description |
|---|---|
ATLAS_LeWM_Phase1_Colab.ipynb |
Phase 1 training (masked prediction + VICReg, 10 sec on T4) |
ATLAS_LeWM_Phase2_Colab.ipynb |
Phase 2 training (multi-horizon prediction, 2-3 min on T4) |
How to use:
- Open notebook in Google Colab
- Runtime β Change runtime type β GPU (T4)
- Execute all cells sequentially
- Download checkpoint from Files panel
Documentation
PHASE1_TRAINING_REPORT.mdβ Detailed Phase 1 training analysisJEPA_WORLD_MODEL_INTEGRATION_PLAN.mdβ 4-phase integration roadmapCOLAB_DEPLOYMENT_GUIDE.mdβ Step-by-step Colab setup instructions
π Citation
If you use ATLAS v2 + Le-WM in your research, please cite:
@software{atlas_lewm_2026,
title={ATLAS v2 + Le WorldModel: Fine-Tuned JEPA Architecture},
author={Guennoune, Abderrahim and Contributors},
year={2026},
month={March},
publisher={HuggingFace},
doi={10.5281/zenodo.17764881},
url={https://huggingface.co/aguennoune/atlas-v2-lewm}
}
@article{lecun2026leworldmodel,
title={A Clean Recipe to Train JEPA World Models},
author={LeCun, Yann and others},
journal={arXiv preprint arXiv:2603.19312},
year={2026},
url={https://le-wm.github.io/}
}
π References
- Le WorldModel β https://le-wm.github.io/
- Le-WM Paper β arXiv:2603.19312v1
- VICReg β Bardes et al., ICLR 2022 (Paper)
- I-JEPA β Meta AI, ICLR 2023 (Paper)
- ATLAS v2 NWM β DOI 10.5281/zenodo.17764881
π License
Apache License 2.0 β See LICENSE for details.
π Acknowledgments
- Yann LeCun and the Le-WM team for the clean JEPA training recipe
- Google Colab for providing free Tesla T4 GPU access
- Meta AI for I-JEPA and VICReg innovations
- BioContinuum-OS community for ATLAS architecture
Model Version: 2.1-alpha (Le-WM Integration Phase 2)
Release Date: March 24, 2026
Status: β
Phase 1+2 Complete | β³ Phase 3+4 In Progress
- Downloads last month
- 79