Skip to main content

Command Palette

Search for a command to run...

Contrastive RL: A Step-by-Step Guide to Learning Reachability

Updated
8 min read

The paper "1000 Layer Networks for Self-Supervised RL" won Best Paper at NeurIPS 2025, and for good reason. It demonstrates that goal-conditioned RL can scale to 1000-layer networks—something previously thought impractical. But the real insight isn't about network depth. It's about reframing RL as a representation learning problem.

The Core Insight

Traditional RL learns a value function: "how much cumulative reward will I get from here?" This requires reward signals, which are sparse and noisy. The credit assignment problem—figuring out which actions led to which rewards—gets harder as tasks get more complex.

Contrastive RL asks a different question: "Can I reach that goal from here?"

Instead of learning reward predictions, we learn a coordinate transformation. We train two neural networks to map observations into a new space—call it reachability space—where L2 distance encodes how reachable one state is from another.

Observation Space          Reachability Space

  s₀ ────────► g           sa₀ ●────────● g
      (complex)                  (L2 distance = reachability)

In this learned space, picking the best action becomes trivial: compute the distance to the goal for each candidate action, pick the action that gets you closest.

The Algorithm at a Glance

Here's the complete algorithm in pseudocode:

CONTRASTIVE RL TRAINING

1. COLLECT TRAJECTORIES
   Run policy in environment, store (state, action, next_state) tuples

2. SAMPLE TRAINING PAIRS
   For each (state, action), sample a future state as "goal"
   Weight sampling geometrically: nearby futures more likely

3. TRANSFORM TO REACHABILITY SPACE
   sa_embedding = SA_encoder(state, action)
   goal_embedding = G_encoder(goal)

4. COMPUTE DISTANCE MATRIX
   For all pairs in batch: distance[i,j] = ||sa_embedding[i] - goal_embedding[j]||₂

5. CALCULATE CONTRASTIVE LOSS
   Diagonal = positive pairs (we know state→goal is reachable)
   Off-diagonal = negative pairs (random pairings, likely unreachable)
   Loss = cross-entropy pushing positives together, negatives apart

6. UPDATE ACTOR
   Actor takes (state, goal) → action distribution
   Train to maximize Q = -distance(SA_encoder(s,a), G_encoder(g))
   SAC-style: balance Q-maximization with entropy for exploration

REPEAT

Let's walk through each step.


Step 1: Collect Trajectories

The agent explores the environment and records what happens:

Trajectory: s₀ →ᵃ⁰ s₁ →ᵃ¹ s₂ →ᵃ² s₃ →ᵃ³ s₄ →ᵃ⁴ s₅

Each transition (state, action, next_state) goes into a replay buffer. Nothing special here—this is standard RL data collection.

The key difference comes in how we use this data. We don't need reward labels. Every trajectory is useful because every state the agent reaches is, by definition, a state that was reachable.


Step 2: Sample State-Action-Goal Pairs

Here's where Contrastive RL gets clever. For each (state, action) pair, we need a "goal"—but what goal?

Answer: Any future state from the same trajectory.

Trajectory: s₀ → s₁ → s₂ → s₃ → s₄ → s₅

For s₁: valid goals are {s₂, s₃, s₄, s₅}

This is Hindsight Experience Replay. We ask: "what if the goal had been whatever state we actually reached?" Then by definition, we succeeded.

But we don't sample uniformly. We use geometric weighting:

probability(goal = sⱼ | current = sᵢ) ∝ γ^(j-i)

With γ = 0.99:

Goal DistanceSampling Weight
1 step ahead0.99
10 steps ahead0.90
50 steps ahead0.61
100 steps ahead0.37

Nearby goals are sampled more often. This creates a natural curriculum—the network first learns short-range reachability, then gradually extends to longer horizons.


Step 3: Transform to Reachability Space

Two neural networks do the heavy lifting:

SA_encoder(state, action) → 64-dim embedding
G_encoder(goal)           → 64-dim embedding

The SA_encoder takes a state-action pair and asks: "where will this action take me?"

The G_encoder takes a goal and asks: "where is this location in reachability space?"

Both networks output vectors in the same 64-dimensional space. The magic is that after training, the L2 distance between these vectors represents reachability:

distance = ||SA_encoder(s, a) - G_encoder(g)||₂

Small distance → (s, a) leads toward g
Large distance → (s, a) doesn't lead toward g

Step 4: Compute the Distance Matrix

For a batch of B training examples, we compute all pairwise distances:

sa_repr = SA_encoder(states, actions)  # shape: [B, 64]
g_repr = G_encoder(goals)               # shape: [B, 64]

# Pairwise L2 distances
distance_matrix[i,j] = ||sa_repr[i] - g_repr[j]||₂  # shape: [B, B]

This gives us a B×B matrix. Each row is a random (state, action) sample from the batch, each column is a goal from a different sample:

goal₀goal₁goal₂goal₃
sample(s,a)₀0.12.31.83.1
sample(s,a)₁2.50.22.91.7
sample(s,a)₂1.93.20.152.4
sample(s,a)₃2.81.62.10.3

