Contrastive RL: A Step-by-Step Guide to Learning Reachability
The paper "1000 Layer Networks for Self-Supervised RL" won Best Paper at NeurIPS 2025, and for good reason. It demonstrates that goal-conditioned RL can scale to 1000-layer networks—something previously thought impractical. But the real insight isn't about network depth. It's about reframing RL as a representation learning problem.
The Core Insight
Traditional RL learns a value function: "how much cumulative reward will I get from here?" This requires reward signals, which are sparse and noisy. The credit assignment problem—figuring out which actions led to which rewards—gets harder as tasks get more complex.
Contrastive RL asks a different question: "Can I reach that goal from here?"
Instead of learning reward predictions, we learn a coordinate transformation. We train two neural networks to map observations into a new space—call it reachability space—where L2 distance encodes how reachable one state is from another.
Observation Space Reachability Space
s₀ ────────► g sa₀ ●────────● g
(complex) (L2 distance = reachability)
In this learned space, picking the best action becomes trivial: compute the distance to the goal for each candidate action, pick the action that gets you closest.
The Algorithm at a Glance
Here's the complete algorithm in pseudocode:
CONTRASTIVE RL TRAINING
1. COLLECT TRAJECTORIES
Run policy in environment, store (state, action, next_state) tuples
2. SAMPLE TRAINING PAIRS
For each (state, action), sample a future state as "goal"
Weight sampling geometrically: nearby futures more likely
3. TRANSFORM TO REACHABILITY SPACE
sa_embedding = SA_encoder(state, action)
goal_embedding = G_encoder(goal)
4. COMPUTE DISTANCE MATRIX
For all pairs in batch: distance[i,j] = ||sa_embedding[i] - goal_embedding[j]||₂
5. CALCULATE CONTRASTIVE LOSS
Diagonal = positive pairs (we know state→goal is reachable)
Off-diagonal = negative pairs (random pairings, likely unreachable)
Loss = cross-entropy pushing positives together, negatives apart
6. UPDATE ACTOR
Actor takes (state, goal) → action distribution
Train to maximize Q = -distance(SA_encoder(s,a), G_encoder(g))
SAC-style: balance Q-maximization with entropy for exploration
REPEAT
Let's walk through each step.
Step 1: Collect Trajectories
The agent explores the environment and records what happens:
Trajectory: s₀ →ᵃ⁰ s₁ →ᵃ¹ s₂ →ᵃ² s₃ →ᵃ³ s₄ →ᵃ⁴ s₅
Each transition (state, action, next_state) goes into a replay buffer. Nothing special here—this is standard RL data collection.
The key difference comes in how we use this data. We don't need reward labels. Every trajectory is useful because every state the agent reaches is, by definition, a state that was reachable.
Step 2: Sample State-Action-Goal Pairs
Here's where Contrastive RL gets clever. For each (state, action) pair, we need a "goal"—but what goal?
Answer: Any future state from the same trajectory.
Trajectory: s₀ → s₁ → s₂ → s₃ → s₄ → s₅
For s₁: valid goals are {s₂, s₃, s₄, s₅}
This is Hindsight Experience Replay. We ask: "what if the goal had been whatever state we actually reached?" Then by definition, we succeeded.
But we don't sample uniformly. We use geometric weighting:
probability(goal = sⱼ | current = sᵢ) ∝ γ^(j-i)
With γ = 0.99:
| Goal Distance | Sampling Weight |
| 1 step ahead | 0.99 |
| 10 steps ahead | 0.90 |
| 50 steps ahead | 0.61 |
| 100 steps ahead | 0.37 |
Nearby goals are sampled more often. This creates a natural curriculum—the network first learns short-range reachability, then gradually extends to longer horizons.
Step 3: Transform to Reachability Space
Two neural networks do the heavy lifting:
SA_encoder(state, action) → 64-dim embedding
G_encoder(goal) → 64-dim embedding
The SA_encoder takes a state-action pair and asks: "where will this action take me?"
The G_encoder takes a goal and asks: "where is this location in reachability space?"
Both networks output vectors in the same 64-dimensional space. The magic is that after training, the L2 distance between these vectors represents reachability:
distance = ||SA_encoder(s, a) - G_encoder(g)||₂
Small distance → (s, a) leads toward g
Large distance → (s, a) doesn't lead toward g
Step 4: Compute the Distance Matrix
For a batch of B training examples, we compute all pairwise distances:
sa_repr = SA_encoder(states, actions) # shape: [B, 64]
g_repr = G_encoder(goals) # shape: [B, 64]
# Pairwise L2 distances
distance_matrix[i,j] = ||sa_repr[i] - g_repr[j]||₂ # shape: [B, B]
This gives us a B×B matrix. Each row is a random (state, action) sample from the batch, each column is a goal from a different sample:
| goal₀ | goal₁ | goal₂ | goal₃ | |
| sample(s,a)₀ | ✅ 0.1 | 2.3 | 1.8 | 3.1 |
| sample(s,a)₁ | 2.5 | ✅ 0.2 | 2.9 | 1.7 |
| sample(s,a)₂ | 1.9 | 3.2 | ✅ 0.15 | 2.4 |
| sample(s,a)₃ | 2.8 | 1.6 | 2.1 | ✅ 0.3 |
The diagonal entries ✅ are positive pairs. Each sample(s,a)ᵢ came from the same trajectory as goalᵢ—we know this pair is reachable.
The off-diagonal entries are negative pairs. These are random pairings across different trajectories—almost certainly unreachable.
Step 5: Calculate Contrastive Loss (InfoNCE)
Now we need a loss function. Looking at our distance matrix, the goal is clear: push the diagonal values down (closer embeddings) and the off-diagonal values up (farther embeddings).
What loss function does this? Let's think about what we already know...
Cross-Entropy Refresher
Cross-entropy loss is everywhere in classification. Given logits for each class, it:
- Finds the score for the correct class
- Compares it against all other scores
- Penalizes when wrong classes have high scores
The formula is:
loss = -log(softmax(logits)[correct_class])
But there's a numerically stable way to write this. Recall that:
softmax(x)[i] = exp(x[i]) / sum(exp(x))
So:
-log(softmax(logits)[correct]) = -log(exp(logits[correct]) / sum(exp(logits)))
= -logits[correct] + log(sum(exp(logits)))
= -logits[correct] + logsumexp(logits)
That logsumexp term is a soft maximum—it's dominated by the largest value but is differentiable everywhere. Cross-entropy is essentially saying: "penalize when the soft-max over all classes exceeds the correct class."
From Cross-Entropy to InfoNCE
Now look at our distance matrix again. Each row is a classification problem:
"Given sample(s,a)ᵢ, which goal did it actually reach?"
The answer is goalᵢ—the diagonal entry. So we apply cross-entropy to each row:
logits = -distance_matrix # negate: smaller distance = higher score
loss_per_row = -logits[i,i] + logsumexp(logits[i,:])
= -score_for_correct_goal + soft_max_over_all_goals
loss = mean(loss_per_row)
This is InfoNCE—it's just cross-entropy where the "classes" are dynamically defined by the batch, and the correct class is always the diagonal.
Why This Works
| Pair Type | What Happens |
| Diagonal (positive) | -logits[i,i] term pulls embeddings closer |
| Off-diagonal (negative) | logsumexp term pushes embeddings apart |
The batch provides free negatives. With B=256, each positive pair competes against 255 negatives automatically. No labeling required—the batch structure defines the classification problem.
Logsumexp regularization:
The implementation adds a term to prevent logits from growing unbounded:
loss += 0.1 * mean(logsumexp(logits)²)
This keeps embedding magnitudes stable during training.
Step 6: Update the Actor
The actor network is goal-conditioned—it takes (state, goal) as input and outputs a distribution over actions:
mean, log_std = Actor(concat(state, goal))
action = tanh(mean + std * noise) # reparameterized sample
To train it, we use the learned reachability space as a Q-function:
Q(state, action, goal) = -||SA_encoder(state, action) - G_encoder(goal)||₂
Higher Q (less negative) means the action leads closer to the goal.
The actor loss is SAC-style—maximize Q while maintaining entropy for exploration:
actor_loss = mean(alpha * log_prob - Q_value)
Where alpha is a learned temperature that auto-tunes to maintain a target entropy level.
Action selection at inference:
The default is stochastic—just sample from the policy. But optionally, you can use K-sample selection:
- Sample K candidate actions from the policy
- Compute Q-value for each
- Execute the action with highest Q
This optional mode trades off exploration for exploitation when evaluating the trained policy.
Why This Scales Better Than Traditional RL
Traditional RL has fundamental scaling problems:
1. Reward sparsity. In goal-reaching tasks, you only get reward when you reach the goal. Early in training, random exploration almost never succeeds, so there's no learning signal.
2. Credit assignment. Even with rewards, figuring out which actions led to success is hard. Did action 47 matter, or was it action 203?
3. Value function instability. Bootstrapping (using your own predictions to train yourself) can diverge, especially with deep networks.
Contrastive RL sidesteps all three:
1. Dense signal from hindsight. Every trajectory generates training data. Reached state X? Great, now we have a positive example for "X is reachable from earlier states." No sparse rewards needed.
2. Direct supervision. We don't ask "what reward will I get?" We ask "did I reach this state?" The answer is binary and ground-truth—no credit assignment required.
3. No bootstrapping. The contrastive loss is self-contained. We're not predicting our own predictions. This is why the paper could scale to 1000 layers—standard supervised learning techniques (residual connections, layer norm) just work.
The result: goal-conditioned agents that learn from pure self-supervision, scale to massive networks, and don't need hand-crafted reward functions.
Summary
Contrastive RL reframes goal-reaching as representation learning:
- Collect trajectories (any exploration data works)
- Sample goal pairs using hindsight (future states = valid goals)
- Learn a reachability space where distance = "how reachable?"
- Train with contrastive loss (positives close, negatives far)
- Act by minimizing distance to the goal
The algorithm doesn't predict rewards—it learns a coordinate system. In that coordinate system, achieving goals becomes as simple as walking downhill.
If you are interested in high-performance RL simulations, check out Lightspeed - a GPU-accelerated RL framework built on NVIDIA Warp.