The diagonal entries ✅ are positive pairs. Each sample(s,a)ᵢ came from the same trajectory as goalᵢ—we know this pair is reachable.

The off-diagonal entries are negative pairs. These are random pairings across different trajectories—almost certainly unreachable.


Step 5: Calculate Contrastive Loss (InfoNCE)

Now we need a loss function. Looking at our distance matrix, the goal is clear: push the diagonal values down (closer embeddings) and the off-diagonal values up (farther embeddings).

What loss function does this? Let's think about what we already know...

Cross-Entropy Refresher

Cross-entropy loss is everywhere in classification. Given logits for each class, it:

  1. Finds the score for the correct class
  2. Compares it against all other scores
  3. Penalizes when wrong classes have high scores

The formula is:

loss = -log(softmax(logits)[correct_class])

But there's a numerically stable way to write this. Recall that:

softmax(x)[i] = exp(x[i]) / sum(exp(x))

So:

-log(softmax(logits)[correct]) = -log(exp(logits[correct]) / sum(exp(logits)))
                                = -logits[correct] + log(sum(exp(logits)))
                                = -logits[correct] + logsumexp(logits)

That logsumexp term is a soft maximum—it's dominated by the largest value but is differentiable everywhere. Cross-entropy is essentially saying: "penalize when the soft-max over all classes exceeds the correct class."

From Cross-Entropy to InfoNCE

Now look at our distance matrix again. Each row is a classification problem:

"Given sample(s,a)ᵢ, which goal did it actually reach?"

The answer is goalᵢ—the diagonal entry. So we apply cross-entropy to each row:

logits = -distance_matrix  # negate: smaller distance = higher score

loss_per_row = -logits[i,i] + logsumexp(logits[i,:])
             = -score_for_correct_goal + soft_max_over_all_goals

loss = mean(loss_per_row)

This is InfoNCE—it's just cross-entropy where the "classes" are dynamically defined by the batch, and the correct class is always the diagonal.

Why This Works

Pair TypeWhat Happens
Diagonal (positive)-logits[i,i] term pulls embeddings closer
Off-diagonal (negative)logsumexp term pushes embeddings apart

The batch provides free negatives. With B=256, each positive pair competes against 255 negatives automatically. No labeling required—the batch structure defines the classification problem.

Logsumexp regularization:

The implementation adds a term to prevent logits from growing unbounded:

loss += 0.1 * mean(logsumexp(logits)²)

This keeps embedding magnitudes stable during training.


Step 6: Update the Actor

The actor network is goal-conditioned—it takes (state, goal) as input and outputs a distribution over actions:

mean, log_std = Actor(concat(state, goal))
action = tanh(mean + std * noise)  # reparameterized sample

To train it, we use the learned reachability space as a Q-function:

Q(state, action, goal) = -||SA_encoder(state, action) - G_encoder(goal)||₂

Higher Q (less negative) means the action leads closer to the goal.

The actor loss is SAC-style—maximize Q while maintaining entropy for exploration:

actor_loss = mean(alpha * log_prob - Q_value)

Where alpha is a learned temperature that auto-tunes to maintain a target entropy level.

Action selection at inference:

The default is stochastic—just sample from the policy. But optionally, you can use K-sample selection:

  1. Sample K candidate actions from the policy
  2. Compute Q-value for each
  3. Execute the action with highest Q

This optional mode trades off exploration for exploitation when evaluating the trained policy.


Why This Scales Better Than Traditional RL

Traditional RL has fundamental scaling problems:

1. Reward sparsity. In goal-reaching tasks, you only get reward when you reach the goal. Early in training, random exploration almost never succeeds, so there's no learning signal.

2. Credit assignment. Even with rewards, figuring out which actions led to success is hard. Did action 47 matter, or was it action 203?

3. Value function instability. Bootstrapping (using your own predictions to train yourself) can diverge, especially with deep networks.

Contrastive RL sidesteps all three:

1. Dense signal from hindsight. Every trajectory generates training data. Reached state X? Great, now we have a positive example for "X is reachable from earlier states." No sparse rewards needed.

2. Direct supervision. We don't ask "what reward will I get?" We ask "did I reach this state?" The answer is binary and ground-truth—no credit assignment required.

3. No bootstrapping. The contrastive loss is self-contained. We're not predicting our own predictions. This is why the paper could scale to 1000 layers—standard supervised learning techniques (residual connections, layer norm) just work.

The result: goal-conditioned agents that learn from pure self-supervision, scale to massive networks, and don't need hand-crafted reward functions.


Summary

Contrastive RL reframes goal-reaching as representation learning:

  1. Collect trajectories (any exploration data works)
  2. Sample goal pairs using hindsight (future states = valid goals)
  3. Learn a reachability space where distance = "how reachable?"
  4. Train with contrastive loss (positives close, negatives far)
  5. Act by minimizing distance to the goal

The algorithm doesn't predict rewards—it learns a coordinate system. In that coordinate system, achieving goals becomes as simple as walking downhill.


If you are interested in high-performance RL simulations, check out Lightspeed - a GPU-accelerated RL framework built on NVIDIA Warp.