<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" xml:lang="en"><generator uri="https://jekyllrb.com/" version="4.3.3">Jekyll</generator><link href="https://elanapearl.github.io/feed.xml" rel="self" type="application/atom+xml"/><link href="https://elanapearl.github.io/" rel="alternate" type="text/html" hreflang="en"/><updated>2026-03-23T04:06:29+00:00</updated><id>https://elanapearl.github.io/feed.xml</id><title type="html">Elana Simon</title><subtitle>(matrices &amp; molecules) a collection of deep-dives into ML x bio </subtitle><entry><title type="html">the bug that taught me more about PyTorch than years of using it</title><link href="https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/" rel="alternate" type="text/html" title="the bug that taught me more about PyTorch than years of using it"/><published>2025-10-22T00:00:00+00:00</published><updated>2025-10-22T00:00:00+00:00</updated><id>https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch</id><content type="html" xml:base="https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/"><![CDATA[<p><code class="language-plaintext highlighter-rouge">Expected to fix: my hyperparameters. Actually had to fix: PyTorch backend.</code></p> <p>My training loss plateaued and wouldn’t budge. Obviously I’d screwed something up. I tried every hyperparameter combination, rewrote my loss function, spent days assuming I’d made some stupid mistake. Because it’s always user error.</p> <p>This time, it wasn’t. It was a niche PyTorch bug that forced me through layers of abstraction I normally never think about: optimizer internals, memory layouts, dispatch systems, kernel implementations. Taught me more about the framework than years of using it.</p> <p>I had a surprisingly fun time with this bug hunt and wrote up the whole investigation step-by-step, explaining framework internals as they become necessary to crack the case. If you enjoy debugging mysteries or find that tracking down bugs teaches you more than docs ever could, this might resonate. 🕵️‍♀️</p> <p>Debugging post-mortems sometimes make me worry I wouldn’t have been smart enough to figure them out myself. So I structured this walkthrough to show the reasoning behind each step: what clues suggested each move, why I tested that hypothesis, why certain results pointed where they did. While the investigation took time and persistence, it didn’t require any particular expertise or wizardry— just observation and willingness to keep digging. I’ve included background knowledge exactly when you need it to understand the next step—think of it as an excuse to learn (or re-learn) PyTorch internals through a real problem. If you’d prefer to jump straight to reproducing the bug yourself, check out the <a href="https://github.com/ElanaPearl/pytorch-mps-noncontiguous-bug">minimal reproduction script and walkthrough</a> on GitHub. Otherwise, join me on the investigation!</p> <p><strong>Table of Contents:</strong> 🤔 <a href="#the-mystery-a-plateauing-loss">The Mystery: A Plateauing Loss</a>…… 🔎 <a href="#isolating-the-problem">Isolating the Problem</a>…… 💻 <a href="#device-specific-differences">Device-Specific Differences</a>…… ⌺ <a href="#tensor-memory-layouts">Tensor Memory Layouts</a>…… 💔 <a href="#identifying-the-broken-operations">Identifying the Broken Operations</a>……. 🍎 <a href="#inside-the-kernel-implementation">Inside the Kernel Implementation</a>…… 🕵️‍♀️ <a href="#case-closed">Case Closed</a></p> <details> <summary><b>TL;DR - Just tell me the bug</b></summary> <div> <p><strong>The Bug:</strong> A PyTorch GPU kernel bug silently failed when writing to non-contiguous memory, causing my model’s encoder weights to freeze during training on Apple Silicon (MPS backend, PyTorch &lt;2.4).</p> <p><strong>The Technical Details:</strong> PyTorch’s MPS (Apple Silicon GPU) backend had a kernel bug where <code class="language-plaintext highlighter-rouge">addcmul_</code> and <code class="language-plaintext highlighter-rouge">addcdiv_</code> operations silently fail when writing to non-contiguous output tensors.</p> <p><strong>Why It Caused the Training Plateau:</strong></p> <ul> <li>Encoder weights initialized as transpose of decoder → non-contiguous memory layout</li> <li>Adam’s state tensors inherited this layout (<code class="language-plaintext highlighter-rouge">exp_avg</code> and <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> became non-contiguous)</li> <li>MPS kernels for <code class="language-plaintext highlighter-rouge">addcmul_</code>/<code class="language-plaintext highlighter-rouge">addcdiv_</code> don’t handle non-contiguous outputs correctly</li> <li>Results computed but written to temporary buffer instead of actual tensor</li> <li>For the non-contiguous encoder’s Adam parameters, <code class="language-plaintext highlighter-rouge">exp_avg_sq.addcmul_()</code> doesn’t update → value stays zero, then the parameter update via <code class="language-plaintext highlighter-rouge">addcdiv_</code> also fails → complete silent freeze</li> </ul> <p><strong>The Fix:</strong></p> <ul> <li><strong>Adjust your code:</strong> Make weights contiguous at initialization</li> <li><strong>Upgrade PyTorch:</strong> Upgrade to PyTorch ≥2.4 (fixes <code class="language-plaintext highlighter-rouge">addcmul_</code>/<code class="language-plaintext highlighter-rouge">addcdiv_</code>)</li> <li><strong>(Complete fix) Upgrade your Operating System:</strong> Upgrade to macOS 15+ (native non-contiguous tensor support)</li> </ul> <p><strong>Current Status:</strong> Random operations (<code class="language-plaintext highlighter-rouge">normal_</code>, <code class="language-plaintext highlighter-rouge">uniform_</code>, etc.) still have this bug on macOS &lt; 15 as of PyTorch 2.10 (I submitted a <a href="https://github.com/pytorch/pytorch/pull/165267">PR</a> to fix this). Other MPS operations may be affected.</p> <p><strong>Reproduction:</strong> A minimal reproduction script &amp; walkthrough is available at <a href="https://github.com/ElanaPearl/pytorch-mps-noncontiguous-bug">https://github.com/ElanaPearl/pytorch-mps-noncontiguous-bug</a>.</p> </div> </details> <h2 id="the-mystery-a-plateauing-loss">The Mystery: A Plateauing Loss</h2> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/loss_plateau-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/loss_plateau-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/loss_plateau-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/loss_plateau.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Training loss plateaued way too early. This felt like a standard hyperparameter issue- but I’d trained this same architecture on similar data with similar hyperparameters countless times and hit much lower losses.</p> <p>What had changed? Those runs were months old. I tried reproducing them exactly, but couldn’t pin down the exact environment—the codebase had evolved through multiple projects, refactors, and dependency updates. Without a clean “before vs after,” I had to debug forward.</p> <p>The architecture itself is straightforward: a two-layer sparse autoencoder (encoder –&gt; sparse hidden layer –&gt; decoder). However, it has some training quirks the <em>could</em> be potential culprits: the hidden layer uses TopK sparsity, where only the k largest activations remain (others are zeroed); the training process includes some manual gradient adjustments (gradient clipping for stability and modifications to decoder weight gradients); there’s an auxiliary loss term to encourage feature activation.</p> <p>Even though I thought my initial hyperparameters were already well-tested, I tried everything: varied learning rates, tested different schedules, tried different k values and hidden dimensions, adjusted the auxiliary loss coefficients.</p> <p>Nothing made a difference.</p> <p>Meanwhile, my actual research sat on hold while I was stuck second-guessing everything: was my code broken? My data corrupted? And the creeping doubt- I’ve been doing ML for years, why can’t I make a simple two-layer autoencoder train properly?</p> <p>The model was small enough that I was training on my MacBook (using the Apple Silicon GPU) and simple enough I could actually inspect every parameter. So after the standard checks turned up nothing, I started looking at the weights directly.</p> <p>I visualized the weights at initialization and after the first few training steps. The decoder weights were updating- values shifting, gradients being applied, nothing crazy. <strong>But the encoder weights… weren’t updating at all.</strong> No NaNs, no suspicious patterns… they just… weren’t changing. They stayed exactly at their initialized values, down to the last decimal place.</p> <p>Both layers participate in the same forward and backward pass. Why would one update and the other freeze completely?</p> <h2 id="isolating-the-problem">Isolating the Problem</h2> <h3 id="are-gradients-flowing">Are Gradients Flowing?</h3> <p>First check: are gradients even making it back to the encoder? The TopK sparsity should make gradients sparse—only the k activated features get gradients through backprop, the rest are zeroed. But maybe I messed up the implementation so that <em>no</em> encoder gradients flow at all? Or the manual gradient adjustments I was making somehow blocked everything?</p> <p>After <code class="language-plaintext highlighter-rouge">loss.backward()</code>, the gradient statistics were:</p> <table> <thead> <tr> <th> </th> <th><strong>Encoder</strong></th> <th><strong>Decoder</strong></th> </tr> </thead> <tbody> <tr> <td><strong>Max Grad</strong></td> <td>2.35e6</td> <td>6.64e6</td> </tr> <tr> <td><strong>Sparsity</strong></td> <td>88.5% zeros</td> <td>88.5% zeros</td> </tr> </tbody> </table> <p>The encoder gradients were there- and they were pretty big (as intended for my dataset)! And they were sparse (majority zeros) which was also expected, but there were still plenty of non-zero gradients. So gradients are definitely being calculated.</p> <h3 id="is-it-the-optimizer">Is It the Optimizer?</h3> <p>Since the gradients exist but weights aren’t updating, the optimizer must be doing something wrong. Testing with a simpler optimizer, stochastic gradient descent (SGD):</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Manual SGD update
</span><span class="k">with</span> <span class="n">torch</span><span class="p">.</span><span class="nf">no_grad</span><span class="p">():</span>
    <span class="n">model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span> <span class="o">-=</span> <span class="mf">0.001</span> <span class="o">*</span> <span class="n">model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">grad</span>
<span class="c1"># Encoder weights change! ✓
</span>
<span class="c1"># Torch SGD update
</span><span class="n">sgd_optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="nc">SGD</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="nf">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
<span class="n">sgd_optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
<span class="c1"># Encoder weights change! ✓
</span>
<span class="c1"># But with Adam...
</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">optim</span><span class="p">.</span><span class="nc">Adam</span><span class="p">(</span><span class="n">model</span><span class="p">.</span><span class="nf">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span>
<span class="n">optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
<span class="c1"># Encoder weights don't change! ✗
</span></code></pre></div></div> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">🤔</div> <div style=" padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> The issue is localized to Adam specifically! But why would Adam fail on the encoder but work perfectly on the decoder? </div> </div> <hr/> <h3 id="how-adam-works">How Adam Works</h3> <p>To understand what might be breaking, I need to understand what Adam actually does differently from simple gradient descent.</p> <details open=""> <summary><b>Understanding Adam's Algorithm (click to collapse if familiar)</b></summary> <div> <h3 id="problems-with-vanilla-sgd">Problems with Vanilla SGD</h3> <p>SGD updates all parameters the same way:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># SGD: one learning rate for everything
</span><span class="n">param</span> <span class="o">=</span> <span class="n">param</span> <span class="o">-</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">gradient</span>
</code></pre></div> </div> <p>This has a few problems:<d-footnote>SGD has other problems too (hence all the optimizer research), but these are the ones Adam addresses.</d-footnote></p> <ol> <li> <p><strong>Different parameters need different learning rates.</strong> Some parameters might consistently get gradients around 1000 while others get 0.01. With SGD’s fixed learning rate, you’re stuck: either you move too slowly on small gradients or you overshoot wildly on large ones.</p> </li> <li> <p><strong>The learning rate needs to change over time.</strong> Early in training, you want big steps to explore the space. Later, you need tiny steps to settle into a minimum. SGD requires manually decaying the learning rate on a schedule.</p> </li> </ol> <h3 id="adams-solution-adaptive-learning-rates-via-gradient-magnitude-tracking">Adam’s Solution: Adaptive Learning Rates via Gradient Magnitude Tracking</h3> <p>Adam maintains two pieces of state per parameter and uses two hyperparameters to control how these states evolve:</p> <p><strong>State variables</strong> (initialized to zero for each parameter):</p> <ul> <li><code class="language-plaintext highlighter-rouge">exp_avg</code>: Running average of gradients (first moment)</li> <li><code class="language-plaintext highlighter-rouge">exp_avg_sq</code>: Running average of squared gradients (second moment)</li> </ul> <p><strong>Hyperparameters</strong> (typically beta_1=0.9, beta_2=0.999):</p> <ul> <li><code class="language-plaintext highlighter-rouge">beta_1</code>: Decay rate for first moment (momentum)</li> <li><code class="language-plaintext highlighter-rouge">beta_2</code>: Decay rate for second moment (gradient magnitude history)</li> </ul> <p><strong>Here’s the simplified algorithm:</strong></p> <p>Initialize state (done once per parameter)</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">exp_avg</span> <span class="o">=</span> <span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">exp_avg_sq</span> <span class="o">=</span> <span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">)</span>
<span class="n">step</span> <span class="o">=</span> <span class="mi">0</span>
</code></pre></div> </div> <p>Each training step:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Update moments with exponential moving averages
</span><span class="n">exp_avg</span> <span class="o">=</span> <span class="n">beta_1</span> <span class="o">*</span> <span class="n">exp_avg</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta_1</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span>
<span class="n">exp_avg_sq</span> <span class="o">=</span> <span class="n">beta_2</span> <span class="o">*</span> <span class="n">exp_avg_sq</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta_2</span><span class="p">)</span> <span class="o">*</span> <span class="n">grad</span><span class="o">**</span><span class="mi">2</span>

<span class="c1"># Update step count
# (It effectively starts at 1 to avoid division by zero in bias correction)
</span><span class="n">step</span> <span class="o">+=</span> <span class="mi">1</span>

<span class="c1"># Bias correction
</span><span class="n">exp_avg_corrected</span> <span class="o">=</span> <span class="n">exp_avg</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta_1</span><span class="o">**</span><span class="n">step</span><span class="p">)</span>
<span class="n">exp_avg_sq_corrected</span> <span class="o">=</span> <span class="n">exp_avg_sq</span> <span class="o">/</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta_2</span><span class="o">**</span><span class="n">step</span><span class="p">)</span>

<span class="c1"># Adaptive parameter update
</span><span class="n">param</span> <span class="o">=</span> <span class="n">param</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">exp_avg_corrected</span> <span class="o">/</span> <span class="p">(</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">exp_avg_sq_corrected</span><span class="p">)</span> <span class="o">+</span> <span class="n">ε</span><span class="p">)</span>
</code></pre></div> </div> <p><strong>What Each Moment Does:</strong></p> <ul> <li> <p><strong>First moment (<code class="language-plaintext highlighter-rouge">exp_avg</code>)</strong>: Smooths out noisy gradients by averaging recent directions—like momentum in physics. When gradients oscillate (+10, -10, +8, -9…), the positive and negative values cancel out, revealing there’s no consistent direction. Beta_1=0.9 means “keep 90% of old momentum, add 10% of new gradient.” This smoothed momentum is what gets multiplied by the learning rate in the parameter update: <code class="language-plaintext highlighter-rouge">lr * exp_avg</code>.</p> </li> <li> <p><strong>Second moment (<code class="language-plaintext highlighter-rouge">exp_avg_sq</code>)</strong>: Tracks typical gradient <strong>magnitude</strong> for each parameter by averaging squared gradients. Squaring removes the +/- sign (both +10 and -10 become 100), preventing cancellation. Beta_2=0.999 means “keep 99.9% of magnitude history, add 0.1% of new squared gradient.” This magnitude normalizes the momentum-based update: <code class="language-plaintext highlighter-rouge">lr * exp_avg / sqrt(exp_avg_sq)</code>. Parameters with consistently large gradients get their updates scaled down (large denominator), while parameters with small gradients get boosted (small denominator). This is how Adam achieves <strong>adaptive per-parameter learning rates</strong>.</p> </li> <li> <p><strong>Epsilon (<code class="language-plaintext highlighter-rouge">ε=1e-8</code>)</strong>: Prevents division by zero.</p> </li> </ul> <p><strong>Bias Correction:</strong></p> <p>Both moments start at zero, causing early estimates to be biased toward zero. The correction factor <code class="language-plaintext highlighter-rouge">(1 - β**step)</code> provides a large boost early to counteract this, effectively “warming up” the optimizer over the first ~1000-3000 steps. As training progresses, the correction approaches 1 and has negligible effect.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/bias_correction_early-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/bias_correction_early-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/bias_correction_early-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/bias_correction_early.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>The second moment works similarly. Without correction, <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> would be only 0.1% of gradient² at step 1, but bias correction restores it to the full value.</p> <p>For a deeper dive into Adam’s design and intuition, as well as other optimizers that use momentum and adaptive learning rates (RMSprop, AdaGrad, etc.), check out <a href="https://cs231n.github.io/neural-networks-3/#update">Stanford’s CS231n notes on optimization</a>.</p> </div> </details> <p>Knowing what Adam <em>should</em> be doing, let’s look at the state it’s maintaining (those <code class="language-plaintext highlighter-rouge">exp_avg</code> and <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> tensors that track momentum and variance) to see what it’s <em>actually</em> doing.</p> <h3 id="examining-adams-state">Examining Adam’s State</h3> <p>For our frozen encoder, the maximum values in each state tensor were:</p> <table> <thead> <tr> <th> </th> <th><strong>Encoder</strong></th> <th><strong>Decoder</strong></th> </tr> </thead> <tbody> <tr> <td><strong>exp_avg</strong></td> <td>1.96e+05</td> <td>1.70e+06</td> </tr> <tr> <td><strong>exp_avg_sq</strong></td> <td><span style="display: inline-block; border: 2px solid var(--global-theme-color); border-radius: 4px; padding: 1px 10px; font-weight: bold;">0</span></td> <td>1.18e+11</td> </tr> </tbody> </table> <p>Wait, WHAT?! The encoder’s <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> is zero despite having momentum accumulated in <code class="language-plaintext highlighter-rouge">exp_avg</code>.</p> <p>This feels mathematically impossible… The second moment (<code class="language-plaintext highlighter-rouge">exp_avg_sq</code>) is zero despite non-zero gradients. Since <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> stores squared gradients, it should NEVER be zero if gradients are non-zero.</p> <p>And if it truly were zero, we’d see massive weight updates.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">param_update</span> <span class="o">=</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">exp_avg</span> <span class="o">/</span> <span class="p">(</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">exp_avg_sq</span><span class="p">)</span> <span class="o">+</span> <span class="n">ε</span><span class="p">)</span> 
             <span class="o">=</span> <span class="mf">0.001</span> <span class="o">*</span> <span class="mf">1.96e5</span> <span class="o">/</span> <span class="p">(</span><span class="nf">sqrt</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="o">+</span> <span class="mf">1e-8</span><span class="p">)</span>
             <span class="o">=</span> <span class="mi">196</span> <span class="o">/</span> <span class="mf">1e-8</span>
             <span class="o">=</span> <span class="mf">1.96e10</span>  <span class="c1"># &lt;-- HUGE!
</span></code></pre></div></div> <p>This would be <strong>huge</strong>! Yet we see NO updates… this paradox points to a deeper issue.</p> <h3 id="testing-hypotheses">Testing Hypotheses</h3> <h4 id="could-it-be-bias-correction">Could it be bias correction?</h4> <p>Adam uses bias correction to counteract zero initialization. Having previously encountered subtle training issues due to Adam bias initialization bugs, I wondered if the correction might be broken here. <d-footnote>💡If you haven't been hurt by a bias correction bug before, check out <a href="https://stats.stackexchange.com/questions/232741/why-is-it-important-to-include-a-bias-correction-term-for-the-adam-optimizer-for">these</a> <a href="https://stats.stackexchange.com/questions/237169/why-are-non-zero-centered-activation-functions-a-problem-in-backpropagation/237282#237282">examples</a> to learn the importance of getting this step right!</d-footnote></p> <p>Recall, the bias correction is simply making our effective beta values dependent on the step index, so if the issue has to do with bias correction, it might have some relation to our beta parameters or step index.</p> <p>I tested with different beta values, at different steps, and even beta_2=0 (which bypasses the exponential average entirely, making <code class="language-plaintext highlighter-rouge">exp_avg_sq = grad**2</code> directly). The encoder’s <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> still stayed zero, making bias correction seem less likely as a culprit.</p> <p>Plus, <code class="language-plaintext highlighter-rouge">exp_avg</code> updated correctly despite using the same bias correction mechanism. So maybe something else is preventing <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> from updating.</p> <h4 id="is-it-a-precision-issue">Is it a precision issue?</h4> <p>My largest gradients were big (1e6), and squared that’s 1e12. While that <em>is</em> quite large, it shouldn’t overflow in float32. However, I’ve also been hurt by precision bugs before<d-footnote>Floating point precision issues have a fun habit of causing silent failures/degradations like this one (where it completes but produces incorrect values). Always worth checking, even when it seems unlikely.</d-footnote>, so I had to try it anyway.</p> <p>I moved everything to float64… <strong>AND IT STARTED WORKING!</strong></p> <div style=" margin: 2rem 0; padding: 2rem; background: repeating-linear-gradient( 45deg, color-mix(in srgb, var(--global-theme-color) 8%, var(--global-bg-color)), color-mix(in srgb, var(--global-theme-color) 8%, var(--global-bg-color)) 10px, color-mix(in srgb, var(--global-theme-color) 12%, var(--global-bg-color)) 10px, color-mix(in srgb, var(--global-theme-color) 12%, var(--global-bg-color)) 20px ); border: 1px solid color-mix(in srgb, var(--global-theme-color) 20%, transparent); border-radius: 8px; font-family: 'Comic Sans MS', cursive, sans-serif; color: var(--global-text-color); line-height: 1.7; position: relative; overflow: hidden; "> <div style=" position: absolute; top: 10px; right: 15px; font-size: 2rem; opacity: 0.4; transform: rotate(15deg); ">😵‍💫</div> <span style="font-size: 1.2em; font-weight: bold; color: var(--global-theme-color);">Wait... how could this possibly be a precision issue?!</span> <p style="margin: 1rem 0; font-style: italic; color: color-mix(in srgb, var(--global-text-color) 85%, transparent);"> I asked Claude to help me understand the situation &amp; was told there are intermediate calculations in Adam that might overflow...</p> <p style="color: var(--global-text-color);">...but I couldn't find these mysterious intermediates in the code. And how would an overflow produce exact zeros instead of inf/NaN? Maybe we divide by the inf somewhere? Or there's an error correction step? Or we're underflowing? But that shouldn't give ALL zeros?!?! </p> <p style="margin: 1rem 0; font-weight: bold; color: var(--global-theme-color);"> ...Going to fp64 <em>DID</em> fix it though, and LLMs probably know PyTorch better than I do, so maybe I'm missing something obvious? But where was this secret intermediate? I couldn't find it anywhere... </p> <div style="text-align: center; margin-top: 1.5rem; font-size: 1.1em; color: var(--global-theme-color);"> <em>so now what???</em> </div> </div> <p>After a few more minutes of spiraling<d-footnote> You're probably not reading this for the mid-debugging-self-doubt, but every debugging adventure has a spiraling moment (at least for me) so feels disingenuous to skip this step. And maybe one of these theories could've actually been correct! </d-footnote>, I realized something: when I switched to float64, I <em>also</em> had to switch from MPS (Apple Silicon GPU) to CPU, since MPS doesn’t support float64. <strong>I’d changed two variables at once.</strong></p> <p>Testing with float32 on CPU… <strong>the weights update!!</strong></p> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">💡</div> <div style=" padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> Turns out, precision wasn't the culprit, it was <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">device-specific</code>! The exact same float32 code updates weights on CPU but fails on MPS. This was progress: same code, same datatypes, but different devices meant different implementations—and different bugs. </div> </div> <p>﹡ This is progress!!</p> <p>﹡ Note to self… simpler explanations are more likely correct- even (and especially!) when LLMs confidently assert complicated theories that are hard to understand / verify</p> <p>﹡ Now I just need to figure out why the bug only occurs with MPS</p> <h2 id="device-specific-differences">Device-Specific Differences</h2> <h3 id="why-the-same-operation-behaves-differently-on-different-chips">Why the Same Operation Behaves Differently on Different Chips</h3> <p>PyTorch’s device abstraction lets you write the same code and run it on CPUs, GPUs, and even Apple Silicon. It <em>feels</em> like the same computation is running everywhere — but under the hood, each device has its own entirely separate implementation.</p> <p>When you call a tensor operation like <code class="language-plaintext highlighter-rouge">matmul</code>, PyTorch looks at the tensor’s metadata (e.g. device, dtype, shape) and dispatches to a <strong>specialized kernel</strong>: a device-specific, highly optimized implementation tailored for that particular hardware backend.</p> <details><summary><b>Understanding Apple's GPU Stack and "Kernel" Terminology</b></summary> <div> <p><strong>Apple’s GPU Stack:</strong></p> <ul> <li><strong>Metal</strong> - Apple’s low-level graphics/compute API (like CUDA for NVIDIA)</li> <li><strong>MPS (Metal Performance Shaders)</strong> - High-level optimized functions built on Metal (like cuDNN for CUDA)</li> <li><strong>PyTorch’s MPS backend</strong> - PyTorch’s integration that uses both Metal directly and MPS functions</li> </ul> <p><strong>On “Kernel” Terminology:</strong></p> <p>Typically, “kernel” refers to low-level GPU code that runs directly on hardware: functions that explicitly manage parallelism across thousands of GPU cores, handle device memory allocation, and are written in chip-specific languages like CUDA or Metal Shading Language.</p> <p>However, PyTorch seems to also use “kernel” to describe a higher-level abstraction: the framework’s implementation code (C++, Objective-C++, or CUDA files in the <code class="language-plaintext highlighter-rouge">native/</code> directory) that handles specific operations for specific backends. These PyTorch kernels sit above the hardware level- they might call optimized libraries like MPS or cuDNN (which then use those low-level GPU kernels underneath), or they might contain hand-written GPU code.</p> <p>In this post, we end up primarily exploring PyTorch kernels (e.g. the C++/Objective-C++ code in <code class="language-plaintext highlighter-rouge">BinaryOps.mm</code> that orchestrates MPS operations) rather than the Metal compute shaders executing on GPU cores beneath them.</p> <p>I was surprised these higher-level implementations are also called “kernels” and maybe I have just confused my terminology here but I didn’t have a better name for them so I tried to mostly use “PyTorch kernel” or just “operation” to describe them, though the terminology does get blurry in places.</p> </div> </details> <p>So when you write something like <code class="language-plaintext highlighter-rouge">result = tensor_a @ tensor_b</code>, you’re not invoking a universal multiply function. PyTorch uses the tensors’ metadata to select a device- and dtype-specific kernel that performs the actual computation.</p> <p>Multiplying two tensors on the CPU uses a completely different kernel than on MPS or CUDA. Even on the same device, changing the dtype or layout can trigger a different kernel. PyTorch maintains a large set of these implementations to support all the combinations.</p> <p>We’ll see exactly how this dispatch system works in C++ later when we dive into the source code. For now, the important point is: <strong><em>even with identical Python code</em> different tensor metadata → different kernel code → different efficiency / bugs.</strong></p> <p>In my case, because I’m running this on my M3 MacBook Pro, I’ m using MPS (Metal Performance Shaders), which is the GPU backend for Apple Silicon. While it feels a bit crazy to assume that my training plateau is due to an internal kernel-level bug, it’s a bit less unreasonable with MPS as it’s newer and less mature than the CPU and CUDA backends. (And honestly, most people training/debugging ML models are not doing it on their MacBooks.)</p> <h3 id="why-does-only-the-encoder-hit-this-bug">Why Does Only the Encoder Hit This Bug?</h3> <p>The Adam bug appears when working with the encoder on MPS. What makes the encoder different from the decoder that would trigger different behavior?</p> <p>I tested everything I could think of that might differentiate the two tensors:</p> <ul> <li>Different gradient scales</li> <li>Dense vs sparse gradient patterns</li> <li>Removing decoder-specific gradient transformations</li> <li>Making encoder and decoder gradients statistically identical</li> </ul> <p>Nothing helped. Even when both tensors had similar gradient statistics, only the encoder’s <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> stayed frozen. The difference wasn’t in the <em>values</em> of the tensor - something else about the encoder tensor itself was triggering the bug.</p> <p><strong>What properties does a PyTorch tensor even have?</strong> I asked Claude what attributes could differ between two tensors and checked them one-by-one:</p> <table> <thead> <tr> <th> </th> <th><strong>Encoder</strong></th> <th><strong>Decoder</strong></th> <th><strong>Same?</strong></th> </tr> </thead> <tbody> <tr> <td><strong>Device</strong></td> <td>mps:0</td> <td>mps:0</td> <td>✓</td> </tr> <tr> <td><strong>Dtype</strong></td> <td>float32</td> <td>float32</td> <td>✓</td> </tr> <tr> <td><strong>Shape</strong></td> <td>[1536, 384]</td> <td>[384, 1536]</td> <td>❌</td> </tr> <tr> <td><strong>Requires_grad</strong></td> <td>True</td> <td>True</td> <td>✓</td> </tr> <tr> <td><strong>Stride</strong></td> <td>(1, 1536)</td> <td>(1536, 1)</td> <td>❌</td> </tr> <tr> <td><strong>Contiguous</strong></td> <td>False</td> <td>True</td> <td>❌</td> </tr> </tbody> </table> <p>Three differences! The encoder and decoder have different shapes (they’re transposes of each other)<d-footnote>PyTorch's <code>nn.Linear</code> stores weights as [out_features, in_features], so the encoder (384→1536) has shape [1536, 384] and the decoder (1536→384) has shape [384, 1536].</d-footnote>, different stride patterns, and different contiguity. These properties are all related (more on that below).</p> <p>The shape difference itself can’t cause different behavior (PyTorch operations handle any shape). But contiguity? That’s a low-level memory detail that could be relevant. Maybe the MPS Adam bug only affects non-contiguous tensors? Worth a shot:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">model</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="nf">contiguous</span><span class="p">()</span>
<span class="n">optimizer</span><span class="p">.</span><span class="nf">step</span><span class="p">()</span>
<span class="c1"># Encoder updates!! ✓
</span></code></pre></div></div> <p><strong>IT WORKS!</strong> But <em>why</em>?</p> <h2 id="tensor-memory-layouts">Tensor Memory Layouts</h2> <h3 id="what-does-contiguous-even-mean">What Does “Contiguous” Even Mean?</h3> <p>Your computer’s memory is just a flat, 1D array of bytes, but tensors represent multi-dimensional grids. When you index <code class="language-plaintext highlighter-rouge">tensor[i, j]</code>, PyTorch needs to find that element in the flat memory. The tensor’s <strong>stride</strong> tells it how to do this conversion (and the exact amount you jump between elements depends on the dtype and how much memory each element takes up).</p> <p>Think of stride as <strong>navigation instructions</strong>: “to get from one row to the next, skip this many elements.” By default, memory is stored row-wise—each row is stored sequentially, then the next row comes after. If you read through a row, you skip over 1 element at a time; to go to the next row, you move row-length elements over. (This is why going across a row is faster than going down a column.)</p> <p>However, the memory layout doesn’t have to match the logical layout we use to think about the tensor. We can change how the user views the tensor without moving any data! For example, when we run transpose (<code class="language-plaintext highlighter-rouge">.T</code>), we don’t need to move around any data—we just change the stride!</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_contig-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_contig-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_contig-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_contig.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_non_contig-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_non_contig-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_non_contig-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/memory_layout_non_contig.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>As we see in the images, reading all the elements row-by-row in the contiguous tensor is easy and linear, but the same row-wise pattern in the non-contiguous tensor is much jumpier. This jumping pattern makes the tensor “non-contiguous.”</p> <p>While there’s only one way for a tensor to be contiguous (the “natural” layout), there are many ways to become non-contiguous. By default, tensors are initialized as contiguous, but operations like slicing (<code class="language-plaintext highlighter-rouge">tensor[::2, :]</code>), reshaping, and dimension reordering (<code class="language-plaintext highlighter-rouge">permute</code>) can all create different non-contiguous stride patterns.</p> <p><strong>Why design tensors this way?</strong> Wouldn’t it be simpler to always keep data in the “natural” contiguous layout? The answer is performance: by just adjusting the tensor’s metadata, operations like transpose, slice, and reshape can be nearly <strong>instant</strong>— no data movement or memory allocation required. Keeping everything contiguous would mean expensive copying every time you reorganize dimensions.</p> <h3 id="how-my-encoder-became-non-contiguous">How My Encoder Became Non-Contiguous</h3> <p>Looking at the weight initialization code:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">self</span><span class="p">.</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">decoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">T</span><span class="p">.</span><span class="nf">clone</span><span class="p">()</span>
</code></pre></div></div> <p>The <code class="language-plaintext highlighter-rouge">.T</code> creates a non-contiguous view, and <code class="language-plaintext highlighter-rouge">.clone()</code> preserves the stride pattern.</p> <details> <summary><b>Why does <code>.clone()</code> preserve stride patterns?</b></summary> <div> <p>At first this felt counterintuitive to me- if we’re already paying the cost to copy the data (the whole point of non-contiguous layouts is to avoid copying), why not copy it into the “better” contiguous layout?</p> <p>But this actually makes sense from a design perspective: <code class="language-plaintext highlighter-rouge">.clone()</code> should create an exact copy with all properties preserved, including memory layout. The tensor might be non-contiguous for a reason—maybe you’re about to transpose it back, or the layout is optimized for some operation. Silently reorganizing memory would be surprising behavior. (The optional <a href="https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch.memory_format"><code class="language-plaintext highlighter-rouge">torch.memory_format</code></a> argument, which defaults to <code class="language-plaintext highlighter-rouge">torch.preserve_format</code>, makes this choice explicit.)</p> <p>As a bonus, preserving the layout is also faster. Even though both include new memory allocation and moving data, reorganizing it still slows things down:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x_t</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="n">T</span>  <span class="c1"># Start with non-contiguous
</span><span class="n">y_noncontig</span> <span class="o">=</span> <span class="n">x_t</span><span class="p">.</span><span class="nf">clone</span><span class="p">()</span>              <span class="c1"># Preserves non-contiguous (1.919ms)
</span><span class="n">y_contig</span> <span class="o">=</span> <span class="n">x_t</span><span class="p">.</span><span class="nf">clone</span><span class="p">(</span><span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">contiguous_format</span><span class="p">)</span>  <span class="c1"># Force contiguous (4.401ms)
</span></code></pre></div> </div> </div> </details> <p><strong>Okay so we now know this initialization is why only the encoder is non-contiguous, and thus why only the encoder has training issues!</strong></p> <p><em>While I could just call <code class="language-plaintext highlighter-rouge">.contiguous()</code> on my encoder, declare victory, and get back to the research this bug was blocking me from doing… I felt like I was just scratching the surface of this bug and I feared it would haunt me until I fully figured out WHAT happened and WHY.</em></p> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">🔎</div> <div style=" padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> Why does a non-contiguous encoder weight cause a zero second moment and no parameter updates with Adam on MPS?? </div> </div> <h2 id="identifying-the-broken-operations">Identifying the Broken Operations</h2> <h3 id="what-operations-does-adam-use">What Operations Does Adam Use?</h3> <p>When Adam updates parameters, what operations does it perform? Let’s look at <a href="https://github.com/pytorch/pytorch/blob/main/torch/optim/adam.py">PyTorch’s Adam implementation</a>.</p> <p>Fair warning: this file is over 1000 lines! To find what we need, search for where <code class="language-plaintext highlighter-rouge">exp_avg</code> and <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> are defined and updated.</p> <p>Here are the critical lines (<a href="https://github.com/pytorch/pytorch/blob/39901f229520a5256505ec24782f716ee7ddc843/torch/optim/adam.py#L101">lines 101, 391-407</a>):</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># State initialization (line 101)
</span><span class="n">state</span><span class="p">[</span><span class="sh">"</span><span class="s">exp_avg</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">preserve_format</span><span class="p">)</span>
<span class="n">state</span><span class="p">[</span><span class="sh">"</span><span class="s">exp_avg_sq</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">preserve_format</span><span class="p">)</span>

<span class="c1"># ... [300 lines of setup and parameter group handling] ...
</span>
<span class="c1"># First moment update (line 391)
</span><span class="n">exp_avg</span><span class="p">.</span><span class="nf">lerp_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">beta1</span><span class="p">)</span>

<span class="c1"># Second moment update (line 392)
</span><span class="n">exp_avg_sq</span><span class="p">.</span><span class="nf">mul_</span><span class="p">(</span><span class="n">beta2</span><span class="p">).</span><span class="nf">addcmul_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="mi">1</span> <span class="o">-</span> <span class="n">beta2</span><span class="p">)</span>

<span class="c1"># ... [bias correction calculations] ...
</span>
<span class="c1"># Parameter update (line 407)
</span><span class="n">param</span><span class="p">.</span><span class="nf">addcdiv_</span><span class="p">(</span><span class="n">exp_avg</span><span class="p">,</span> <span class="n">denom</span><span class="p">,</span> <span class="n">value</span><span class="o">=-</span><span class="n">step_size</span><span class="p">)</span>
</code></pre></div></div> <p>Look at that initialization! <code class="language-plaintext highlighter-rouge">memory_format=torch.preserve_format</code> means the state tensors inherit their stride pattern from <code class="language-plaintext highlighter-rouge">param</code>. So when our encoder weight is non-contiguous, both <code class="language-plaintext highlighter-rouge">exp_avg</code> and <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> are also non-contiguous.</p> <p>But they’re BOTH non-contiguous - so why does only one break?</p> <p>Well, while they both are computed via addition and multiplication, they don’t use the exact same operations to perform this. Any of these operations could be a suspect, so let’s test each one individually!</p> <p>For operations like <code class="language-plaintext highlighter-rouge">output.addcmul_(input1, input2)</code>, the <strong>output tensor</strong><d-footnote>In PyTorch, when a function name ends with an underscore (like <code>mul_</code>), that indicates that it is performing an <b>in-place operation</b> to modify a tensor directly in memory. Just as different devices can distinct kernels, so can distinctions like these!</d-footnote> is modified while <strong>input tensors</strong> are read from. In our case, we know the output tensor is non-contiguous, so let’s test if that is sufficient to cause our bug.</p> <h3 id="testing-the-broken-operations">Testing the Broken Operations</h3> <p>Testing each Adam operation with non-contiguous output tensors on MPS:</p> <table> <thead> <tr> <th><strong>Operation</strong></th> <th><strong>Function</strong></th> <th><strong>Result</strong></th> </tr> </thead> <tbody> <tr> <td>Linear interpolation</td> <td><code class="language-plaintext highlighter-rouge">lerp_()</code></td> <td>Updates ✓</td> </tr> <tr> <td>Scalar multiply</td> <td><code class="language-plaintext highlighter-rouge">mul_()</code></td> <td>Updates ✓</td> </tr> <tr> <td>Add + multiply</td> <td><code class="language-plaintext highlighter-rouge">addcmul_()</code></td> <td>Stays zero ✗</td> </tr> <tr> <td>Add + divide</td> <td><code class="language-plaintext highlighter-rouge">addcdiv_()</code></td> <td>Stays zero ✗</td> </tr> </tbody> </table> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">‼️</div> <div style=" padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> Found it! <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">addcmul_()</code> and <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">addcdiv_()</code> both fail silently when writing to non-contiguous outputs on MPS. </div> </div> <p>Interestingly, <em>input contiguity doesn’t matter</em>, only the output! Whether <code class="language-plaintext highlighter-rouge">grad</code>, <code class="language-plaintext highlighter-rouge">exp_avg</code>, or <code class="language-plaintext highlighter-rouge">denom</code> are contiguous makes no difference. The bug is purely in how these kernels write to <em>non-contiguous output buffers</em>.</p> <p>The broken operations aren’t producing zeros or NaNs. They’re simply not modifying the output tensor at all. This wasn’t immediately obvious since <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> was initialized to zeros, making “stays at zero” and “never updates” look identical. But testing with a non-zero, non-contiguous output tensor confirms that after calling <code class="language-plaintext highlighter-rouge">addcmul_</code> or <code class="language-plaintext highlighter-rouge">addcdiv_</code>, the values remain unchanged. No update happens.</p> <p>Yet timing shows MPS <em>is</em> doing substantial work. Non-contiguous operations take &gt;2x longer than contiguous ones, proving the kernels are computing <em>something</em>, yet those results never make it to the output tensor. On CPU, each of these operations work correctly regardless of memory layout. This is purely a MPS-specific bug.</p> <p>With the broken operations identified, we can trace the complete chain of events that triggers our failure:</p> <h3 id="putting-the-pieces-together">Putting the Pieces Together</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/complete_bug_chain-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/complete_bug_chain-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/complete_bug_chain-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/complete_bug_chain.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <details> <summary><b>Show the complete bug chain in code</b></summary> <div> <p><strong>Step 1: Initialization</strong></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Creates non-contiguous encoder weight (stride: 1, 1536)
</span><span class="n">encoder</span><span class="p">.</span><span class="n">weight</span> <span class="o">=</span> <span class="n">decoder</span><span class="p">.</span><span class="n">weight</span><span class="p">.</span><span class="n">T</span><span class="p">.</span><span class="nf">clone</span><span class="p">()</span>
</code></pre></div> </div> <p><strong>Step 2: Adam State Creation</strong></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Both state tensors inherit non-contiguous layout from param
</span><span class="n">state</span><span class="p">[</span><span class="sh">"</span><span class="s">exp_avg</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">preserve_format</span><span class="p">)</span>
<span class="n">state</span><span class="p">[</span><span class="sh">"</span><span class="s">exp_avg_sq</span><span class="sh">"</span><span class="p">]</span> <span class="o">=</span> <span class="nf">zeros_like</span><span class="p">(</span><span class="n">param</span><span class="p">,</span> <span class="n">memory_format</span><span class="o">=</span><span class="n">torch</span><span class="p">.</span><span class="n">preserve_format</span><span class="p">)</span>
</code></pre></div> </div> <p><strong>Step 3: Optimization Loop</strong></p> <p><em>First moment update:</em></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">exp_avg</span><span class="p">.</span><span class="nf">lerp_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="mi">1</span><span class="o">-</span><span class="n">beta_1</span><span class="p">)</span>  <span class="c1"># ✓ Works fine
</span></code></pre></div> </div> <p><em>Second moment update:</em></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">exp_avg_sq</span><span class="p">.</span><span class="nf">mul_</span><span class="p">(</span><span class="n">beta_2</span><span class="p">)</span>                        <span class="c1"># ✓ Works fine
</span><span class="n">exp_avg_sq</span><span class="p">.</span><span class="nf">addcmul_</span><span class="p">(</span><span class="n">grad</span><span class="p">,</span> <span class="n">grad</span><span class="p">,</span> <span class="mi">1</span><span class="o">-</span><span class="n">beta_2</span><span class="p">)</span>      <span class="c1"># ✗ No update - stays zero!
</span></code></pre></div> </div> <p><strong>Step 4: Parameter Update</strong></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># Should update param, does nothing, leading to silent failure
</span><span class="n">param</span><span class="p">.</span><span class="nf">addcdiv_</span><span class="p">(</span><span class="n">exp_avg</span><span class="p">,</span> <span class="n">denom</span><span class="p">,</span> <span class="n">value</span><span class="o">=-</span><span class="n">step_size</span><span class="p">)</span>  <span class="c1"># ✗ No update!
</span></code></pre></div> </div> </div> </details> <p>If <em>only</em> <code class="language-plaintext highlighter-rouge">exp_avg_sq.addcmul_()</code> failed, the zero <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> would produce massive weight explosions (update = <code class="language-plaintext highlighter-rouge">lr × exp_avg / √(ε)</code>), making the bug immediately obvious. But <code class="language-plaintext highlighter-rouge">param.addcdiv_()</code> <em>also</em> failed, producing no updates at all!</p> <p>The second bug masked the first, creating a silent failure: the spookiest type of error. The model appeared to be learning (the decoder was training normally), but progress stalled because the encoder stayed frozen. A subtle plateau that looked exactly like a hyperparameter issue 🙃</p> <details> <summary><b>Side note: Why did forward and backward passes work fine with non-contiguous weights?</b></summary> <div> <p>If non-contiguous tensors can cause operations to silently fail on MPS, why didn’t the forward pass or backward pass break?</p> <p>The forward and backward passes for <code class="language-plaintext highlighter-rouge">F.linear</code> use <code class="language-plaintext highlighter-rouge">matmul</code> for their matrix multiplications, which handle non-contiguous tensors correctly on MPS. Testing confirms that both <code class="language-plaintext highlighter-rouge">matmul</code> (the <code class="language-plaintext highlighter-rouge">@</code> operator) and <code class="language-plaintext highlighter-rouge">F.linear</code> work correctly with non-contiguous input tensors and non-contiguous weight matrices on MPS, including during the backward pass where gradients flow through non-contiguous weights without issues.</p> <p>The bug is specific to the fused in-place operations that Adam uses for state updates: <code class="language-plaintext highlighter-rouge">addcmul_</code> and <code class="language-plaintext highlighter-rouge">addcdiv_</code>. These operations fail silently when writing to non-contiguous output tensors, while other in-place operations like <code class="language-plaintext highlighter-rouge">lerp_</code> and <code class="language-plaintext highlighter-rouge">mul_</code> work correctly.</p> </div> </details> <p><strong>While we have made so much progress on this case, we’re still not done yet!!</strong></p> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: -12px; left: 1rem; background: var(--global-theme-color); color: var(--global-hover-text-color); padding: 0.4rem 1rem; border-radius: 1rem; font-size: 0.75rem; font-weight: 600; text-transform: uppercase; letter-spacing: 0.5px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1); ">Remaining Question</div> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">🔍</div> <div style=" margin-top: 0.5rem; padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> Why do <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">addcmul_</code> and <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">addcdiv_</code> fail to update non-contiguous outputs while <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">mul_</code> and <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">lerp_</code> work fine? </div> </div> <h2 id="inside-the-kernel-implementation">Inside the Kernel Implementation</h2> <p>To understand why some operations work and others don’t, I needed to look at PyTorch’s source code for the buggy kernels.</p> <p>While I normally trace through a Python codebase by jumping to definitions in my IDE, that doesn’t work with <code class="language-plaintext highlighter-rouge">tensor.addcmul_()</code>. When you call this function, there’s no Python source code executing - instead, Python immediately jumps into compiled C++ code for performance. And since PyTorch ships this as a pre-compiled binary, I can’t see that C++ implementation.</p> <details> <summary><b>How can Python call C++ functions? (a brief aside on bindings)</b></summary> <div> <p>How can a Python tensor object have methods that execute C++ code? I skipped over this earlier but even though I know PyTorch isn’t the only framework to do this and everything is just machine code if you zoom in close enough… it still feels a bit magical to casually call another language.</p> <p>The explanation is <strong>Python bindings</strong>.</p> <p>When you install PyTorch, you’re not just getting Python files. You’re also getting compiled C++ libraries (.so files on Linux/Mac, .dll on Windows) that contain the actual mathematical operations. The Python part is essentially a wrapper that:</p> <ol> <li>Takes your Python arguments (<code class="language-plaintext highlighter-rouge">tensor</code>, <code class="language-plaintext highlighter-rouge">other_tensor</code>, etc.)</li> <li>Converts them to C++ data structures</li> <li>Calls the appropriate C++ function</li> <li>Converts the C++ result back to a Python tensor</li> <li>Returns it to your Python code</li> </ol> <p>PyTorch uses <a href="https://pybind11.readthedocs.io/">pybind11</a> to automatically generate this wrapper code. For example, the C++ function signature:</p> <div class="language-cpp highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">addcmul_</span><span class="p">(</span><span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">self</span><span class="p">,</span> <span class="k">const</span> <span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">tensor1</span><span class="p">,</span> <span class="k">const</span> <span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">tensor2</span><span class="p">,</span> <span class="k">const</span> <span class="n">Scalar</span><span class="o">&amp;</span> <span class="n">value</span><span class="p">)</span>
</code></pre></div> </div> <p>Gets automatically wrapped so you can call it from Python as:</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">tensor</span><span class="p">.</span><span class="nf">addcmul_</span><span class="p">(</span><span class="n">tensor1</span><span class="p">,</span> <span class="n">tensor2</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span>
</code></pre></div> </div> <p>This is why PyTorch operations are fast despite being called from Python - the heavy lifting happens in optimized C++ code, with Python just handling the interface.</p> </div> </details> <p>And as we discussed earlier, PyTorch dispatches based on tensor metadata, so there isn’t just <em>one</em> implementation - there are device-specific kernels for CPU, CUDA, MPS, etc. Since my PyTorch installation just has the compiled binary files, to investigate the actual implementations, we need to clone PyTorch’s repository.</p> <h3 id="pytorchs-dispatch-system">PyTorch’s Dispatch System</h3> <p>All kernels are listed in an <strong>operation registry</strong> - a YAML file that maps operation names (like <code class="language-plaintext highlighter-rouge">addcmul_</code>) to their tensor-specific C++ implementations. In practice, when PyTorch is compiled (normally done before you install it), this registry is used to automatically generate hundreds of scripts that do the actual dispatching based on the patterns described here, but if we just want to understand what kernel our tensor is calling, we can look through the registry.</p> <p>Searching for “addcmul_” in the registry <code class="language-plaintext highlighter-rouge">native_functions.yaml</code>:</p> <div class="language-yaml highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -&gt; Tensor(a!)</span>
  <span class="c1"># our addcmul_ function just points us to the yaml for addcmul.out</span>
  <span class="na">structured_delegate</span><span class="pi">:</span> <span class="s">addcmul.out</span>

<span class="c1"># The function addcmul_ points to:</span>
<span class="pi">-</span> <span class="na">func</span><span class="pi">:</span> <span class="s">addcmul.out(...)</span>
  <span class="na">dispatch</span><span class="pi">:</span>
    <span class="s">CPU, CUDA</span><span class="err">:</span> <span class="s">addcmul_out</span>
    <span class="s">MPS</span><span class="err">:</span> <span class="s">addcmul_out_mps</span>  <span class="c1"># Different function for MPS!</span>
</code></pre></div></div> <p>Now that we have the device-specific operation names, we can search them in the PyTorch repo within the <a href="`https://github.com/pytorch/pytorch/blob/v2.2.1/aten/src/ATen/native/mps/`">mps implementations</a>, and we find our implementation for <code class="language-plaintext highlighter-rouge">addcmul_out_mps</code> in <a href="https://github.com/pytorch/pytorch/blob/v2.2.1/aten/src/ATen/native/mps/operations/PointwiseOps.mm"><code class="language-plaintext highlighter-rouge">PointwiseOps.mm</code></a>. Upon a first skim of the code, I realized I had no clue how to read the MPS codebase. There were too many unknown variables and constructs, and I wasn’t sure what to look for in this implementation. I’d written a CUDA kernel before, and was pretty good with C about a decade ago, but as turns out, neither of those helped here :(</p> <h3 id="comparing-broken-vs-working-implementations">Comparing Broken vs Working Implementations</h3> <p>Rather than trying to decode unfamiliar code in isolation, I’d find something similar that works correctly and compare the two. <code class="language-plaintext highlighter-rouge">mul_</code> was the perfect comparison since both are simple element-wise in-place operations. The registry pointed me to <code class="language-plaintext highlighter-rouge">binaryOpTensor</code> in <a href="https://github.com/pytorch/pytorch/blob/v2.2.1/aten/src/ATen/native/mps/operations/BinaryOps.mm"><code class="language-plaintext highlighter-rouge">BinaryOps.mm</code></a>.</p> <p>Now I had my comparison:</p> <ul> <li><strong>Broken:</strong> <code class="language-plaintext highlighter-rouge">addc_mul_div_out_mps</code> in <code class="language-plaintext highlighter-rouge">PointwiseOps.mm</code> (used by <code class="language-plaintext highlighter-rouge">addcmul_</code>)</li> <li><strong>Working:</strong> <code class="language-plaintext highlighter-rouge">binaryOpTensor</code> in <code class="language-plaintext highlighter-rouge">BinaryOps.mm</code> (used by <code class="language-plaintext highlighter-rouge">mul_</code>)</li> </ul> <p>I opened both side-by-side, scanning specifically for differences in how they handle the output tensor. My experiments had already narrowed the search: I knew both operations were computing <em>something</em> (timing proved that), so the bug had to be in how results get written back to non-contiguous outputs. Look for anything related to contiguity checks or special output handling.</p> <p><strong>Broken version (<code class="language-plaintext highlighter-rouge">addcmul_</code>):</strong></p> <div class="language-objc highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">static</span> <span class="kt">void</span> <span class="nf">addc_mul_div_out_mps</span><span class="p">(...,</span> <span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">output</span><span class="p">,</span> <span class="p">...)</span> <span class="p">{</span>
  <span class="c1">// ... setup code ...</span>
  <span class="n">Placeholder</span> <span class="n">outputPlaceholder</span> <span class="o">=</span> <span class="n">Placeholder</span><span class="p">(</span><span class="n">output</span><span class="p">);</span>
  <span class="n">runMPSGraph</span><span class="p">(...);</span>
  <span class="c1">// That's it - no additional handling</span>
<span class="p">}</span>
</code></pre></div></div> <p><strong>Working version (<code class="language-plaintext highlighter-rouge">mul_</code>):</strong></p> <div class="language-objc highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">static</span> <span class="kt">void</span> <span class="nf">binaryOpTensor</span><span class="p">(...,</span> <span class="n">Tensor</span><span class="o">&amp;</span> <span class="n">output</span><span class="p">,</span> <span class="p">...)</span> <span class="p">{</span>
  <span class="c1">// ... setup code ...</span>
  
  <span class="n">bool</span> <span class="n">needsCopyToOutput</span> <span class="o">=</span> <span class="o">!</span><span class="n">output</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">();</span>
  <span class="k">if</span> <span class="p">(</span><span class="n">needsCopyToOutput</span><span class="p">)</span> <span class="p">{</span>
    <span class="c1">// Create temporary contiguous tensor</span>
    <span class="n">output</span> <span class="o">=</span> <span class="n">at</span><span class="o">::</span><span class="n">empty</span><span class="p">(...);</span>
  <span class="p">}</span>
  
  <span class="n">Placeholder</span> <span class="n">outputPlaceholder</span> <span class="o">=</span> <span class="n">Placeholder</span><span class="p">(</span><span class="n">output</span><span class="p">);</span>
  <span class="n">runMPSGraph</span><span class="p">(...);</span>
  
  <span class="k">if</span> <span class="p">(</span><span class="n">needsCopyToOutput</span><span class="p">)</span> <span class="p">{</span>
    <span class="n">output_</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">output</span><span class="p">);</span>  <span class="c1">// Copy results back!</span>
  <span class="p">}</span>
<span class="p">}</span>
</code></pre></div></div> <p>The working version explicitly checks <code class="language-plaintext highlighter-rouge">!output.is_contiguous()</code> and adds extra handling: it creates a temporary contiguous tensor, runs the operation, then copies results back. The broken version just passes the output directly to <code class="language-plaintext highlighter-rouge">Placeholder</code> and calls it a day.</p> <p>But this raises a new question: if non-contiguous memory layouts need this kind of explicit handling, why doesn’t <code class="language-plaintext highlighter-rouge">addcmul</code> just crash or throw an error instead of silently failing?</p> <h3 id="the-memory-conversion-problem">The Memory Conversion Problem</h3> <p>The answer lies in understanding what <code class="language-plaintext highlighter-rouge">Placeholder</code> does. PyTorch tensors and Metal (Apple’s GPU framework) use different memory formats, so PyTorch needs a converter when running operations on Apple Silicon. <code class="language-plaintext highlighter-rouge">Placeholder</code> handles this conversion - it takes PyTorch tensors and wraps them in Metal-compatible buffers, handles different data types, manages memory layouts, and sets up the compute pipeline.</p> <p>For most tensors, this conversion is straightforward. But for non-contiguous tensors, Metal can’t work with the scattered memory layout directly. Looking at the Placeholder code:</p> <div class="language-objc highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">if</span> <span class="p">(</span><span class="o">!</span><span class="n">src</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">())</span> <span class="p">{</span>
    <span class="n">_tensor</span> <span class="o">=</span> <span class="n">src</span><span class="p">.</span><span class="n">clone</span><span class="p">(</span><span class="n">MemoryFormat</span><span class="o">::</span><span class="n">Contiguous</span><span class="p">);</span>  <span class="c1">// Create contiguous copy</span>
    <span class="n">srcBuf</span> <span class="o">=</span> <span class="n">getMTLBufferStorage</span><span class="p">(</span><span class="n">_tensor</span><span class="p">);</span>          <span class="c1">// Point Metal to the copy</span>
<span class="p">}</span>
</code></pre></div></div> <p>When Placeholder encounters a non-contiguous tensor, it automatically creates a contiguous copy and points Metal to that copy instead. This happens transparently - the broken kernels have no idea they’re working with a temporary.</p> <p>This automatic copying is perfect for <strong>input tensors</strong> - Metal reads from the copy, computation proceeds normally, and nobody cares what happens to the temporary afterward.</p> <p>But it’s disastrous for <strong>output tensors</strong> where the goal is in-place editing. The computation succeeds and writes results to the temporary copy, but those results never make it back to the original tensor that’s supposed to be updated.</p> <details> <summary><b>Why is this MPS-Specific?</b></summary> <div> <p>If non-contiguous tensors are so problematic, why do CPU and CUDA backends handle them fine?</p> <p><strong>CPU:</strong> Can handle arbitrary strides natively. When iterating through a non-contiguous tensor, the CPU just follows the stride pattern—jumping around memory is slower than sequential access, but it works correctly.</p> <p><strong>CUDA:</strong> NVIDIA’s CUDA framework has always supported strided memory access in kernels. Operations can read/write to non-contiguous layouts directly, though with some performance penalty.</p> <p><strong>MPS:</strong> Apple’s Metal Performance Shaders framework initially didn’t support strided access. Kernels expected contiguous memory layouts, period. This forced PyTorch to implement the gather-scatter workaround pattern we saw in the working kernels.</p> <p>The bug occurred because some MPS operations implemented this workaround (like <code class="language-plaintext highlighter-rouge">mul_</code>), while others didn’t (like <code class="language-plaintext highlighter-rouge">addcmul_</code>). The abstraction (Placeholder) that was supposed to hide this complexity actually made it worse by silently copying outputs without a way to copy results back. Although as we’ll learn later this has been improved in newer Mac Operating Systems.</p> </div> </details> <h3 id="the-complete-bug-mechanism">The Complete Bug Mechanism</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/the_bug_that_taught_me_pytorch_post/placeholder_bug_mechanism-480.webp 480w,/assets/img/the_bug_that_taught_me_pytorch_post/placeholder_bug_mechanism-800.webp 800w,/assets/img/the_bug_that_taught_me_pytorch_post/placeholder_bug_mechanism-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/the_bug_that_taught_me_pytorch_post/placeholder_bug_mechanism.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>The broken kernels work perfectly with contiguous tensors and silently fail with non-contiguous ones. The working kernels detect this situation and add an explicit copy-back step to move results from the temporary to the original tensor.</p> <h3 id="the-fix">The Fix</h3> <p>Understanding the bug made the solution clear - apply the same pattern that working kernels use:</p> <style>.diff-container{background:#f8f9fa;border-radius:8px;overflow-x:auto;border:1px solid #e0e0e0;max-width:100%;margin:20px 0}.diff-line{display:flex;font-family:'Monaco','Menlo','Ubuntu Mono',monospace;font-size:13px;line-height:1.6;border-bottom:1px solid #e8e8e8}.diff-line:last-child{border-bottom:0}.line-number{padding:4px 12px;background:#f0f0f0;color:#999;text-align:right;user-select:none;min-width:50px;border-right:1px solid #e0e0e0}.line-content{padding:4px 12px;flex:1;white-space:pre}.added{background:#e6ffed;color:#24292e}.added .line-number{background:#cdffd8;color:#22863a}.removed{background:#ffeef0;color:#24292e}.removed .line-number{background:#ffdce0;color:#d73a49}.unchanged{background:white}</style> <div class="diff-container"> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">Tensor output = output_;</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">bool needsCopyToOutput = false;</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> </div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">if (!output_.is_contiguous()) {</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> output = at::empty(...); // Create contiguous buffer WE manage</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> needsCopyToOutput = true;</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">}</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> </div> </div> <div class="diff-line unchanged"> <div class="line-number">1</div> <div class="line-content">@autoreleasepool {</div> </div> <div class="diff-line unchanged"> <div class="line-number">2</div> <div class="line-content"> Placeholder outputPlaceholder = Placeholder(output);</div> </div> <div class="diff-line unchanged"> <div class="line-number">3</div> <div class="line-content"> runMPSGraph(...);</div> </div> <div class="diff-line unchanged"> <div class="line-number">4</div> <div class="line-content">}</div> </div> <div class="diff-line removed"> <div class="line-number">-</div> <div class="line-content">// No copy-back - results vanish when Placeholder dies</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> </div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">if (needsCopyToOutput) {</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content"> output_.copy_(output); // Copy results back</div> </div> <div class="diff-line added"> <div class="line-number">+</div> <div class="line-content">}</div> </div> </div> <p>I tested this locally and it worked! The encoder weights finally updated and the model trained successfully 🎉🎉</p> <p>You can see the complete reproduction, debugging experiments, fix at <a href="https://github.com/ElanaPearl/pytorch-mps-noncontiguous-bug">https://github.com/ElanaPearl/pytorch-mps-noncontiguous-bug</a>.</p> <h2 id="case-closed">Case Closed</h2> <h3 id="a-lesson-in-version-control">A Lesson in Version Control</h3> <p>While editing a Python package just involves installing your locally editable version of the code instead of the default package, to test my PyTorch fix, I had to re-build it all locally, which was more work than expected and <em>also</em> made me acutely aware that this whole time I was working on PyTorch v2.2.1<d-footnote>I was working on a research codebase with dependency conflicts that blocked upgrading PyTorch. Common enough situation, but lesson learned: always check versions early in debugging, even if you can't immediately update!</d-footnote> (as this fact made it difficult to build and I had to downgrade things like CMake and deal with weird version conflicts to even build this older PyTorch).</p> <p>Checking the latest version revealed the bug was already fixed in v2.4, patched by an ML engineer at Apple last year using almost the exact same approach I’d used.<d-footnote>The official fix uses slightly different syntax; but the same core pattern: detect non-contiguous output, create a contiguous temporary buffer, perform the computation, then copy results back to the original tensor.</d-footnote> This updated code even informed me that in macOS 15+, MPS now handles non-contiguous tensors natively! <d-footnote>In macOS 15, Apple added native strided array support to MPSGraph via the <code>arrayView</code> API (see <a href="https://developer.apple.com/videos/play/wwdc2024/10218/">WWDC 2024 session</a> at timestamp 13:41). Instead of the gather-scatter workaround, Metal can now read/write directly from non-contiguous memory using stride metadata. This means on macOS 15+, PyTorch can skip the manual copy workarounds entirely. The performance gap between contiguous and non-contiguous tensors is now much smaller, though contiguous is still faster due to better cache utilization.</d-footnote></p> <div style=" background: var(--global-card-bg-color); border: 1px solid var(--global-divider-color); border-left: 4px solid var(--global-theme-color); border-radius: 8px; padding: 1.5rem; margin: 2rem 0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.08); position: relative; transition: box-shadow 0.3s ease; "> <div style=" position: absolute; top: 6px; right: 10px; font-size: 1.6rem; opacity: 0.75; transform: rotate(8deg); pointer-events: none; ">🤦‍♀️</div> <div style=" padding-right: 2.5rem; font-size: 1.25rem; font-weight: 600; line-height: 1.4; color: var(--global-text-color); "> While I now felt silly for diving so deep on an already-fixed bug, the process was still very fun, educational, and so worth the effort.<br/><br/>In hindsight, I maybe could've tried upgrading PyTorch earlier...<br/><br/> ...But as it turns out, <code style="background: var(--global-code-bg-color); color: var(--global-theme-color); padding: 0.2rem 0.4rem; border-radius: 4px; font-size: 0.9em;">the story wasn't over just yet!</code> </div> </div> <h3 id="the-pattern-strikes-again">The Pattern Strikes Again</h3> <p>While writing this up, I added some more tests for my kernel fix to confirm it really worked, and one of the tests failed! I looked into it more and realized I’d stumbled upon <strong>the same failure pattern</strong> in the <code class="language-plaintext highlighter-rouge">random_</code> operation (in the most up-to-date PyTorch this time!)</p> <p><strong>Turns out, all random in-place operations</strong> (<code class="language-plaintext highlighter-rouge">normal_</code>, <code class="language-plaintext highlighter-rouge">uniform_</code>, <code class="language-plaintext highlighter-rouge">exponential_</code>, <code class="language-plaintext highlighter-rouge">random_</code>, <code class="language-plaintext highlighter-rouge">bernoulli_</code>) <strong>silently fail when called on non-contiguous tensors on MPS</strong>.</p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">).</span><span class="n">T</span>  <span class="c1"># Non-contiguous
</span><span class="n">x</span><span class="p">.</span><span class="nf">normal_</span><span class="p">()</span>  <span class="c1"># Should fill with random values
</span><span class="nf">print</span><span class="p">(</span><span class="n">x</span><span class="p">.</span><span class="nf">max</span><span class="p">())</span>  <span class="c1"># Prints 0.0 - the operation silently failed!
</span></code></pre></div></div> <p>Yet again, the operations complete without error, but the tensor remains unchanged—the kernel computes random values into a temporary contiguous buffer but never copies them back.</p> <p>Having just traced through this exact bug pattern, I recognized it immediately and knew exactly how to fix it. Filed an <a href="https://github.com/pytorch/pytorch/issues/165257">Issue</a> and made a <a href="https://github.com/pytorch/pytorch/pull/165267">PR</a> applying the same solution.</p> <p>I suspect there are other similar bugs lying around, as none of these fixes actually address the underlying quirk that <strong>the Placeholder abstraction itself is problematic when used with output tensors</strong>.</p> <p>The core issue: Placeholder’s constructor silently creates a temporary contiguous copy for non-contiguous tensors, but it has no way to know if it’s wrapping an input (where the copy is fine- we just read from it) or an output (where the copy is broken- results get written to it then lost). This means <strong>every single operation that uses Placeholder for outputs must manually implement the same workaround pattern</strong> or else it has this silent failure:</p> <div class="language-objc highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1">// Every MPS operation must remember to do this:</span>
<span class="n">bool</span> <span class="n">needsCopy</span> <span class="o">=</span> <span class="o">!</span><span class="n">output</span><span class="p">.</span><span class="n">is_contiguous</span><span class="p">();</span>
<span class="n">Tensor</span> <span class="n">temp</span> <span class="o">=</span> <span class="n">needsCopy</span> <span class="p">?</span> <span class="n">at</span><span class="p">:</span><span class="o">:</span><span class="n">empty</span><span class="p">(...)</span> <span class="o">:</span> <span class="n">output</span><span class="p">;</span>
<span class="k">@autoreleasepool</span> <span class="p">{</span>
    <span class="n">Placeholder</span> <span class="n">p</span><span class="p">(</span><span class="n">temp</span><span class="p">);</span>
    <span class="n">runGraph</span><span class="p">();</span>
<span class="p">}</span>
<span class="k">if</span> <span class="p">(</span><span class="n">needsCopy</span><span class="p">)</span>
  <span class="n">output</span><span class="p">.</span><span class="n">copy_</span><span class="p">(</span><span class="n">temp</span><span class="p">);</span>
</code></pre></div></div> <p>This is a leaky abstraction<d-footnote>A "leaky abstraction" is when an abstraction that's supposed to hide implementation details forces you to understand and work around those details anyway. Placeholder is supposed to abstract Metal buffer management, but its internal copying leaks through, forcing every caller to manually handle non-contiguous outputs. See Joel Spolsky's <a href="https://www.joelonsoftware.com/2002/11/11/the-law-of-leaky-abstractions/">The Law of Leaky Abstractions</a> for the canonical explanation.</d-footnote>: the internal implementation detail that “Placeholder makes temporary copies” has leaked out to every caller, making it each operation’s responsibility to work around. A better design would be:</p> <ul> <li>Placeholder knows input vs output: Pass a flag so Placeholder can handle the copy-back itself</li> <li>Separate abstractions: Different wrapper types for inputs (InputPlaceholder) and outputs (OutputPlaceholder)</li> <li>Make the temporary explicit: Don’t hide the copy inside Placeholder—make callers explicitly create and manage contiguous temporaries (this is what I used in the fixes for addcmul_/addcdiv_/the random ops)</li> </ul> <p>The good news: macOS 15+ Metal now handles non-contiguous tensors natively, making this entire issue obsolete for newer systems. But for anyone on older macOS versions or maintaining PyTorch’s MPS backend, this abstraction continues to cause issues.</p> <p>So ideally, the Placeholder class would be redesigned to handle output tensors correctly by default, but given that the hardware is moving to handle this natively anyway, the pragmatic fix is probably just to audit and patch the remaining operations using the established pattern.</p> <h3 id="practical-takeaways-for-your-code">Practical Takeaways for Your Code</h3> <p><strong>Performance Considerations</strong></p> <p>Even with the code fixes, non-contiguous tensors on MPS involve: Allocate temporary buffer -&gt; Copy to contiguous layout -&gt; Compute -&gt; Copy back. Making tensors contiguous once at initialization avoids thousands of copies during training! And even if your OS can avoid making this temporary contiguous copy, it is still slower to operate on non-contiguous memory if you will be using it many times.</p> <p><strong>When to Call <code class="language-plaintext highlighter-rouge">.contiguous()</code></strong></p> <div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="c1"># When to call .contiguous() - General Principles
</span>
<span class="c1"># 1. After operations that change memory layout:
</span><span class="n">x</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">.</span><span class="nf">transpose</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>  <span class="c1"># Non-contiguous
</span><span class="n">x</span> <span class="o">=</span> <span class="n">tensor</span><span class="p">.</span><span class="nf">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>          <span class="c1"># Might fail if non-contiguous!
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nf">contiguous</span><span class="p">().</span><span class="nf">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>  <span class="c1"># Safe
</span>
<span class="c1"># 2. Before operations that might not handle strides:
# - Custom CUDA/Metal kernels  
# - Newer backend features
# - Operations that failed mysteriously on certain devices
</span>
<span class="c1"># 3. For performance on repeated operations:
</span><span class="n">weights</span> <span class="o">=</span> <span class="nf">init_weights</span><span class="p">().</span><span class="n">T</span>   <span class="c1"># Used in every forward pass
</span><span class="n">weights</span> <span class="o">=</span> <span class="n">weights</span><span class="p">.</span><span class="nf">contiguous</span><span class="p">()</span>  <span class="c1"># Pay copy cost once, not every iteration
</span>
<span class="c1"># But don't overuse it!
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">y</span>  <span class="c1"># Creates new contiguous tensor anyway
</span><span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="p">.</span><span class="nf">contiguous</span><span class="p">()</span>  <span class="c1"># Unnecessary copy!
</span></code></pre></div></div> <p><strong>For MPS specifically:</strong> If on macOS &lt;15, make sure all your parameters are contiguous!</p> <h3 id="what-i-learned">What I Learned</h3> <p><strong>Isolate to specific, measurable symptoms.</strong> The most standard advice and for such good reason. Everything got easier once I had a concrete target: “<code class="language-plaintext highlighter-rouge">exp_avg_sq</code> stays at zero” is infinitely more debuggable than “the loss plateaus mysteriously.” Once I had a specific symptom, I could strip away components and test the minimal case that triggered it.</p> <p><strong>When debugging tensor issues, check metadata not just values.</strong> I was checking for NaNs, visualizing weights, inspecting gradients—all focused on the numbers inside tensors. The actual problem was the tensor’s <em>stride pattern</em>. Device, dtype, contiguity, memory layout—these aren’t just performance details, they can cause silent correctness bugs. <code class="language-plaintext highlighter-rouge">tensor.is_contiguous()</code> is now part of my debugging checklist.</p> <p><strong>When I’m confused, I might have changed two things—or there might be two bugs.</strong> Switching to fp64 “fixed” it, but I’d also switched from MPS to CPU. Untangling that revealed the real culprit. And <code class="language-plaintext highlighter-rouge">exp_avg_sq</code> staying zero <em>should</em> have caused explosions, but the parameter update <em>also</em> failed—one bug perfectly masked the other.</p> <p><strong>Documentation makes more sense when I need it.</strong> I’d skimmed PyTorch internals docs before and nothing stuck—dispatch systems, stride patterns, kernel implementations all felt overwhelming. But once I <em>had</em> to understand how <code class="language-plaintext highlighter-rouge">addcmul_</code> dispatches to MPS kernels, everything clicked. Now PyTorch feels less like a black box. And when I hit the random ops bug weeks later, I wasn’t intimidated—I knew exactly how to trace through the source.</p> <p><strong>Explore the system before exploring the code.</strong> When I needed to debug <code class="language-plaintext highlighter-rouge">addcmul_out_mps</code> in unfamiliar MPS code, I ran experiments first: which operations fail? Do they run at all? What triggers the bug? By the time I opened the source, I knew to compare <code class="language-plaintext highlighter-rouge">addcmul_</code> (broken) against <code class="language-plaintext highlighter-rouge">mul_</code> (working) and scan specifically for differences in output handling. Without that context, I’d have been lost in Objective-C++ with no idea what mattered. Also LLMs were very helpful with unfamiliar constructs like <code class="language-plaintext highlighter-rouge">MPSGraphTensor</code> or <code class="language-plaintext highlighter-rouge">@autoreleasepool</code>, although they’re still less reliable with MPS than more documented frameworks.</p> <p><strong>Write post-mortems– even for yourself.</strong> Forcing myself to explain <em>why</em> I tried each debugging step was as educational as the original investigation. It’s like experience replay in RL: you explore many failed paths, find one that works, then replay that successful trajectory to reinforce the policy. Writing it down builds pattern recognition—when I’m in “situation A”, what hypotheses are worth trying? I’ve written lower-effort debugging debriefs before, but making this one readable for an external audience forced me to articulate why each step made sense, deepening my understanding of what actually worked.</p> <p>What started as a frustrating research roadblock became a surprisingly fun &amp; educational detour. It forced a closer look at things normally taken for granted: Adam’s momentum mechanics, stride patterns, kernel dispatch. Understanding why each operation behaved differently revealed more about PyTorch’s architecture than typical usage ever does.</p> <hr/> <p>If you made it this far, thanks for joining! Hope you had fun and/or learned something &amp; happy debugging!</p> <p>Special thanks to <a href="https://x.com/nickevanjoseph">Nicholas Joseph</a>, <a href="https://www.benkuhn.net/">Ben Kuhn</a>, <a href="https://blog.nelhage.com/">Nelson Elhage</a> and <a href="https://www.alextamkin.com/">Alex Tamkin</a> for giving feedback on this 💜</p>]]></content><author><name>Elana Simon</name></author><summary type="html"><![CDATA[a loss plateau that looked like my mistake turned out to be a PyTorch bug. tracking it down meant peeling back every layer of abstraction, from optimizer internals to GPU kernels.]]></summary></entry><entry><title type="html">The Illustrated AlphaFold</title><link href="https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/" rel="alternate" type="text/html" title="The Illustrated AlphaFold"/><published>2024-07-10T00:00:00+00:00</published><updated>2024-07-10T00:00:00+00:00</updated><id>https://elanapearl.github.io/blog/2024/the-illustrated-alphafold</id><content type="html" xml:base="https://elanapearl.github.io/blog/2024/the-illustrated-alphafold/"><![CDATA[<h1 id="introduction">Introduction</h1> <h3 id="who-should-read-this">Who should read this</h3> <p>Do you want to understand exactly how AlphaFold3 works? The architecture is quite complicated and the description in the paper can be overwhelming, so we made a much more friendly (but just as detailed!) visual walkthrough.</p> <p>This is mostly written for an ML audience and multiple points assume familiarity with the steps of attention. If you’re rusty, see Jay Alammar’s <a href="https://jalammar.github.io/illustrated-transformer/">The Illustrated Transformer</a> for a thorough visual explanation. That post is one of the best explanations of a model architecture at the level of individual matrix operations and also the inspiration for the diagrams and naming.</p> <p>There are already many great explanations of the motivation for protein structure prediction, the CASP competition, model failure modes, debates about evaluations, implications for biotech, etc. so we don’t focus on any of that. Instead we explore the <em>how</em>.</p> <p><em>How are these molecules represented in the model and what are all of the operations that convert them into a predicted structure?</em></p> <p>This is probably more exhaustive than most people are looking for, but if you want to understand all the details and you like learning via diagrams, this should help :)</p> <h3 id="architecture-overview">Architecture Overview</h3> <p>We’ll start by pointing out that goals of the model are a bit different than previous AlphaFold models: instead of just predicting the structure of individual protein sequences (AF2) or protein complexes (AF-multimeter), it predicts the structure of a protein, optionally complexed with other proteins, nucleic acids, or small molecules, all from sequence alone. So while previous AF models only had to represent sequences of standard amino acids, AF3 has to represent more complex input types, and thus there is a more complex featurization/tokenization scheme. Tokenization is described in its own section, but for now just know that when we say “token” it either represents a single amino acid (for proteins), nucleotide (for DNA/RNA), or an individual atom if that atom is not part of a standard amino acid/nucleotide.</p> <div class="l-body"> <p style="text-align: center;"><b>Interactive Table of Contents</b></p> <figure class="image-figure"> <div class="image-container"> <img src="/assets/img/af3_post/full_arch_for_labeling.png" alt="Full Architecture" class="img-fluid rounded z-depth-1" usemap="#image-map"/> <div class="hover-overlay" id="overlay-1"></div> <div class="hover-overlay" id="overlay-2"></div> <div class="hover-overlay" id="overlay-3"></div> <div class="hover-overlay" id="overlay-4"></div> <div class="hover-overlay" id="overlay-5"></div> <div class="hover-overlay" id="overlay-6"></div> <div class="hover-overlay" id="overlay-7"></div> <div class="hover-overlay" id="overlay-8"></div> <div class="hover-overlay" id="overlay-9"></div> <div class="hover-overlay" id="overlay-10"></div> <div class="hover-overlay" id="overlay-11"></div> <div class="hover-overlay" id="overlay-12"></div> <div class="hover-overlay" id="overlay-13"></div> <div class="hover-overlay" id="overlay-14"></div> </div> <map name="image-map"> <area shape="rect" alt="Tokenization" title="Tokenization" coords="0.58,50.38,8.35,74.06" href="#tokenization" data-target="overlay-1"/> <area shape="rect" alt="Retrieval" title="Retrieval" coords="8.35,3.26,22.16,40.98" href="#retrieval-create-msa-and-templates" data-target="overlay-2"/> <area shape="rect" alt="Create atom-level representations" title="Create atom-level representations" coords="8.35,41.48,22.16,68.92" href="#create-atom-level-representations" data-target="overlay-3"/> <area shape="rect" alt="Atom Transformer" title="Atom Transformer" coords="22.23,23.43,28.16,69.80" href="#update-atom-level-representations-atom-transformer" data-target="overlay-4"/> <area shape="rect" alt="Atom-Level to Token-Level" title="Atom-Level to Token-Level" coords="28.16,40.23,31.89,69.80" href="#aggregate-atom-level--token-level" data-target="overlay-5"/> <area shape="rect" alt="Template module" title="Template module" coords="34.56,34.59,40.56,56.02" href="#template-module" data-target="overlay-6"/> <area shape="rect" alt="MSA module" title="MSA module" coords="42.15,34.21,48.16,56.02" href="#msa-module" data-target="overlay-7"/> <area shape="rect" alt="Pairformer" title="Pairformer" coords="49.67,34.21,63.52,69.80" href="#pairformer-module" data-target="overlay-8"/> <area shape="rect" alt="Diffusion module" title="Diffusion module" coords="68.37,41.48,90.20,70.18" href="#diffusion-module" data-target="overlay-9"/> <area shape="rect" alt="Confidence and Loss" title="Confidence and loss" coords="90.96,20.05,99.71,47.37" href="#4-loss-function-and-other-training-details" data-target="overlay-10"/> <area shape="rect" alt="Other training details" title="Other training details" coords="90.96,47.37,99.71,66.17" href="#other-training-details" data-target="overlay-11"/> <area shape="rect" alt="Input preparation" title="Input preparation" coords="8.35,87.84,23.46,97.74" href="#1-input-preparation" data-target="overlay-12"/> <area shape="rect" alt="Representation learning" title="Representation learning" coords="38.76,87.84,58.97,97.49" href="#2-representation-learning" data-target="overlay-13"/> <area shape="rect" alt="Structure prediction" title="Structure prediction" coords="71.69,88.22,89.37,96.74" href="#3-structure-prediction" data-target="overlay-14"/> </map> </figure> <style>.image-figure{margin-bottom:1rem}.image-container{position:relative;width:100%}.hover-overlay{position:absolute;background-color:rgba(136,0,255,0.3);opacity:0;transition:opacity .3s;pointer-events:none}.figure-caption{text-align:center;margin-top:.5rem}</style> <script>document.addEventListener("DOMContentLoaded",function(){function t(){var t=e.width,r=e.height;n.forEach(function(e,n){var a=e.getAttribute("coords").split(",").map(function(e,n){return n%2==0?Math.round(e*t/100):Math.round(e*r/100)});e.coords=a.join(",");var i=o[n];i.style.left=a[0]/t*100+"%",i.style.top=a[1]/r*100+"%",i.style.width=(a[2]-a[0])/t*100+"%",i.style.height=(a[3]-a[1])/r*100+"%"})}var e=document.querySelector(".image-container img"),n=document.querySelector('map[name="image-map"]').querySelectorAll("area"),o=document.querySelectorAll(".hover-overlay");t(),window.addEventListener("resize",t),n.forEach(function(t){t.addEventListener("mouseover",function(){var t=this.getAttribute("data-target");document.getElementById(t).style.opacity="1"}),t.addEventListener("mouseout",function(){var t=this.getAttribute("data-target");document.getElementById(t).style.opacity="0"})})});</script> <div class="caption">Full architecture. If you click on any part of the architecture, it will take you to that section of the post. If you resize the page, you might need to refresh to keep the interactive part working. (Diagram modified from AF3 paper)</div> </div> <div class="l-gutter"> Throughout the post, we highlight where you are in this diagram so you don't get lost! </div> <p>The model can be broken down into 3 main sections:</p> <ol> <li><a href="#1-input-preparation"><strong>Input Preparation</strong></a> The user provides sequences of some molecules to predict structures for and these need to be embedded into numerical tensors. Furthermore, the model retrieves a collection of other molecules that are presumed to have similar structures to the user-provided molecules. The input preparation step identifies these molecules and also embeds these as their own tensors.</li> <li><a href="#2-representation-learning"><strong>Representation learning</strong></a> Given the Single and Pair tensors created in section 1, we use many variants of attention to update these representations.</li> <li><a href="#3-structure-prediction"><strong>Structure prediction</strong></a> We use these improved representations, and the original inputs created in section 1 to predict the structure using conditional diffusion.</li> </ol> <div class="l-gutter"> Skip to a specific section by via its name here or the by clicking the relevant part of the architecture in the diagram above. </div> <p>We also have additional sections describing 4. <a href="#4-loss-function-and-other-training-details"><strong>the loss function, confidence heads, and other relevant training details</strong></a> and 5. <a href="#ml-musings"><strong>some thoughts on the model from an ML trends perspective</strong></a>.</p> <h3 id="notes-on-the-variables-and-diagrams">Notes on the variables and diagrams</h3> <p>Throughout the model a protein complex is represented in two primary forms: the “single” representation which represents all the tokens in our protein complex, and a “pair” representation which represents the relationships (e.g. distance, potential interactions) between all pairs of amino acids / atoms in the complex. Each of these can be represented at an atom-level or a token-level, and will always be shown with these names (as established in the AF3 paper) and colors:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/single_and_pair_rep-480.webp 480w,/assets/img/af3_post/single_and_pair_rep-800.webp 800w,/assets/img/af3_post/single_and_pair_rep-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/single_and_pair_rep.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p> </p> </div> <ul> <li>The diagrams abstract away the model weights and only visualize how the shapes of activations change</li> <li>The activation tensors are always labeled with the dimension names used in the paper and the sizes of the diagrams vaguely aim to follow when these dimensions grow/shrink. <d-footnote>The hidden dimension names usually start with "c" for "channel". For reference the main dimensions used are c<sub>z</sub>=128, c<sub>m</sub>=64, c<sub>atom</sub>=128, c<sub>atompair</sub>=16, c<sub>token</sub>=768, c<sub>s</sub>=384.</d-footnote></li> <li>Whenever possible, the names above the tensors in this (and every) diagram match the names of the tensors use in the AF3 supplement. Typically, a tensor maintains its name as it goes through the model. However, in some cases, we use different names to distinguish between versions of a tensor at different stages of processing. For example, in the atom-level single representation, <strong><span style="color: #A056A7;">c</span></strong> represents the initial atom-level single representation while <strong><span style="color: #A056A7;">q</span></strong> represents the updated version of this representation as it progresses through the Atom Transformer.</li> <li>We also ignore most of the LayerNorms for simplicity but they are used <em>everywhere</em>.</li> </ul> <hr/> <h1 id="1-input-preparation">1. Input Preparation</h1> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/input_prep-480.webp 480w,/assets/img/af3_post/input_prep-800.webp 800w,/assets/img/af3_post/input_prep-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/input_prep.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>The actual input a user provides to AF3 is the sequence of one protein and optionally additional molecules. The goal of this section is to convert these sequences into a series of 6 tensors that will be used as the input to the main trunk of the model as outlined in this diagram. These tensors are <strong><span style="color: #F5ACFB;">s</span></strong>, our token-level single representation, <strong><span style="color: #7CC9F4;">z</span></strong>, our token-level pair representation, <strong><span style="color: #A056A7;">q</span></strong>, our atom-level single representation, <strong><span style="color: #087CBE;">p</span></strong>, our atom-level pair representation, <strong><span style="color: #FDC38D;">m</span></strong>, our MSA representation, and <strong><span style="color: #2EAF88;">t</span></strong>, our template representation.</p> <p>This section contains:</p> <ul> <li><a href="#tokenization"><strong>Tokenization</strong></a> describes how molecules are tokenized and clarifyies the difference between atom-level and token-level</li> <li><a href="#retrieval-create-msa-and-templates"><strong>Retrieval (Create MSA and Templates)</strong></a> expalains why and how we include additional inputs to the model. It creates our MSA (<strong><span style="color: #FDC38D;">m</span></strong>) and structure templates (<strong><span style="color: #2EAF88;">t</span></strong>).</li> <li><a href="#create-atom-level-representations"><strong>Create Atom-Level Representations</strong></a> creates our first atom-level representations <strong><span style="color: #A056A7;">q</span></strong> (single) and <strong><span style="color: #087CBE;">p</span></strong> (pair) and includes information about generated conformers of the molecules.</li> <li><a href="#update-atom-level-representations-atom-transformer"><strong>Update Atom-Level Representations (Atom Transformer)</strong></a> is the main “Input Embedder” block, also called the “Atom Transformer”, which gets repreated 3 times and updates the atom-level single representation (<strong><span style="color: #A056A7;">q</span></strong>). The building blocks introduced here (<a href="#1-adaptive-layernorm"><strong>Adaptive LayerNorm</strong></a>, <a href="#2-attention-with-pair-bias"><strong>Attention with Pair Bias</strong></a>, <a href="#3-conditioned-gating"><strong>Conditioned Gating</strong></a>, and <a href="#4-conditioned-transition"><strong>Conditioned Transition</strong></a>) are also relevant later in the model.</li> <li><a href="#aggregate-atom-level--token-level"><strong>Aggregate Atom-Level -&gt; Token-Level</strong></a> takes our atom-level representations (<strong><span style="color: #A056A7;">q</span></strong>, <strong><span style="color: #087CBE;">p</span></strong>) and aggregates all the atoms that at part of multi-atom tokens to create token-level representations <strong><span style="color: #F5ACFB;">s</span></strong> (single) and <strong><span style="color: #7CC9F4;">z</span></strong> (pair) and includes information from the MSA (<strong><span style="color: #FDC38D;">m</span></strong>) and any user-provided information about known bonds that involve ligands.</li> </ul> <h2 id="tokenization">Tokenization</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/tokenize-480.webp 480w,/assets/img/af3_post/summaries/tokenize-800.webp 800w,/assets/img/af3_post/summaries/tokenize-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/tokenize.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/tokens-480.webp 480w,/assets/img/af3_post/tokens-800.webp 800w,/assets/img/af3_post/tokens-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/tokens.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>In AF2, as the model only represented proteins with a fixed set of amino acids, each amino acid was represented with its own token. This is maintained in AF3, but additional tokens are also introduced for the additional molecule types that AF3 can handle:</p> <ul> <li>Standard amino acid: 1 token (as per AF2)</li> <li>Standard nucleotide: 1 token</li> <li>Non-standard amino acids or nucleotides (methylated nucleotide, amino acid with post-translational modification, etc.): 1 token <em>per atom</em></li> <li>Other molecules: 1 token <em>per atom</em></li> </ul> <p>As a result, we can think of some tokens (like those for amino acids) as being associated with multiple atoms, while other tokens (like those for an atom in a ligand) are associated with only a single atom. So, while a protein with 35 standard amino acids (likely &gt; 600 atoms) would be represented by 35 tokens, a ligand with 35 atoms would also be represented by 35 tokens.</p> <h2 id="retrieval-create-msa-and-templates">Retrieval (Create MSA and Templates)</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/retrieval-480.webp 480w,/assets/img/af3_post/summaries/retrieval-800.webp 800w,/assets/img/af3_post/summaries/retrieval-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/retrieval.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <p>One of the key early steps in AF3 is something akin to Retrieval Augmented Generation <a href="https://aws.amazon.com/what-is/retrieval-augmented-generation">RAG</a> in language models. We find similar sequences to our protein and RNA sequences of interest (collected into a multiple sequence alignment, “MSA”), and any structures related to those (called the “templates”), then include them as additional inputs to the model called <strong><span style="color: #FDC38D;">m</span></strong> and <strong><span style="color: #2EAF88;">t</span></strong>, respectively.</p> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/MSA_and_templates-480.webp 480w,/assets/img/af3_post/MSA_and_templates-800.webp 800w,/assets/img/af3_post/MSA_and_templates-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/MSA_and_templates.jpg" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">(Image from AF2)</div> </div> <details> <summary>Why do we want to include MSA and templates?</summary> <p> Versions of the same protein found in different species can be quite structurally and sequentially similar. By aligning these together into a Multiple Sequence Alignment (MSA), we can look at how an individual position in a protein sequence has changed throughout evolution. You can think about an MSA for a given protein as a matrix where each row is the sequence of the analogous protein from a different species. It has been shown that the conservation patterns found along the column of a specific position in the protein can reflect how critical it is for that position to have certain amino acids present, and the relationships between different columns reflect relationships between amino acids (i.e. if two amino acids are physically interacting, the changes in their amino acids will likely be correlated across evolution). Thus, MSAs are often used to enrich representations of single proteins.</p> <p> Similarly, if any of these proteins have known structures, those are also likely to inform the structure of this protein. Instead of searching for full structures, only individual chains of the proteins are used. This resembles the practice of homology modeling, in which the structure of a query protein is modeled based on templates from known protein structures that are presumed to be similar. </p> </details> <details> <summary>So how are these sequences and structures retrieved?</summary> First, a genetic search is done searching for any protein or RNA chains that resemble any input protein or RNA chains. This does not involve any training and relies upon existing Hidden Markov Model (HMM) based methods<d-footnote>Specifically, they use jackhmmer, HHBlits, and nhmmer</d-footnote> to scan multiple protein databases and RNA databases for relevant hits. Then these sequences are aligned to each other to construct an MSA with N<sub>MSA</sub> sequences. As the computational complexity of the model scales with N<sub>MSA</sub> they limit this to N<sub>MSA</sub> &lt; 2<sup>14</sup>. Typically, MSAs are constructed from individual protein chains but, as described in <a href="https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2.full.pdf">AF-multimer</a>, instead of just concatenating the separate MSAs together into a block diagonal matrix, certain chains from the same species can be 'paired' as described <a href="https://www.biorxiv.org/content/10.1101/240754v3.full.pdf">here</a>. This way, the MSA does not have to be as large and sparse, and evolutionary information can be learned about relationships between chains. <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/multi_chain_MSA-480.webp 480w,/assets/img/af3_post/multi_chain_MSA-800.webp 800w,/assets/img/af3_post/multi_chain_MSA-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/multi_chain_MSA.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> Then, for each protein chain, they use another HMM-based method (hmmsearch) to find sequences in the Protein Data Bank (PDB) that resemble the constructed MSA. The highest quality structures are selected and up to 4 of these are sampled to be included as "templates". </details> <p>The only new part of these retrieval steps compared to AF-multimer is the fact that we now do this retrieval for RNA sequences in addition to protein sequences. Note that this is not traditionally called “retrieval” as the practice of using structural templates to guide protein structure modeling has been common practice in the field of <a href="https://en.wikipedia.org/wiki/Homology_modeling">homology modeling</a> long before the term RAG existed. However, even though AlphaFold doesn’t explicitly refer to this process as retrieval, it does quite resemble what has now been popularized as RAG.</p> <p><strong>How do we represent these templates?</strong></p> <p>From our template search, we have a 3D structure for each of our templates and information about which tokens are in which chains. First, the euclidean distances between all pairs of tokens in a given template are calculated. For tokens associated with multiple atoms, a representative <span style="color: #0094FF">“center atom”</span> is used to calculate distances. This would be the <span style="color: #0094FF">C<sub>ɑ</sub></span> atom for amino acids and <span style="color: #0094FF">C<sup>1</sup>’</span> atom for standard nucleotides.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/center_atoms-480.webp 480w,/assets/img/af3_post/center_atoms-800.webp 800w,/assets/img/af3_post/center_atoms-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/center_atoms.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">Highlighting <span style="color: #0094FF">"center atoms"</span> in single-token building blocks</div> </div> <p>This generates a N<sub>token</sub> x N<sub>token</sub> matrix for each template. However, instead of representing each distance as a numerical value, the distances get discretized into a “distogram” (a histogram of distances).<d-footnote>Specifically, the values are binned into 38 bins between 3.15A and 50.75A and there's 1 additional bin for any distances bigger than that.</d-footnote></p> <p>To each distogram, we then append metadata about which chain <d-footnote>In molecular complexes, a chain refers to a distinct molecule or part of a molecule. This can be a protein chain (a sequence of amino acids), a DNA or RNA chain (a sequence of nucleotides), or other biomolecules. AlphaFold uses chain information to differentiate between parts of a complex, helping it predict how these parts interact to form the overall structure</d-footnote> each token belongs to, whether this token was resolved in the crystal structure, and information about local distances within each amino acid. We then mask out this matrix such that we only look at distances within each chain (e.g., we ignore the distances between chain A and chain B) as they “make no attempt to select templates… to gain information about inter-chain interactions”‘<d-footnote> It is not specified why, but note that while there is no inter-chain interactions in the templates, they do incorporate them the MSA construction. </d-footnote>.</p> <h2 id="create-atom-level-representations">Create Atom-Level Representations</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/make-atom-level-480.webp 480w,/assets/img/af3_post/summaries/make-atom-level-800.webp 800w,/assets/img/af3_post/summaries/make-atom-level-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/make-atom-level.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_atom_rep-480.webp 480w,/assets/img/af3_post/make_atom_rep-800.webp 800w,/assets/img/af3_post/make_atom_rep-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_atom_rep.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>To create <strong><span style="color: #A056A7;">q</span></strong>, our atom-level single representation, we need to pull all our atom-level features. The first step is to calculate a “reference conformer” for each amino acid, nucleotide, and ligand. While we do not yet know the structure of the entire complex, we have strong priors on the local structures of each individual component. A conformer (short for <a href="https://www.sciencedirect.com/topics/chemistry/conformational-isomer#:~:text=Conformations%20or%20conformational%20isomers%20have,the%20same%20configuration%2C%20if%20chiral.">conformational isomer</a>) is a 3D arrangement of atoms in a molecule that is generated by sampling rotations about single bonds. Each amino acid has a “standard” conformer which is just one of the low-energy conformations this amino acid can exist in, which can be retrieved through a look-up. However, each small molecule requires its own conformation generation. These are generated with <a href="https://rdkit.org/docs/RDKit_Book.html#conformer-generation">RDKit’s ETKDGv3</a>, an algorithm that combines experimental data and torsion angle preferences to produce 3D conformers.</p> <p>Then we concatenate the information from this conformer (relative location) with each atom’s charge, atomic number, and other identifiers. Matrix <strong><span style="color: #A056A7;">c</span></strong> stores this information for all the atoms in our sequences<d-footnote>In the AF3 supplement, the atom-level matrices (<b><span style="color: #A056A7;">c</span></b>, and <b><span style="color: #A056A7;">q</span></b>) are typically referred to in their vector forms (<i>e.g.</i> <b><span style="color: #A056A7;">c<sub>l</sub></span></b> or <b><span style="color: #A056A7;">c<sub>m</sub></span></b>), where l and m are used to index atoms.</d-footnote>. We then use <strong><span style="color: #A056A7;">c</span></strong> to initialize our atom-level pair representation <strong><span style="color: #087CBE;">p</span></strong> to store the relative distances between atoms. Because we only know reference distances within each token, we use a mask (<strong>v</strong>) to ensure this initial distance matrix only represents distances we’ve calculated in the conformer generation. We also include a linear embedding of the inverse square of the distances, add to it a projection of <strong><span style="color: #A056A7;">c<sub>l</sub></span></strong> and <strong><span style="color: #A056A7;">c<sub>m</sub></span></strong>, and update this with a few more linear layers with residual connections<d-footnote>The AF3 paper doesn't really clarify why this additional inverse distance step is performed or contain ablations for their effect of it; so, as with many of the steps we will discuss, we can only assume they were empirically shown to be useful.</d-footnote><d-footnote>In the AF3 supplement, the <b><span style="color: #087CBE;">p</span></b> tensor is typically referred to in its vector form <b><span style="color: #087CBE;">p<sub>l,m</sub></span></b> (where this represents the relationship between atom l and atom m).</d-footnote>.</p> <p>Finally, we make a copy of our atom-level single representation, calling this copy <strong><span style="color: #A056A7;">q</span></strong>. This matrix <strong><span style="color: #A056A7;">q</span></strong> is what we will be updating going forward, but <strong><span style="color: #A056A7;">c</span></strong> does get saved and used later.</p> <h2 id="update-atom-level-representations-atom-transformer">Update Atom-Level Representations (Atom Transformer)</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/atom-transformer-480.webp 480w,/assets/img/af3_post/summaries/atom-transformer-800.webp 800w,/assets/img/af3_post/summaries/atom-transformer-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/atom-transformer.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/atom_transformer-480.webp 480w,/assets/img/af3_post/atom_transformer-800.webp 800w,/assets/img/af3_post/atom_transformer-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/atom_transformer.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Having generated <strong><span style="color: #A056A7;">q</span></strong> (representation of all the atoms) and <strong><span style="color: #087CBE;">p</span></strong> (representation of each pair of atoms), we now want to update these representations based on other atoms nearby. Anytime AF3 applies attention at the atom-level, we use a module called the Atom Transformer. The atom transformer is a series of blocks that use attention to update <strong><span style="color: #A056A7;">q</span></strong> using both <strong><span style="color: #087CBE;">p</span></strong> and the original representation of <strong><span style="color: #A056A7;">q</span></strong> called <strong><span style="color: #A056A7;">c</span></strong>. As <strong><span style="color: #A056A7;">c</span></strong> does not get updated by the Attention Transformer, it can be thought of as a residual connection to the starting representation.</p> <p>The Atom Transformer mostly follows a standard transformer structure using layer norm, attention, then an MLP transition. However, each step has been adapted to include additional input from <strong><span style="color: #A056A7;">c</span></strong> and <strong><span style="color: #087CBE;">p</span></strong> (including a secondary input here is sometimes referred to as “conditioning”.) There is also a ‘gating’ step between the attention and MLP blocks. Going through each of these 4 steps in more detail:</p> <h3 id="1-adaptive-layernorm">1. Adaptive LayerNorm</h3> <div class="l-body"> <div class="row"> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/standard_ln-480.webp 480w,/assets/img/af3_post/standard_ln-800.webp 800w,/assets/img/af3_post/standard_ln-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/standard_ln.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/adaptive_ln-480.webp 480w,/assets/img/af3_post/adaptive_ln-800.webp 800w,/assets/img/af3_post/adaptive_ln-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/adaptive_ln.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>Adaptive LayerNorm (AdaNorm) is a variant of LayerNorm with one simple extension. Recall that for a given input matrix, traditional LayerNorm learns two parameters (a scaling factor gamma and a bias factor beta) that adjust the mean and standard deviation of each of the channels in our matrix. Instead of learning fixed parameters for gamma and beta, AdaNorm learns a function to generate gamma and beta adaptively based on the input matrix. However, instead of generating the parameters based on the input getting re-scaled (in the Atom Transformer this is <strong><span style="color: #A056A7;">q</span></strong>), a secondary input (<strong><span style="color: #A056A7;">c</span></strong> in the Atom Transformer) is used to predict the gamma and beta that re-scale the mean and standard deviation of <strong><span style="color: #A056A7;">q</span></strong>.</p> <h3 id="2-attention-with-pair-bias">2. Attention with Pair Bias</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/atom_attn_w_pair_bias-480.webp 480w,/assets/img/af3_post/atom_attn_w_pair_bias-800.webp 800w,/assets/img/af3_post/atom_attn_w_pair_bias-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/atom_attn_w_pair_bias.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Atom-Level Attention with Pair-Bias can be thought of as an extension of self-attention. Like in self-attention, the queries, keys, and values all come from the same 1D sequence (our single representation, <strong><span style="color: #A056A7;">q</span></strong>). However, there are 3 differences:</p> <ol> <li> <p><strong>Pair-biasing</strong>: after the dot product of the queries and keys are calculated, a linear projection of the pair representation is added as a bias to scale the attention weights. Note that this operation does not involve any information from <strong><span style="color: #A056A7;">q</span></strong> being used to update <strong><span style="color: #087CBE;">p</span></strong>, just one way flow from the pair representation to <strong><span style="color: #A056A7;">q</span></strong>. The reasoning for this is that atoms that have a stronger pairwise relationship should attend to each other more strongly and <strong><span style="color: #087CBE;">p</span></strong> is effectively already encoding an attention map.</p> </li> <li> <p><strong>Gating</strong>: In addition to the queries, keys, and values, we create an additional projection of <strong><span style="color: #A056A7;">q</span></strong> that is passed through a sigmoid, to squash the values between 0 and 1. Our output is multiplied by this “gate” right before all the heads are re-combined. This effectively forces the model to ignore some of what it learned in this attention process. This type of gating appears frequently in AF3 and is discussed more in the ML-musings section. To briefly elaborate, because the model is constantly adding the outputs of each section to the residual stream, this gating mechanism can be thought of as the model’s way to specify what information does or does not get saved in this residual stream. It is presumably named a “gate” after the similar “gates” in LSTM which uses a sigmoid to learn a filter for what inputs get added to the running cell state.</p> </li> <li> <p><strong>Sparse attention</strong>:</p> </li> </ol> <table style="border-collapse: collapse; border: none;"> <tr> <td width="2%" style="border: none;"> </td> <td width="76%" style="border: none;"> Because the number of atoms can be much larger than the number of tokens, we do not run full attention at this step, rather, we use a type of sparse attention (called Sequence-local atom attention) in which the attention is effectively run in local groups where groups of 32 atoms at a time can all attend to 128 other atoms. Sparse attention patterns are more thoroughly described <a href="https://medium.com/@vishal09vns/sparse-attention-dad17691478">elsewhere on the internet</a>. </td> <td width="22%" style="border: none;"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/sparse_attn_pattern-480.webp 480w,/assets/img/af3_post/sparse_attn_pattern-800.webp 800w,/assets/img/af3_post/sparse_attn_pattern-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/sparse_attn_pattern.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </td> </tr> </table> <h3 id="3-conditioned-gating">3. Conditioned Gating</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/conditioned_gating-480.webp 480w,/assets/img/af3_post/conditioned_gating-800.webp 800w,/assets/img/af3_post/conditioned_gating-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/conditioned_gating.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>We apply another gate to our data, but this time the gate is generated from our origin atom-level single matrix, <strong><span style="color: #A056A7;">c</span></strong><d-footnote>As with so many steps, it is unclear why it is done this way and what the benefit of conditioning on the original representation <b><span style="color: #A056A7;">c</span></b> does as opposed to learning the gate from the primary single representation <b><span style="color: #A056A7;">q</span></b></d-footnote>.</p> <h3 id="4-conditioned-transition">4. Conditioned Transition</h3> <p>This step is equivalent to the MLP layers in a transformer, and is called “conditioned” because the MLP is sandwiched in between Adaptive LayerNorm (Step 1 of Atom Transformer) and Conditional Gating (Step 3 of Atom Transformer) which both depend on <strong><span style="color: #A056A7;">c</span></strong>.</p> <p>The only other piece of note in this section is that AF3 uses SwiGLU in the transition block instead of ReLU. The switch from ReLU → SwiGLU happened with AF2 → AF3 and has been a common change in many recent architectures so we visualize it here.</p> <p>With a ReLU-based transition layer (as in AF2), we take the activations, project them up to 4x the size, apply a ReLU, then down-project them back to their original size. When using SwiGLU (in AF3), the input activation creates two intermediate up-projections, one of which goes through a swish non-linearity (improved variant of ReLU), then these are multiplied before down-projecting. The diagram below shows the differences:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/swiglu-480.webp 480w,/assets/img/af3_post/swiglu-800.webp 800w,/assets/img/af3_post/swiglu-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/swiglu.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <h2 id="aggregate-atom-level--token-level">Aggregate Atom-Level → Token-Level</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/atom-to-token-level-480.webp 480w,/assets/img/af3_post/summaries/atom-to-token-level-800.webp 800w,/assets/img/af3_post/summaries/atom-to-token-level-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/atom-to-token-level.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/aggregate_atom_to_token-480.webp 480w,/assets/img/af3_post/aggregate_atom_to_token-800.webp 800w,/assets/img/af3_post/aggregate_atom_to_token-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/aggregate_atom_to_token.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>While the data so far has all been stored at an atom-level, the representation learning section of AF3 from here onwards operates at the token-level. To create these token-level representations, we first project our atom-level representation to a larger dimension (c<sub>atom</sub>=128, c<sub>token</sub>=384). Then, we take the mean over all atoms assigned to the same token. Note that this only applies to the atoms associated with standard amino acids and nucleotides (by taking the mean across all atoms attached to the same token), while the rest remain unchanged<d-footnote>The AF3 paper describes these molecule types as having a representative atom per token (the center atom). Recall that this is the C<sub>α</sub> atom for amino acids and C<sup>1</sup>' atom for standard nucleotides. So while we mostly consider this reduced representation as "token space", we can also think of each token as representing a single atom (either a representative C<sub>α</sub>/C<sup>1</sup>' atom or an individual atom).</d-footnote>.</p> <p>Now that we are working in “token space”, we concatenate our token-level features and statistics from our MSA (where available)<d-footnote>e.g., The amino acid type (dim = 32), distribution of amino acids at this position in our MSA (dim = 32), and the deletion mean at this token (dim = 1) from our MSA. Note that these values will be zero for ligand atoms not associated with an MSA.</d-footnote>. This matrix, <strong><span style="color: #F5ACFB;">s<sup>inputs</sup></span></strong>, having grown a bit from these concatenations, is projected back down to c<sub>token</sub>, and called <strong><span style="color: #F5ACFB;">s<sup>init</sup></span></strong>: the starting representation of our sequence that will be updated in the representation learning section. Note that <strong><span style="color: #F5ACFB;">s<sup>init</sup></span></strong> gets updated in the representation learning section, but <strong><span style="color: #F5ACFB;">s<sup>inputs</sup></span></strong> are saved to be used later in the structure prediction section.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_token_pair-480.webp 480w,/assets/img/af3_post/make_token_pair-800.webp 800w,/assets/img/af3_post/make_token_pair-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_token_pair.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Now that we have created <strong><span style="color: #F5ACFB;">s<sup>init</sup></span></strong>, our initialized single representation, the next step is to initialize our pair representation <strong><span style="color: #7CC9F4;">z<sup>init</sup></span></strong>. The pair representation is a three dimensional tensor, but it’s easiest to think of it as a heatmap-like 2D matrix with an implicit depth dimension of c<sub>z</sub>=128 channels. So, entry <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> of our pair representation is a c<sub>z</sub> dimensional vector meant to store information about the relationship between token i and token j in our token sequence. We have created an analogous atom-level matrix <strong><span style="color: #087CBE;">p</span></strong>, and we follow a similar process here at the token-level.</p> <p>To initialize <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>, we use a linear projection to make the channel dimension of our sequence representation match that of the pair representation (384 → 128) and add the resulting <strong><span style="color: #F5ACFB;">s<sub>i</sub></span></strong> and <strong><span style="color: #F5ACFB;">s<sub>j</sub></span></strong>. To this, we add a relative positional encoding, <strong><span style="color: #087CBE;">p<sub>i,j</sub></span></strong><d-footnote>This encoding consists of a<sup>rel_pos</sup>, a one-hot encoding of the offset of the two token ids in token space (or set to a maximum of 65 if the two tokens are not on the same chain), a<sup>rel_token</sup>, a one-hot encoding of the offset of the two token ids in token space (or set to a maximum of 65 if the tokens are part of different amino acids or nucleotides), and a<sup>rel_chain</sup>, encoding the offset of the two chains the tokens are on. We project this concatenated encoding into the dimensionality of <b><span style="color: #7CC9F4;">z</span></b> too.</d-footnote>. If the user has also specified particular bonds between tokens, those are linearly embedded here and added to that entry in the pair representation.</p> <p>Now we’ve successfully created and embedded all of the inputs that will be used in the rest of our model:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/input_prep_summary-480.webp 480w,/assets/img/af3_post/input_prep_summary-800.webp 800w,/assets/img/af3_post/input_prep_summary-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/input_prep_summary.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>For Step 2, we will set aside the atom-level representations (<strong><span style="color: #A056A7;">c</span></strong>, <strong><span style="color: #A056A7;">q</span></strong>, <strong><span style="color: #087CBE;">p</span></strong>) and focus on updating our token-level representations <strong><span style="color: #F5ACFB;">s</span></strong> and <strong><span style="color: #7CC9F4;">z</span></strong> in the next section (with the help of <strong><span style="color: #FDC38D;">m</span></strong> and <strong><span style="color: #2EAF88;">t</span></strong>).</p> <h1 id="2-representation-learning">2. Representation Learning</h1> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/rep_learning_arch-480.webp 480w,/assets/img/af3_post/rep_learning_arch-800.webp 800w,/assets/img/af3_post/rep_learning_arch-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/rep_learning_arch.jpg" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">(Diagram modified from full AF3 architecture diagram)</div> </div> <p>This section is the majority of the model, often referred to as the “trunk”, as it is where most of the computation is done. We call it the representation learning section of the model, as the goal is to learn improved representations of our token-level “single” (<strong><span style="color: #F5ACFB;">s</span></strong>) and “pair” (<strong><span style="color: #7CC9F4;">z</span></strong>) tensors initialized above. <d-footnote> Recall that we refer to the "single" sequence representations, these are not necessarily the sequence of one protein, but rather the concatenated sequence of all the atoms or tokens in our structure (which could contain multiple separate molecules).</d-footnote></p> <p>This section contains:</p> <ol> <li><strong>Template module</strong> updates <strong><span style="color: #7CC9F4;">z</span></strong> using the structure templates <strong><span style="color: #2EAF88;">t</span></strong></li> <li><strong>MSA module</strong> first updates the MSA <strong><span style="color: #FDC38D;">m</span></strong>, then adds it to the token-level pair representation <strong><span style="color: #7CC9F4;">z</span></strong>. In this section we spend significant time on two operations: <ul> <li><a href="#outer-product-mean">The Outer Product Mean</a> enables <strong><span style="color: #FDC38D;">m</span></strong> to influence <strong><span style="color: #7CC9F4;">z</span></strong></li> <li><a href="#row-wise-gated-self-attention-using-only-pair-bias">MSA Row-wise Gated Self-Attention Using Only Pair Bias</a> updates <strong><span style="color: #FDC38D;">m</span></strong> based on <strong><span style="color: #7CC9F4;">z</span></strong> and is a simplified version of attention with pair-bias (intended for MSAs)</li> </ul> </li> <li><strong>Pairformer</strong> updates <strong><span style="color: #F5ACFB;">s</span></strong> and <strong><span style="color: #7CC9F4;">z</span></strong> with geometry-inspired (triangle) attention. This section mostly describes the triangle operations (used extensively throughout both AF2 and AF3). <ul> <li><a href="#why-look-at-triangles">Why look at triangles?</a> explains some intuition for the triangle operations</li> <li><a href="#triangle-updates">Triangle Updates</a> and <a href="#triangle-attention">Triangle Attention</a> both update <strong><span style="color: #7CC9F4;">z</span></strong> using methods similar to self-attention, but inspired by the triangle inequality</li> <li><a href="#single-attention-with-pair-bias">Single Attention With Pair Bias</a> updates <strong><span style="color: #F5ACFB;">s</span></strong> based on <strong><span style="color: #7CC9F4;">z</span></strong> and is the token-level equivalent of attention with pair-bias (intended for single sequences)</li> </ul> </li> </ol> <p>Each individual block is repeated multiple times, and then the output of the whole section is fed back into itself again as input and the process is repeated (this is called recycling).</p> <h2 id="template-module">Template Module</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/templates-480.webp 480w,/assets/img/af3_post/summaries/templates-800.webp 800w,/assets/img/af3_post/summaries/templates-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/templates.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/template_module-480.webp 480w,/assets/img/af3_post/template_module-800.webp 800w,/assets/img/af3_post/template_module-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/template_module.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Each template (N<sub>templates</sub>=2 in the diagram) goes through a linear projection and is added together with a linear projection of our pair representation (<strong><span style="color: #7CC9F4;">z</span></strong>). This newly combined matrix goes through a series of operations called the Pairformer Stack (described in depth later). Finally, all of the templates are averaged together and go through - you guessed it - another linear layer.<d-footnote>This is both called the template module and template embedder depending on where you look in the AF3 supplement, but they seem to just refer to the same thing.</d-footnote> Interestingly, this last linear layer has a ReLU as the non-linearity which wouldn’t be particularly notable except for the fact that it is one of only two places ReLU is used as the non-linearity in AF3. As always, can only hypothesize as to why this was selected.</p> <h2 id="msa-module">MSA Module</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/msa-480.webp 480w,/assets/img/af3_post/summaries/msa-800.webp 800w,/assets/img/af3_post/summaries/msa-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/msa.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/msa_module-480.webp 480w,/assets/img/af3_post/msa_module-800.webp 800w,/assets/img/af3_post/msa_module-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/msa_module.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">Architecture of MSA Module. {Diagram from AF3}</div> </div> <p>This module greatly resembles what was called “Evoformer” in AF2, and the goal of it is to simultaneously improve the MSA and pair representations. It does a series of operations independently on these two representations then also enables cross-talk between them.</p> <p>The first step is to subsample the rows of the MSA, rather than use all rows of the MSA previously generated (which could be up to 16k), then add a projected version of our single representation to this subsampled MSA.</p> <h3 id="outer-product-mean">Outer Product Mean</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/outer_product_mean-480.webp 480w,/assets/img/af3_post/outer_product_mean-800.webp 800w,/assets/img/af3_post/outer_product_mean-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/outer_product_mean.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Next, we take the MSA representation and incorporate it into the pair representation via the “Outer Product Mean”. Comparing two columns of the MSA reveals information about the relationships between two positions in the sequence (<em>e.g.</em> how correlated are these two positions in the sequence across evolution). For each pair of token indices i,j, we iterate over all evolutionary sequences, taking the outer product of <strong><span style="color: #FDC38D;">m<sub>s,i</sub></span></strong> and <strong><span style="color: #FDC38D;">m<sub>s,j</sub></span></strong>, then averaging these across all the evolutionary sequences. We then flatten this outer product, project it back down, and add this to the pair representation <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> (full details in diagram). While each outer product only compares values <em>within</em> a given sequence <strong><span style="color: #FDC38D;">m<sub>s</sub></span></strong>, when we take the mean of these, that mixes information <em>across</em> sequences. <em>This is the only point in the model where information is shared across evolutionary sequences.</em> This is a significant change to reduce the computational complexity of the Evoformer in AF2.</p> <h3 id="row-wise-gated-self-attention-using-only-pair-bias">Row-wise gated self-attention using only pair bias</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/row_wise_gated_self_attn-480.webp 480w,/assets/img/af3_post/row_wise_gated_self_attn-800.webp 800w,/assets/img/af3_post/row_wise_gated_self_attn-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/row_wise_gated_self_attn.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Having updated the pair representation based on the MSA, the model next updates the MSA based on the pair representation. This specific update pattern is called <strong>row-wise gated self attention <em>using only pair bias</em></strong>, and is a simplified version of <strong>self attention <em>with pair bias</em></strong>, discussed in the Atom Transformer section, applied to every sequence (row) in the MSA independently. It is inspired by attention, but instead of using queries and keys to determine what other positions each token should attend to, we just use the existing relationships between tokens stored in our pair representation <strong><span style="color: #7CC9F4;">z</span></strong>.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/attn_score_from_bias-480.webp 480w,/assets/img/af3_post/attn_score_from_bias-800.webp 800w,/assets/img/af3_post/attn_score_from_bias-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/attn_score_from_bias.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>In the pair representation, each <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> is a vector containing information about the relationship between tokens i and j. When the tensor <strong><span style="color: #7CC9F4;">z</span></strong> gets projected down to a matrix, each <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> vector becomes a scalar that can be used to determine how much token i should attend to token j. After applying row-wise softmax, these are now equivalent to attention scores, which are used to create a weighted average of the values as a typical attention map would.</p> <p>Note that there is no information shared across the evolutionary sequences in the MSA as it is run independently for each row.</p> <h3 id="updates-to-pair-representation">Updates to pair representation</h3> <p>The last step of the MSA module is to update the pair representation through a series of steps referred to as triangle updates and attention. These triangle operations are described below with Pairformer, where they are used again. There are also some transition blocks that use SwiGLU to up/down project the matrix as was done in the Atom Transformer.</p> <h2 id="pairformer-module">Pairformer module</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/pairformer-480.webp 480w,/assets/img/af3_post/summaries/pairformer-800.webp 800w,/assets/img/af3_post/summaries/pairformer-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/pairformer.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/pairformer_module-480.webp 480w,/assets/img/af3_post/pairformer_module-800.webp 800w,/assets/img/af3_post/pairformer_module-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/pairformer_module.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <p>Diagram from AF3 supplement</p> </div> <p>Having used the templates and MSA to update our pair representation, we now ignore them for the rest of the model. Instead, only the updated pair representation (<strong><span style="color: #7CC9F4;">z</span></strong>) and single representation (<strong><span style="color: #F5ACFB;">s</span></strong>) enter the Pairformer and are used to update each other. As the transition blocks have already been described, this section focuses on the Triangle Updates and Triangle Attention, then briefly explains how the Single Attention with Pair Bias differs from the variant described earlier. These triangle-based layers were first introduced in AF2 are one of the pieces that not only remained in AF3, but now are even more present in the architecture, so they get quite a bit of attention.</p> <h3 id="why-look-at-triangles">Why look at triangles?</h3> <p>The guiding principle here is the idea of the triangle inequality: “the sum of any two sides of a triangle is greater than or equal to the third side”. Recall that each <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> in the pair tensor encodes the relationship between positions i and j in the sequence. While it does not literally encode the physical distances between pairs of tokens, let’s think about it for a moment as if it did. If we imagine that each <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> is the distance between two amino acids and we know <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>=1 and <strong><span style="color: #7CC9F4;">z<sub>j,k</sub></span></strong>=1. By the triangle inequality <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong> cannot be larger than 2. Knowing two of the distances gives us a strong belief about what the third distance must be. The goal of triangle updates and triangle attention are to try to encode these geometric constraints into the model.</p> <p>The triangle inequality is not enforced in the model but rather, it is encouraged through ensuring each position <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> is updated by looking at all possible triplets of positions (<strong>i</strong>,<strong>j</strong>,<strong>k</strong>) at a time. So <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> is updated based on <strong><span style="color: #7CC9F4;">z<sub>j,k</sub></span></strong> and <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong> for all other atoms k. Because <strong><span style="color: #7CC9F4;">z</span></strong> represents the complex physical relationship between these tokens, rather than merely their distance, these relationships can be directional. So for <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>, we also want to encourage consistency with <strong><span style="color: #7CC9F4;">z<sub>k,i</sub></span></strong> and <strong><span style="color: #7CC9F4;">z<sub>k,j</sub></span></strong> for all atoms k. If we think of the atoms as a graph, with <strong><span style="color: #7CC9F4;">z</span></strong> as a directed adjacency matrix, it makes sense that AlphaFold calls these “outgoing edges” and “incoming edges”.</p> <p>Consider row i=0 of this adjacency matrix, and let’s say we want to update <strong><span style="color: #7CC9F4;">z<sub>0,2</sub></span></strong>, which has been highlighted in purple. The idea behind the update is that if we know the distances between 0→1 and 2→1, that gives us some constraints on what 0→2 can be. Similarly, if we know the distances between 0→3 and 2→3, this also gives us a constraint on 0→2. This would apply for all atoms k.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/adjacency_matrix-480.webp 480w,/assets/img/af3_post/adjacency_matrix-800.webp 800w,/assets/img/af3_post/adjacency_matrix-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/adjacency_matrix.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>So, in the triangle updates and attention, we effectively look at all directed paths for 3 nodes in this graph (a.k.a triangles, hence the name!).</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/triangle_paths-480.webp 480w,/assets/img/af3_post/triangle_paths-800.webp 800w,/assets/img/af3_post/triangle_paths-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/triangle_paths.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <h3 id="triangle-updates">Triangle Updates</h3> <p>Having carefully looked at the triangle operations from a graph theory perspective, we can see how this is implemented with tensor operations. In the outgoing update, every position <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> in the pair representation gets updated independently based on a weighted combination of the other elements in the same row (<strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>), where the weighting of each <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong> is based on the third element in its outgoing edge triangle (<strong><span style="color: #7CC9F4;">z<sub>j,k</sub></span></strong>).</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/triangle_update_outgoing-480.webp 480w,/assets/img/af3_post/triangle_update_outgoing-800.webp 800w,/assets/img/af3_post/triangle_update_outgoing-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/triangle_update_outgoing.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Practically, we take three linear projections of <strong><span style="color: #7CC9F4;">z</span></strong> (called a, b, and g). To update <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>, we take an element-wise multiplication of <strong>row i from a</strong> and <strong>row j from b</strong>. We then sum over all these rows (different values of k), and gate with our g projection.</p> <div class="l-gutter"> At this point you might notice that gating is used all throughout this architecture! </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/triangle_update_incoming-480.webp 480w,/assets/img/af3_post/triangle_update_incoming-800.webp 800w,/assets/img/af3_post/triangle_update_incoming-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/triangle_update_incoming.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>For the incoming update, we effectively do the same thing but flipping the rows with the columns, so to update <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> we take a weighted sum of the other elements in the same column (<strong><span style="color: #7CC9F4;">z<sub>k,j</sub></span></strong>), where the weighting of each <strong><span style="color: #7CC9F4;">z<sub>k,j</sub></span></strong> is based on the third element in its outgoing edge triangle (<strong><span style="color: #7CC9F4;">z<sub>k,i</sub></span></strong>). After creating the same linear projections, we take an element-wise multiplication of <strong>column</strong> i from a and <strong>column</strong> j from b, and sum over all the <strong>rows of this matrix</strong>. You’ll find that these operations exactly mirror the graph-theory adjacency view described above.</p> <h3 id="triangle-attention">Triangle Attention</h3> <p>After our two triangle update steps, we also update each <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> using <strong>triangle attention</strong> for the outgoing edges and triangle attention for the incoming edges. The AF3 paper refers to the “outgoing edges” as attention “around starting node” and “incoming edges” as attention “around ending node”.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/triangle_attn_starting-480.webp 480w,/assets/img/af3_post/triangle_attn_starting-800.webp 800w,/assets/img/af3_post/triangle_attn_starting-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/triangle_attn_starting.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>To build up to triangle attention, it can be helpful to start with typical self-attention over a 1D sequence. Recall that queries, keys, and values are all transformations of the original 1D sequence. An attention variant called <a href="https://arxiv.org/abs/1912.12180">axial attention</a> extends this to matrices by applying independent 1D self-attention over the different axes of a 2D matrix (the rows, then the columns). Triangle attention adds the triangle principle we discussed earlier to this, updating <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> by incorporating <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong> and <strong><span style="color: #7CC9F4;">z<sub>j,k</sub></span></strong> for all atoms k. Specifically, in the “starting node” case, to calculate the attention scores along row i (to determine how much <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> should be influenced by <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong>), we do a query-key comparison between <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> and <strong><span style="color: #7CC9F4;">z<sub>i,k</sub></span></strong> as usual, then bias the attention based on <strong><span style="color: #7CC9F4;">z<sub>j,k</sub></span></strong> as is shown above.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/triangle_attn_ending-480.webp 480w,/assets/img/af3_post/triangle_attn_ending-800.webp 800w,/assets/img/af3_post/triangle_attn_ending-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/triangle_attn_ending.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>For the “ending node” case, we again swap rows for columns. For <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong>, the keys and values will both come from column i of <strong><span style="color: #7CC9F4;">z</span></strong>, while the bias will come from column j. So, when comparing the query <strong><span style="color: #7CC9F4;">z<sub>i,j</sub></span></strong> with the key <strong><span style="color: #7CC9F4;">z<sub>k,i</sub></span></strong>, we bias that attention score based on <strong><span style="color: #7CC9F4;">z<sub>k,j</sub></span></strong>. Then, once we have attention scores over all k, we use our values vectors from column i.</p> <h3 id="single-attention-with-pair-bias">Single Attention with Pair Bias</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/single_attn_w_pair_bias-480.webp 480w,/assets/img/af3_post/single_attn_w_pair_bias-800.webp 800w,/assets/img/af3_post/single_attn_w_pair_bias-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/single_attn_w_pair_bias.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Now that we’ve updated our pair representation with these four triangle steps, we pass the pair representation through a Transition block as described above. Finally, we want to update our single representation (<strong><span style="color: #F5ACFB;">s</span></strong>) using this new updated pair representation (<strong><span style="color: #7CC9F4;">z</span></strong>), so we will use single attention with pair bias, pictured below. This is identical to Single Attention with Pair Bias described<d-footnote>For reference, in the AF3 supplement, Single Attention with Pair Bias is also referred to as "Attention Pair Bias"</d-footnote> in the Atom Transformer section, but at the token-level. As it operates on the token-level, it uses full attention as opposed to the block-wise sparse pattern used when operating at the atom-level.</p> <p>We repeat the Pairformer for 48 blocks, eventually creating <strong><span style="color: #F5ACFB;">s<sup>trunk</sup></span></strong> and <strong><span style="color: #7CC9F4;">z<sup>trunk</sup></span></strong>.</p> <h1 id="3-structure-prediction">3. Structure Prediction</h1> <h2 id="basics-of-diffusion">Basics of Diffusion</h2> <p>Now, with these refined representations, we are ready to use <strong><span style="color: #F5ACFB;">s</span></strong> and <strong><span style="color: #7CC9F4;">z</span></strong> to predict the structure of our complex. One of the changes introduced in AF3 is that entire structure prediction is based on atom-level diffusion. Existing posts more thoroughly explain the <a href="https://www.superannotate.com/blog/diffusion-models">intuition</a> and <a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/">math</a> for diffusion, but the basic idea of a Diffusion Model is to start with real data, add random noise to your data, then train a model to predict what noise was added. Noise is iteratively added to the data over a series of T timesteps to create a sequence of T variants of each datapoint. We call the original data point x<sub>t=0</sub> and the fully noised version x<sub>t=T</sub>. During training, at timestep t, the model is given the x<sub>t</sub> and predicts what noise was added between x<sub>t-1</sub> and x<sub>t</sub>. We take a gradient step on the predicted noise added compared to the actual noise that had been added.</p> <p>Then, at inference time, we simply start with random noise, which is equivalent to x<sub>t=T</sub>. For every time step, we predict the noise the model thinks has been added, and remove that predicted noise. After a pre-specified number of timesteps, we end up with a fully “denoised” datapoint that should resemble the original data from our dataset.</p> <p>Conditional Diffusion lets the model ‘condition’ these de-noising predictions on some input. Practically this means that for each step of the model, it takes three inputs:</p> <ol> <li>The current noisy iteration of our generation</li> <li>A representation of the current time step we are at</li> <li>The information we want to condition on (this could be a caption for an image to generate, or properties for a protein).</li> </ol> <p>As a result, the final generation is not just a random example that resembles the training data distribution, but should specifically match the information represented by this conditioning vector.</p> <p>With AF3, the data we learn to de-noise is a matrix <strong><span style="color: #F4DD65;">x</span></strong> with the x,y,z coordinates of all the atoms in our sequences. During training, we add Gaussian noise to these coordinates until they are effectively fully random. Then at inference time, we start with random coordinates. At each time step, we first randomly rotate and translate our entire predicted complex. This data-augmentation teaches the model that any rotation and translation of our complex is equally valid, and replaces the much more complicated Invariant Point Attention used in AF2. <d-footnote>AF2 had developed a complicated architecture called Invariant Point Attention meant to enforce equivariance to translations and rotations. This led to a vigorous debate over the importance of IPA in AF2's success. In AF3, this is dropped in favor of a much simpler approach: applying random rotations and translations as data-augmentations to help the model learn such equivariances naturally. So here we simply randomly rotate all atoms' coordinates around the center of our current generation (the mean over all atoms' coordinates), and randomly sample a translation in each dimension (x, y, and z) from a N(0,1) Gaussian. It appears from the algorithm that the translation is universal, that is the same translation is applied to every atom in our current generation. This type of data augmentation was popularize with CNNs but in the past few years, equivariant architectures like IPA have been considered an more efficient and elegant approach to solve the same problem. Thus, when AF3 replaced equivariant attention with data-augmentation, it sparked a lot of internet discussions.</d-footnote> We then add a small amount of noise to the coordinates to encourage more heterogeneous generations.<d-footnote>It benefits us for the model to generate several slightly different variations. At inference time, we can score each using our confidence head, and return only the generation with the highest score.</d-footnote> Finally, we predict a de-noising step using the Diffusion Module. We cover this module in more detail below:</p> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/coordinates_for_diffusion-480.webp 480w,/assets/img/af3_post/coordinates_for_diffusion-800.webp 800w,/assets/img/af3_post/coordinates_for_diffusion-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/coordinates_for_diffusion.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> Data (coordinates) to get de-noised </div> <h2 id="diffusion-module">Diffusion Module</h2> <div class="l-gutter"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/summaries/diffusion-480.webp 480w,/assets/img/af3_post/summaries/diffusion-800.webp 800w,/assets/img/af3_post/summaries/diffusion-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/summaries/diffusion.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">See where this fits into the full architecture</div> </div> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/diffusion_module-480.webp 480w,/assets/img/af3_post/diffusion_module-800.webp 800w,/assets/img/af3_post/diffusion_module-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/diffusion_module.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>In each de-noising diffusion step, we condition our prediction on multiple representations of the input sequences:</p> <ul> <li>the outputs of the trunk (our post-Pairformer updated <strong><span style="color: #F5ACFB;">s</span></strong> and <strong><span style="color: #7CC9F4;">z</span></strong>, now called <strong><span style="color: #F5ACFB;">s<sup>trunk</sup></span></strong> and <strong><span style="color: #7CC9F4;">z<sup>trunk</sup></span></strong>)</li> <li>the initial atom and token-level representations of the sequence created in the input embedder that have not gone through the trunk (<strong><span style="color: #F5ACFB;">s<sup>inputs</sup></span></strong>, <strong><span style="color: #A056A7;">c<sup>inputs</sup></span></strong>)</li> </ul> <p>The AF3 paper breaks down its diffusion process into 4 steps that involve moving from tokens to atoms, back to tokens, and back to atoms:</p> <ol> <li><a href="#1-prepare-token-level-conditioning-tensors"><strong>Prepare token-level conditioning tensors</strong></a></li> <li><a href="#2-prepare-atom-level-tensors-apply-atom-level-attention-and-aggregate-back-to-token-level"><strong>Prepare atom-level conditioning tensors, update them using the Atom Transformer, and aggregate them back to token-level</strong></a></li> <li><a href="#3-apply-attention-at-the-token-level"><strong>Apply attention at the token-level, and project back to atoms</strong></a></li> <li><a href="#4-apply-attention-at-the-atom-level-to-predict-atom-level-noise-updates"><strong>Apply attention at the atom-level to predict atom-level noise updates</strong></a></li> </ol> <h3 id="1-prepare-token-level-conditioning-tensors">1. Prepare token-level conditioning tensors</h3> <div class="l-body"> <div class="row"> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_token_level_single_cond-480.webp 480w,/assets/img/af3_post/make_token_level_single_cond-800.webp 800w,/assets/img/af3_post/make_token_level_single_cond-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_token_level_single_cond.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_token_level_pair_cond-480.webp 480w,/assets/img/af3_post/make_token_level_pair_cond-800.webp 800w,/assets/img/af3_post/make_token_level_pair_cond-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_token_level_pair_cond.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>To initialize our token-level conditioning representation, we concatenate <strong><span style="color: #7CC9F4;">z<sup>trunk</sup></span></strong> to the relative positional encodings then project this larger representation back down and pass it through several residual-connection transition blocks.</p> <p>Similarly, for our token-level single representation, we concatenate the very first representation of the input created at the start of the model (<strong><span style="color: #F5ACFB;">s<sup>inputs</sup></span></strong>) and our current representation (<strong><span style="color: #F5ACFB;">s<sup>trunk</sup></span></strong>), then project it back down to its original size. We then create a Fourier embedding based on the current diffusion time step<d-footnote>More specifically, the amount of noise associated with this timestep in the Noise Schedule</d-footnote>, add that to our single representation, and pass that combination through several Transition blocks. By including the diffusion time step in the conditioning input here, it ensures the model is aware of the timestep in the diffusion process when making de-noising predictions, and so predicts the right scale of noise to remove for this timestep.</p> <h3 id="2-prepare-atom-level-tensors-apply-atom-level-attention-and-aggregate-back-to-token-level">2. Prepare atom-level tensors, apply atom-level attention, and aggregate back to token-level</h3> <p>At this point, our conditioning vectors are storing information at a per-token level, but we want to also run attention at the atom-level. To address this, we take our initial atom-level representations of the input created in the Embedding section (<strong><span style="color: #A056A7;">c</span></strong> and <strong><span style="color: #087CBE;">p</span></strong>), and update them based on the current token-level representations, to create atom-level conditioning tensors.</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_atom_level_single_cond-480.webp 480w,/assets/img/af3_post/make_atom_level_single_cond-800.webp 800w,/assets/img/af3_post/make_atom_level_single_cond-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_atom_level_single_cond.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/make_atom_level_pair_cond-480.webp 480w,/assets/img/af3_post/make_atom_level_pair_cond-800.webp 800w,/assets/img/af3_post/make_atom_level_pair_cond-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/make_atom_level_pair_cond.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>Next, we scale the atom’s current coordinates (<strong><span style="color: #F4DD65;">x</span></strong>) by the variance of the data, effectively creating “dimensionless” coordinates with unit variance (called <strong><span style="color: #F4DD65;">r</span></strong>). We then update <strong><span style="color: #A056A7;">q</span></strong> based on <strong><span style="color: #F4DD65;">r</span></strong> such that <strong><span style="color: #A056A7;">q</span></strong> is now aware of the atom’s current location. Finally, we update <strong><span style="color: #A056A7;">q</span></strong> with the Atom Transformer (which also takes the pair representation as input), and aggregate the atoms back to tokens as we’ve previously seen.<d-footnote>Recall from the input preparation section that Atom Transformer runs sparse attention over the atoms, and all steps (layer norm, attention, gating) are conditioned on the conditioning tensor <b><span style="color: #A056A7;">c</span></b>.</d-footnote></p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/agg_back_to_token_level-480.webp 480w,/assets/img/af3_post/agg_back_to_token_level-800.webp 800w,/assets/img/af3_post/agg_back_to_token_level-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/agg_back_to_token_level.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>At the end of this step, we return</p> <ul> <li><strong><span style="color: #A056A7;">q</span></strong>: updated atom representation after incorporating information about the atom’s current coordinates</li> <li><strong><span style="color: #F5ACFB;">a</span></strong>: token-level aggregated form of <span style="color: #A056A7;">q</span>, capturing coordinates and sequence information</li> <li><strong><span style="color: #A056A7;">c</span></strong>: atom representation for conditioning based on the trunk</li> <li><strong><span style="color: #087CBE;">p</span></strong>: our updated atom-pair representation for conditioning</li> </ul> <h3 id="3-apply-attention-at-the-token-level">3. Apply attention at the token-level</h3> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/diffusion_transformer-480.webp 480w,/assets/img/af3_post/diffusion_transformer-800.webp 800w,/assets/img/af3_post/diffusion_transformer-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/diffusion_transformer.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>The goal of this step is to apply attention to update our token-level representation of the atom coordinates and sequence information, <span style="color: #F5ACFB;">a</span>. This step uses the Diffusion Transformer visualized during input preparation, which mirrors the Atom Transformer but for tokens.</p> <h3 id="4-apply-attention-at-the-atom-level-to-predict-atom-level-noise-updates">4. Apply attention at the atom-level to predict atom-level noise updates</h3> <p>Now, we return to atom space. We use our updated <strong><span style="color: #F5ACFB;">a</span></strong> (token-level representations based on current “center atom” locations) to update <strong><span style="color: #A056A7;">q</span></strong> (atom-level representation of all atoms based on current location) using the Atom Transformer. As was done in step 3, we broadcast our tokens representation to match the number of atoms we started with (selectively duplicating the tokens that represent multiple atoms), and run the Atom Transformer. Most importantly, one last linear layer maps this atom-level representation <strong><span style="color: #A056A7;">q</span></strong> back to R<sup>3</sup>. This is the key step: we’ve used all these conditioning representations to generate coordinate updates <strong><span style="color: #F4DD65;">r<sup>update</sup></span></strong> for all atoms. Now, because we generated these in the “dimensionless” space <span style="color: #F4DD65;">r<sub>l</sub></span>, we carefully re-scale<d-footnote>This careful scaling involves both the variance of our data, and the noise schedule based on our current timestep, so that our updates are smaller and smaller as we get deeper into the de-noising process.</d-footnote> the updates from <strong><span style="color: #F4DD65;">r<sup>update</sup></span></strong> to their form with non-unit variance, <strong><span style="color: #F4DD65;">x<sup>update</sup></span></strong>, and apply the updates to <strong><span style="color: #F4DD65;">x<sub>l</sub></span></strong>.</p> <p>With that, we’ve completed our tour through the main architecture of AlphaFold 3! Now we provide some additional information about the loss function, auxiliary confidence heads, and training details.</p> <h1 id="4-loss-function-and-other-training-details">4. Loss Function and Other Training Details</h1> <h2 id="loss-function-and-confidence-heads">Loss function and confidence heads</h2> <p>L<sub>loss</sub> = L<sub>distogram</sub> * α<sub>distogram</sub> + L<sub>diffusion</sub> * α<sub>diffusion</sub> + L<sub>confidence</sub> * α<sub>confidence</sub></p> <p>The loss is a weighted sum of 3 terms:</p> <ul> <li><strong>L<sub>distogram</sub></strong> which evaluates the accuracy of the predicted distogram at a token-level</li> <li><strong>L<sub>diffusion</sub></strong> which evaluates the accuracy of the predicted distogram at an atom-level. It looks at all pairwise distances then includes additional terms to prioritize distances between nearby atoms and atoms involved in protein-ligand bonds.</li> <li><strong>L<sub>confidence</sub></strong> which evaluates the model’s self-awareness about which structures are likely to be inaccurate</li> </ul> <h3 id="ldistogram">L<sub>distogram</sub></h3> <p>The output of our model is atom-level coordinates, which can easily be used to create an atom-level distogram<d-footnote>Recall how the distograms were initially created by binning pairwise distances between atoms</d-footnote>. However, this loss evaluates a token-level distogram. To get the xyz coordinates for tokens, we just use the coordinate of the “center atom”. As these distogram distances are categorical, the predicted distogram is then compared to the true distogram via cross entropy.</p> <h3 id="ldiffusion">L<sub>diffusion</sub></h3> <p>The diffusion loss itself is a weighted sum of three terms each computed over the atom positions, additionally scaled by the amount of noise<d-footnote>t<sup>^</sup>, the sampled noise level for the current time step, and σ<sub>data</sub>, the variance of the data which scales the amount of noise at each time step</d-footnote> added at the current time step:</p> <p>L<sub>diffusion</sub> = (L<sub>MSE</sub> + L<sub>bond</sub> * α<sub>bond</sub>) * (t̂² + σ<sub>data</sub>²)/(t̂+σ<sub>data</sub>)² + L<sub>smooth_lddt</sub></p> <ul> <li><strong>L<sub>MSE</sub></strong> is a version of the distogram loss we just discussed, but over all atoms rather just “center atoms” (and with DNA, RNA, and ligand atoms upweighted). Additionally, it looks at the mean squared error between positions, rather than binning them into a distogram.</li> <li><strong>L<sub>bond</sub></strong> aims to ensure the accuracy of bond lengths for protein-ligand bonds by adding an additional MSE loss on the difference in predicted and ground-truth distograms for atom-pairs that are part of protein-ligand bonds.<d-footnote>There are various stages of training and α<sub>bond</sub> is set to 0 in the initial stages, so this term is only introduced later.</d-footnote></li> <li><strong>L<sub>smooth_LDDT</sub></strong> (smoothed local distance difference test) is yet another variant of the distogram loss that tries to capture the accuracy of local distances. An atom-pair’s predicted distance “passes the test” if it is within a given threshold of the atom-pair’s true distance. To make this metric smooth and differentiable, we pass the difference between predicted and ground-truth distograms through a sigmoid centered on the test’s threshold. We can think of this as generating a probability (between 0 and 1) that this atom-pair passes the test. We take the average of four “tests’’ with increasingly tight thresholds (4, 2, 1, and .5 Å). Using this loss encourages the model to reduce the probability of failing each test. Finally, to make the test “local”, we ignore the loss for an atom-pair if that atom-pair’s ground truth distance is large, as we only want the model to focus on accurately predicting an atom’s distances to nearby atoms<d-footnote>Specifically, for an atom-pair l,m, we ignore the loss for l and m if l and m are more than 30 Å away if atom m is part of a nucleotide. We ignore the loss for l and m if they are more than 15 Å away from each other and m is not a nucleotide (so is part of a protein or ligand).</d-footnote>.</li> </ul> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/smooth_lddt-480.webp 480w,/assets/img/af3_post/smooth_lddt-800.webp 800w,/assets/img/af3_post/smooth_lddt-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/smooth_lddt.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <h3 id="lconfidence">L<sub>confidence</sub></h3> <p>The goal of this loss is not to improve the accuracy of the structure, but rather to teach the model to predict its own accuracy. This loss is a weighted sum of 4 terms that each correspond to a method of evaluating the quality of a predicted structure:</p> <p>L<sub>confidence</sub> = L<sub>plDDT</sub> + L<sub>PDE</sub> + L<sub>resolved</sub> + L<sub>PAE</sub> * α<sub>PAE</sub></p> <ul> <li> <p><strong>lDDT</strong> Atom-level “local distance difference test”, capturing the expected accuracy of an atom’s predicted distances to nearby atoms.</p> </li> <li> <p><strong>PAE</strong> Predicted alignment error between token i’s predicted and the ground-truth positions. We first rotate and translate the predicted token i and ground-truth token i into the frame of token j. That is, if we assume for a moment token j is in exactly its ground-truth position, we predict how close token i is to where it should be, based on its relation to token j.</p> </li> <li> <p><strong>PDE</strong> Predicted distance error between tokens, capturing the accuracy of predicted differences between all pairs of tokens.</p> </li> <li> <p><strong>Experimentally resolved prediction</strong> The model predicts which atoms were experimentally resolved (not every atom is experimentally resolved in every crystal structure).</p> </li> </ul> <p>To get these confidence losses for each of the metrics, AF3 predicts values for these error metrics, then these error metrics are calculated on the predicted structure, and the loss is based on the difference between these two. So even if the structure is really incorrect and the PAE is high, if the predicted PAE is also high, the L<sub>pae</sub> will be low.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/confidence_arch-480.webp 480w,/assets/img/af3_post/confidence_arch-800.webp 800w,/assets/img/af3_post/confidence_arch-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/confidence_arch.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>These confidence predictions are generated mid-way through the diffusion process. At a selected diffusion step t, the predicted coordinates <strong><span style="color: #F4DD65;">r<sup>t</sup></span></strong> are used to update the single and pair representations created in the representation learning trunk. The predicted errors are then calculated from linear projections of the updated pair representation (for PAE and PDE) or this updated single representation (pLDDT and experimentally resolved). Then, the actual error metrics are calculated based on the same generated atom coordinates (process described below, if interested) for comparison.</p> <p>While these terms are included in the confidence head loss, gradients from these terms are only used to update the confidence prediction heads and do not affect the rest of the model.</p> <details> <summary>How are the actual error metrics calculated?</summary> <p><b>pLDDT:</b> The LDDT for atom l is calculated in the following way: in the current predicted structure, we calculate the distance between atom l and a set of atoms R that is indexed by m, and compare this to the ground truth equivalent. To be in this set, an atom m must be part of a polymer chain, within 15 or 30 Å of l depending on the molecule m is a part of, and the center atom of a token. We then calculate four binary distance tests with increasingly tight thresholds (4, 2, 1, and .5 Å) and take the average pass rate, and sum over the atoms in R. We bin this percentage into 50 bins between 0 and 1.</p> <p>At inference time, we have a pLDDT head. This head takes the single representation of a given token, repeats it out across all the atoms "attached" to this token<d-footnote>Technically, the max number of atoms attached to any token, so that we can stack tensors</d-footnote>, and projects all those atom-level representations to the 50 bins of our pLDDT_l. We treat these as logits across the 50 "classes", use a softmax to convert to probabilities, and take a multi-class classification loss across the bins.</p> <p><b>Predicted Alignment Error (PAE):</b> Every token is considered to have a frame, that is a 3D coordinate frame created from three atoms (called a, b, c) involved in that token. Atom b within those three atoms forms the origin in this frame. In cases where each token has a single atom "attached", the center atom of the frame is the single atom of the token, and the two other nearest tokens of the same entity (e.g., same ligand) form the basis of the frame. For every token pair (i,j) we re-express the predicted coordinates of the center atom for token_i using the frame of token_j. We do the same for the ground-truth coordinates of the center atom of token_i. The euclidean distance between these transformed true and predicted coordinates of the center atom of token_i is our alignment error, binned into 64 bins. We predict this alignment error from the pair representation <b><span style="color: #7CC9F4;">z<sub>i,j</sub></span></b>, projecting it to 64 dimensions that we treat as logits and convert to probabilities with a softmax. We train this head with a classification loss, which each bin as a class. See <a href="https://www.ebi.ac.uk/training/online/courses/alphafold/inputs-and-outputs/evaluating-alphafolds-predicted-structures-using-confidence-scores/pae-a-measure-of-global-confidence-in-alphafold-predictions/">here</a> for additional details.</p> <p>Third, AF3 predicts the distance error (PDE) between tokens. The true distance error is calculated by taking the distance between center atoms for every token pair, and binning these distances over 64 uniformly-sized bins from 0 Å to 32 Å. The predicted distance error comes from projecting the pair representation <b><span style="color: #7CC9F4;">z<sub>i,j</sub></span></b> plus the pair representation <b><span style="color: #7CC9F4;">z<sub>j,i</sub></span></b> into 64 dimensions that we again treat as logits, and again convert to probabilities with a softmax.</p> <p>Finally, AF3 predicts whether each atom was experimentally resolved in the ground-truth structure. Similar to the pLDDT head, we repeat the <b><span style="color: #F5ACFB;">s<sub>i</sub></span></b> single representation out for the number of atoms this token represents, and project to 2 dimensions and use a binary classification loss.</p> </details> <hr/> <h2 id="other-training-details">Other Training Details</h2> <p>Now that the architecture is covered, the last pieces are some of the additional training details.</p> <h3 id="recycling">Recycling</h3> <p>As introduced in AF2, AF3 recycles its weights; that is, rather than making the model deeper, the model weights are re-used and inputs are run through the modules multiple times to continually improve the representations. Diffusion inherently uses recycling at inference time, as the model is trained to incorporate the timestep information and use the same model weights for every time step.</p> <h3 id="cross-distillation">Cross-distillation</h3> <p>AF3 uses a mix of synthetic training data generated by itself (via self-distillation) but also by AF2, via <a href="https://link.springer.com/article/10.1007/s11263-024-02002-0">cross-distillation</a>. Specifically, the authors note that, by switching to the diffusion-based generative module, the model stopped producing the characteristic “spaghetti” regions that allowed users of AF2 to visually identify low-confidence and likely disordered regions. Just visually looking at the diffusion-based generations, all regions appeared equally high in confidence, making it more difficult to identify potential hallucinations.</p> <p>To solve this problem, they included generations from AF2 and AF-Multimer in the training data for AF3, allowing the model to learn that, when AF2 was not confident in its prediction, it should output these unfolded regions and to “instruct” AF3 to do the same.<d-footnote>Nucleic acids and small molecules in distillation datasets had to be removed as they could not be processed by AF2 and AF-multimer. However, once previous models generated new predicted structures, and these structures got aligned to the originals, the removed molecules were added back in. If adding these back in created new atom clashes, the whole structure was excluded, to avoid accidentally teaching the model to accept clashes.</d-footnote></p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/cross_distillation-480.webp 480w,/assets/img/af3_post/cross_distillation-800.webp 800w,/assets/img/af3_post/cross_distillation-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/cross_distillation.jpg" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">(Diagram from AF3 paper)</div> </div> <h3 id="cropping-and-training-stages">Cropping and Training Stages</h3> <p>While no part of the model has an explicit restriction on the length of the input sequences, the memory and compute requirements increase significantly with sequence length (recall the multiple O(N<sub>tokens</sub><sup>3</sup> operations)). Thus, for efficiency the proteins get randomly cropped. As introduced in AF-multimer, because we want to model the interactions between multiple chains, the random cropping needs to include all of these. They use 3 methods for cropping and all 3 of these are used in different proportions depending on the training data (ex: PDB crystal structure vs disordered PDB complex vs distillation, etc.)</p> <ul> <li>Contiguous cropping: Contiguous sequences of amino acids are selected for each chain</li> <li>Spatial cropping: Amino acids are selected based on distance to a reference atom (typically this atom is part of a specific chain or binding interface of interest)</li> <li>Spatial interface cropping: Similar to spatial cropping, but based on distances to atoms that specifically at a binding interface.</li> </ul> <p>While a model trained on random crops of 384 can be applied to longer sequences, to improve the model’s ability to handle these sequences, it is iteratively fine-tuned on larger sequence lengths. The mix of datasets and other training details is also varied in each training stage as is shown in the table below.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/af3_post/training_stages-480.webp 480w,/assets/img/af3_post/training_stages-800.webp 800w,/assets/img/af3_post/training_stages-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/af3_post/training_stages.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> <div class="caption">(Table from AF3 supplement)</div> </div> <h3 id="clashing">Clashing</h3> <p>The authors note that AF3’s loss does not include a clash penalty for overlapping atoms. While switching to a diffusion-based structure module means the model could in theory predict two atoms to be in the same location, this seems to be minimal after training. That said, AF3 does employ a clashing penalty when ranking generated structures.</p> <h3 id="batch-sizes">Batch sizes</h3> <p>Although the diffusion process sounds quite involved, it is still significantly less computationally expensive than the trunk of the model. Thus, the AF3 authors found that it is more efficient from a training perspective to expand the batch size of the model after the trunk. So for each input structure, it gets run through the embedding and trunk, then 48 independent data-augmented versions of the structure are applied, and these 48 structures are all trained in parallel.</p> <p><strong>That’s it for the training process!</strong> There are some other small details but this is probably already more than you need, and if you’ve made it this far, the rest should be easy to pick up from reading the AF3 supplement.</p> <h1 id="ml-musings">ML Musings</h1> <p>Having walked so thoroughly through the architecture of AF3 and its comparisons to AF2, it is interesting how the choices made by the authors fit into broader Machine Learning trends.</p> <h3 id="alphafold-as-retrieval-augmented-generation">AlphaFold as Retrieval-Augmented Generation</h3> <p>At the time AF2 was released, it was not common to include retrievals from the training set at inference time. In the case of AF, utilizing an MSA and template search. MSA-based methods were being used for protein modeling, but this type of retrieval was less used in other areas of Deep Learning (i.e., ResNets do not embed relevant training images at inference time when classifying a new image in Computer Vision, for example). Although AF3 reduces the emphasis on the MSA compared to AF2 (it is no longer operated on and updated in the 48 blocks of the Evoformer/Pairformer), they still incorporate both the MSA and templates, even as other protein prediction models such as ESMFold have dropped retrieval in favor of fully parametric inference.</p> <p>Interestingly, some of the largest and most successful Deep Learning models now often include similar additional information at inference time. While the details of the retrieval systems are not always disclosed, Large Language Models routinely use Retrieval Augmented Generation systems such as a traditional web search at inference time to orient the model toward relevant information (even if that information was likely already in its training data) that should guide inference. It will be interesting to see how the use of directly relevant examples at inference time develops in the future.</p> <h3 id="pair-bias-attention">Pair-Bias Attention</h3> <p>One of the major components of AF2 that is even more present in AF3 is Pair-Bias Attention. That is, attention where the queries, keys, and values all originate from the same source (like in self-attention), but where there is a bias term added to the attention map from another source. This effectively acts as a light-touch version of information sharing, without full cross-attention. Pair-Bias Attention appears in almost every module. While this type of attention is now used in other protein modeling architectures, we have not seen this particular type of cross-biasing used in other fields (although that does not mean it hasn’t been done!). Perhaps it only works well here because the pair-representation is naturally analogous to a self-attention map already, but is an intriguing alternative to pure self or pure cross-attention.</p> <h3 id="self-supervised-training">Self-supervised training</h3> <p>Self-supervised models like ESM have been able to achieve impressive results at predicting protein structure by replacing the MSA embedding with a “probabilistic MSA” using self-supervised pre-training. In AF2, the model had an additional task that predicted masked tokens from the MSA, achieving a similar self-supervision, but that was removed with AF3. We have not seen commentary from the authors on why they did not use any self-supervised language modeling pre-training approach on the MSA, and in fact decreased the compute used to process the MSA. Three possible reasons self-supervised learning is not used to initialize the MSA embeddings are 1) they viewed the massive pre-training phase as a suboptimal use of compute 2) they tried it and found that including a small MSA module outperformed pre-trained embeddings and was worth the additional inference-time cost or 3) utilizing a mix of pre-trained embeddings for amino acid tokens and randomly initialized embeddings for DNA/RNA/ligands would not be compatible or underperformed fully supervised training on their hybrid atom-token structure. <d-footnote>By focusing on self-supervision tasks, models in the ESM family are also much more simple than AF3 (although they don't handle DNA/RNA/ligands and have slightly different goals.) It is interesting to watch that as some models aim to maximize architectural simplicity, AlphaFold remains this complicated! </d-footnote></p> <h3 id="classification-vs-regression">Classification vs. Regression</h3> <p>As in AF2, AF3 continues to use a mix of MSE and binned classification losses. The classification components are interesting as, if the model predicts a distogram bin that is only off-by-one, it gets no “credit” for being close rather than way off. It is unclear what informed this design decision, but perhaps the authors found the gradients to be more stable than working with several different MSE losses, and perhaps the per-atom losses saw so many gradient steps that the additional signal from a continuous loss would not have proven beneficial.</p> <h3 id="similarities-to-recurrent-architectures-eg-lstms">Similarities to Recurrent Architectures (e.g LSTMs)</h3> <p>AF3’s architecture incorporates several design elements reminiscent of recurrent neural networks that are not typically found in traditional transformers:</p> <ul> <li>Extensive Gating: AF3 uses gating mechanisms throughout its architecture to control information flow in the residual stream. This is more akin to the gating in LSTMs or GRUs than the standard feed-forward nature of normal transformer layers.</li> <li>Iterative Processing with Weight Reuse: AF3 applies the same weights multiple times to progressively refine its predictions. This process, involving both recycling and the diffusion model, resembles how recurrent networks process sequential data over time steps using a shared set of weights. It differs from standard transformers, which typically make predictions in a single forward pass. This approach allows AF3 to iteratively improve its protein structure predictions without increasing the number of parameters.</li> <li>Adaptive Computation: The recycling is also similar to the iterative updating used in diffusion and quite related to the idea of adaptive compute time <a href="https://arxiv.org/abs/1603.08983">(ACT)</a>, originally introduced to dynamically determine how much compute to use for RNNs and more recently used in <a href="https://arxiv.org/pdf/2404.02258">Mixture-of-Depths</a> to achieve a similar goal with transformers. This contrasts with the fixed depth of standard transformers and theoretically would allow the model to apply more processing to challenging inputs.</li> </ul> <p>It was shown in the AF2 ablations that the recycling was important, but there was little disucssion on the importance of gating. Presumably it helps with training stability as in LSTMs but it is interesting that it is so prevalent here yet not in many other transformer-based architectures.</p> <h3 id="cross-distillation-1">Cross-distillation</h3> <p>The use of AF2 generations to re-introduce its distinctive style specifically for low-confidence regions is very interesting. If there is a lesson here, it may be the most practical of all: If your previous model is doing one specific thing better than your new model, you can try cross-distillation to get the best of both worlds!</p> <hr/> <p>If you’ve made it this far, thanks for reading and hope it was helpful!! If you have any questions / comments / corrections / feedback feel free to reach out to us on twitter (<a href="https://twitter.com/ElanaPearl/">E</a>, <a href="https://twitter.com/JakeSilberg">J</a>) or email (<a href="mailto:epsimon@stanford.edu" target="_blank">E</a>, <a href="mailto:jsilberg@stanford.edu" target="_blank">J</a>) !</p> <p>Special thanks to <a href="https://x.com/kristyacarp">Kristy Carpenter</a>, <a href="https://twitter.com/nickevanjoseph">Nicholas Joseph</a>, <a href="https://swansonkyle.com/">Kyle Swanson</a>, and <a href="https://karamarieliu.github.io/">Kara Liu</a> for giving feedback on this 💜</p>]]></content><author><name>Elana Simon</name></author><summary type="html"><![CDATA[A visual walkthrough of the AlphaFold3 architecture, with more details and diagrams than you were probably looking for.]]></summary></entry><entry><title type="html">Mapping Chemical Space with UMAP</title><link href="https://elanapearl.github.io/blog/2021/umap/" rel="alternate" type="text/html" title="Mapping Chemical Space with UMAP"/><published>2021-04-06T11:46:00+00:00</published><updated>2021-04-06T11:46:00+00:00</updated><id>https://elanapearl.github.io/blog/2021/umap</id><content type="html" xml:base="https://elanapearl.github.io/blog/2021/umap/"><![CDATA[<div class="l-page"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/download-2-480.webp 480w,/assets/img/umap_post/download-2-800.webp 800w,/assets/img/umap_post/download-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/download-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p><em>Note: this was originally written for the Reverie Labs blog which got taken down after acquisition so now it’s re-posted here</em></p> <p>This blog post discusses why it is important to visualize the latent space of chemical datasets, what makes UMAP a useful tool for this purpose, and how we use UMAP at Reverie Labs. As an example, we use the Blood Brain Barrier Permeability (BBBP) dataset from <a href="http://moleculenet.ai/datasets-1">MoleculeNet</a> for our visualizations and the code tutorial. This dataset has measurements from over 2000 unique compounds, many of which are approved drugs, each labeled as “permeable” or “not permeable”. We look into the details of how this dataset gets embedded by various dimensionality reduction methods and reveal some fascinating properties of UMAP.</p> <p>For a walkthrough using this dataset of how to use UMAP to visualize chemical space, see <a href="https://colab.research.google.com/gist/ElanaPearl/444b3331f61485bbe8862db27cb2b968/mapping-chemical-space-with-umap.ipynb">this</a> Colab notebook:</p> <script src="https://gist.github.com/ElanaPearl/444b3331f61485bbe8862db27cb2b968.js"></script> <h2 id="motivation">Motivation</h2> <p>A fundamental assumption behind most machine learning methods is that data are independent and identically distributed (IID). However, in drug discovery datasets, compounds are almost never sampled independently, as they are typically extracted from experiments for specific therapeutic programs. Measurements often follow the patterns of the drug development efforts that generate them. Any biases in the data-generation process can also sneak into the training and evaluation of models. In practice, this means that open source and industry datasets are often “clumpy”, consisting of measurements for compounds that are very similar to one another and non-uniformly cover chemical space. Visualizing the chemical space of a dataset does not solve these issues, but it helps us better understand them within the context of our datasets.</p> <p>At Reverie Labs, we work with dozens of datasets spanning a range of different properties. When ingesting a new dataset, we begin with a systematized analysis that identifies biases and helps determine how best to prepare and use the dataset for modeling. As with most machine learning problems, we want to understand the distribution of the measured property we are modeling (What is the class balance within a classification dataset? Does our regression dataset have truncated measurements due to assay limits? etc.) We want to make sure the measurements look reasonable and compare their distributions in our training and test data. However, we cannot stop there. We must also look at the distribution of the chemical structures that the measurements are taken from because the compounds themselves are typically selected from a biased generation process.</p> <div class="l-gutter"> <p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/data_doodle-480.webp 480w,/assets/img/umap_post/data_doodle-800.webp 800w,/assets/img/umap_post/data_doodle-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/data_doodle.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </p> </div> <p>Visualizing the distribution of our compounds in chemical space allows us to gauge how much we expect models trained on a given dataset to generalize to new chemical matter. These visualizations can help with manual inspection of Structure-Activity / Structure-Property Relationships (SAR/SPR), expose potential quirks or biases in the dataset, and reveal insights for how we might want to split the dataset into training and evaluation sets. These are many methods we can use for visualizing chemical space, but at Reverie, we have selected a default procedure involving UMAP that optimizes accuracy, speed, and ease-of-use.</p> <h2 id="embeddings">Embeddings</h2> <p>That sounds great, but how do we actually visualize these distributions? The key here is that we need to embed our compounds into a low-dimensional vector form that can be easily interpreted (in this post we stick to 2D for simplicity’s sake but 3D would also work). Representing our molecules as <a href="https://pubs.acs.org/doi/10.1021/ci100050t">2048-bit Extended-Connectivity Fingerprints (ECFPs)</a> gives us high dimensional vectors that we can then project into a 2-dimensional space for visualization. <a href="https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.PCA.html">PCA</a> and <a href="https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html">t-SNE</a> are commonly used for this kind of dimensionality reduction, and have been used for many biology and chemistry purposes. Since usage of these tools has been thoroughly documented, here we focus on the utility of a more recent addition to the dimensionality reduction repertoire: <a href="https://arxiv.org/pdf/1802.03426.pdf">UMAP</a>.</p> <p>Uniform Manifold Approximation and Projection (UMAP) constructs a high-dimensional graph representation of the entire dataset then tries to re-create a low dimensional version of this initial graph that maintains as much of the local and global structure as possible. This method is somewhat similar to t-SNE, but with a few key differences that lead to important advantages for our purposes. The technical differences between UMAP and t-SNE would take up a full blog post, so we will not detail them here but more details can be found in these resources: <a href="https://arxiv.org/pdf/1802.03426.pdf">the original paper</a>, <a href="https://pair-code.github.io/understanding-umap/">Understanding UMAP</a> or <a href="https://towardsdatascience.com/how-exactly-umap-works-13e3040e1668">How Exactly UMAP Works</a>.</p> <p>The most relevant benefits UMAP provides us are speed, the ability to maintain some of the local / global structure of the data, and an easy interface for applying an embedding from one dataset to a different dataset.</p> <h3 id="speed">Speed</h3> <p>To evaluate the speed of these methods, we compute ECFPs of various sizes and embed each using each PCA, t-SNE, and UMAP:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/timing_plot-480.webp 480w,/assets/img/umap_post/timing_plot-800.webp 800w,/assets/img/umap_post/timing_plot-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/timing_plot.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>We can see that PCA is the most efficient, UMAP is slightly less efficient, and t-SNE is by far the least efficient. For small datasets, these time differences are not particularly relevant, but they really add up as the number of compounds grows. To be fair, we used the most generic implementations (PCA and t-SNE from <a href="https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html">sklearn</a>) and there are more efficient variants of t-SNE. However, the <a href="https://umap-learn.readthedocs.io/en/latest/performance.html">UMAP Docs</a> contain a similar performance analysis on <a href="https://en.wikipedia.org/wiki/MNIST_database">MNIST</a>, including a wider variety of the performant methods, and find similar results.</p> <p>Additionally, the creators of UMAP recently released a new version of the algorithm, <a href="https://arxiv.org/abs/2009.12981">ParametricUMAP</a>, that uses a neural network to reduce the dimensions of the graphical embedding, giving it even greater speed improvements. For this post we stick to the original, non-parametric UMAP, but if you want to optimize the speed of embedding new compounds into a pre-fit UMAP model, ParametricUMAP can be quite useful.</p> <h3 id="local--global-structure">Local / Global Structure</h3> <p>If PCA is so fast, why don’t we just use that? Our goal is to understand both the local structure (organization of similar compounds) and global structure (organization of groups of different compounds) of complex datasets. Speed is important, but we also evaluate the methods according to how informative they are for those tasks. We can use the example BBBP dataset to explore the local and global structure of the embeddings created by the various methods:</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/pca_original-2-480.webp 480w,/assets/img/umap_post/pca_original-2-800.webp 800w,/assets/img/umap_post/pca_original-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/pca_original-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/tsne_original--1-480.webp 480w,/assets/img/umap_post/tsne_original--1-800.webp 800w,/assets/img/umap_post/tsne_original--1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/tsne_original--1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/umap_original-3-480.webp 480w,/assets/img/umap_post/umap_original-3-800.webp 800w,/assets/img/umap_post/umap_original-3-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/umap_original-3.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>Click these links for interactive versions of the plots [<a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/PCA_original.html">PCA</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/TSNE_original.html">t-SNE</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/UMAP_original.html">UMAP</a>]</p> <p>At a first glance, we notice some high-level differences in the overall structures of the embeddings. PCA looks like the intersection of two orthogonal lines, and the compounds within them are somewhat uniformly distributed. t-SNE has a flowery shape that resembles a 2D gaussian, with a variety of isolated clusters around the edges. UMAP has many disjoint, tight clusters that do not follow a specific pattern. Beyond surface-level observations, we cannot ascertain much more about these embeddings from these plots alone. To do that, we need to look more into the actual compounds that are represented in these images.</p> <p>These links [<a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/PCA_original.html">PCA</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/TSNE_original.html">t-SNE</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/UMAP_original.html">UMAP</a>] take you to pages with interactive views of these plots, which is how we typically examine our datasets. Exploring the data in an interactive form helps builds intuition for how the different algorithms organize the chemical space of the dataset. To help guide this interactive exploration, we’ve selected a few clusters from the UMAP plot and highlighted where each of their compounds are embedded in t-SNE and PCA:</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/pca_drugtype-480.webp 480w,/assets/img/umap_post/pca_drugtype-800.webp 800w,/assets/img/umap_post/pca_drugtype-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/pca_drugtype.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/tsne_drugtype-480.webp 480w,/assets/img/umap_post/tsne_drugtype-800.webp 800w,/assets/img/umap_post/tsne_drugtype-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/tsne_drugtype.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/umap_drugtype-480.webp 480w,/assets/img/umap_post/umap_drugtype-800.webp 800w,/assets/img/umap_post/umap_drugtype-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/umap_drugtype.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> <p> </p> </div> <p>We see that local areas of the embedding contain compounds that not only look similar, but also belong to the same drug class. The clusters we’ve selected here contain <a href="https://en.wikipedia.org/wiki/Steroid"><span style="color:#a48cf4;">steroids</span></a>, <a href="https://en.wikipedia.org/wiki/Tetracycline_antibiotics"><span style="color:#36ada4;">tetracycline antibiotics</span></a> and <a href="https://en.wikipedia.org/wiki/%CE%92-lactam_antibiotic"><span style="color:#97a431;">β-lactam antibiotics</span></a>. Diving into the details of these clusters and their respective drug types serves as a great case study through which we can better understand and compare the structure of the embeddings.</p> <h2 id="case-study">Case study</h2> <p>In the <strong><span style="color:#a48cf4;">steroid</span></strong> cluster there are many compounds with the 4-ring system that is characteristic of steroids, and even some non-steroid compounds with a similar structure. Each method separates these compounds out from the rest, although UMAP appears to isolate them out and group them together the most strongly.</p> <div class="l-gutter"> <p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/steroid_examples-480.webp 480w,/assets/img/umap_post/steroid_examples-800.webp 800w,/assets/img/umap_post/steroid_examples-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/steroid_examples.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </p> </div> <p>In a nearby but fully isolated cluster we find a collection of <strong><span style="color:#36ada4;">tetracycline antibiotics</span></strong> antibiotics. These compounds also contain a fused 4-ring system, however the <span style="color:#36ada4;">tetracycline</span> rings differ from the <span style="color:#a48cf4;">steroid</span> rings due to their shape, relative positioning, and bond orders.</p> <div class="l-gutter"> <p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/tetracycline_examples-480.webp 480w,/assets/img/umap_post/tetracycline_examples-800.webp 800w,/assets/img/umap_post/tetracycline_examples-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/tetracycline_examples.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </p> </div> <p>Comparing the embedding of the <span style="color:#a48cf4;">steroid</span> and <span style="color:#36ada4;">tetracycline</span> clusters we see: 1) <span style="color:#a48cf4;">Steroids</span> appear more isolated than the <span style="color:#36ada4;">tetracyclines</span>. If we assume the global distances between clusters are meaningful, this implies that the <span style="color:#a48cf4;">steroids</span> are more unique from the rest of the dataset than the <span style="color:#36ada4;">tetracycline</span> are. 2) While the <span style="color:#a48cf4;">steroid</span> compounds are spread out within their cluster, the <span style="color:#36ada4;">tetracyclines</span> are all embedded practically on top of each other. If we assume the spread of a cluster reflects the local diversity, this implies that the <span style="color:#a48cf4;">steroids</span> are more diverse than the <span style="color:#36ada4;">tetracyclines</span>. 3) The two clusters are relatively near each other. If we assume the relationships between clusters are meaningful, this implies that these clusters are more similar to each-other than they are to the rest of the dataset.</p> <p>To evaluate whether these assumptions about the global and local structure are true, we can examine the validity of the claims they imply. Should <span style="color:#a48cf4;">steroids</span> be placed farther away from the rest of the compounds than the <span style="color:#36ada4;">tetracyclines</span> are? Both groups appear to stand out from the rest of the dataset, but it is difficult to manually compare their levels of uniqueness. Should tetracyclines and <span style="color:#a48cf4;">steroids</span> actually be placed near each other? Given that both groups contain 4-fused ring systems, it seems reasonable for them to be located near each other but this similarity could just be superficial. To more quantitatively address these questions we look at the measured similarities between steroids, <span style="color:#36ada4;">tetracyclines</span>, and the rest of the compounds in the dataset. For each <span style="color:#a48cf4;">steroid</span> compound, we calculate its average similarity to the other steroids, the <span style="color:#36ada4;">tetracyclines</span>, and every other compound. This is repeated with the <span style="color:#36ada4;">tetracyclines</span>. We define chemical similarity between compounds using Tanimoto similarity between ECFPs for the compounds, the same metric used to create the embeddings. The higher the Tanimoto similarity, the more similar to compounds are to one another.</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/download-480.webp 480w,/assets/img/umap_post/download-800.webp 800w,/assets/img/umap_post/download-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/download.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-2 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/download-1-480.webp 480w,/assets/img/umap_post/download-1-800.webp 800w,/assets/img/umap_post/download-1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/download-1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> <p> </p> </div> <p>These plots contextualize some of our observations about the <span style="color:#a48cf4;">steroid</span> and tetracycline clusters:</p> <ol> <li> <p>The <span style="color:#a48cf4;">steroids</span> and <span style="color:#36ada4;">tetracyclines</span> have roughly equivalent levels of similarity to the rest of the dataset (average similarity of 0.09 and 0.1 respectively). The distributions are not identical but their difference is not large enough to explain the extra isolation of the <span style="color:#a48cf4;">steroid</span> cluster. This supports the general belief that these global distances are not always interpretable.</p> </li> <li> <p><span style="color:#36ada4;">Tetracyclines</span> are indeed more homogenous than steroids. <span style="color:#36ada4;">Tetracyclines</span> have, on average 0.15 higher tanimoto similarity with other <span style="color:#36ada4;">tetracyclines</span> than <span style="color:#a48cf4;">steroids</span> do with other steroids. This means that the spread of intra-group distributions actually reflect the local chemical diversity and are not just a strange quirk of the UMAP embedding.</p> </li> <li> <p><span style="color:#36ada4;">Tetracyclines</span> and <span style="color:#a48cf4;">steroids</span> have higher similarity to each other than they do to the rest of the dataset. This means that the visual relationship between these embedded clusters actually reflects a real relationship between the compounds that the clusters represent.</p> </li> </ol> <p>This reveals that the global structure of this dataset is not maintained through exact distances between groups of compounds, but rather the relationships between them. Local structure is expressed by maintaining the local diversity of groups through their intra-group distribution. There actually is an even more detailed level of local structure hidden in these embeddings that we’ll examine later, but given the information we have so far these are the main conclusions. The embedding of two groups of compounds does not prove anything about UMAP that we expect to hold up for all embeddings. But their examples highlight important patterns in how chemical datasets get embedded.</p> <p>We can compare UMAP’s arrangement of the <span style="color:#a48cf4;">steroids</span> and <span style="color:#36ada4;">tetracyclines</span> with PCA and t-SNE’s to look for differences in the local and global structures of these embeddings:</p> <table border="1"> <tr> <th></th> <th colspan="3">Global Structure</th> <th colspan="2">Local Structure</th> </tr> <tr> <th></th> <th colspan="2">Clusters Identifiable</th> <th rowspan="2">Relationship between groups</th> <th colspan="2">Intra-group distribution</th> </tr> <tr> <th></th> <th><span style="color:#a48cf4;">Steroids</span></th> <th><span style="color:#36ada4;">Tetracyclines</span></th> <th><span style="color:#a48cf4;">Steroids</span></th> <th><span style="color:#36ada4;">Tetracyclines</span></th> </tr> <tr> <td>PCA</td> <td>yes</td> <td>no</td> <td>nearby</td> <td>disperse</td> <td>disperse</td> </tr> <tr> <td>t-SNE</td> <td>yes</td> <td>yes</td> <td>far apart</td> <td>disperse</td> <td>tight</td> </tr> <tr> <td>UMAP</td> <td>yes</td> <td>yes</td> <td>nearby</td> <td>disperse</td> <td>tight</td> </tr> </table> <p>Based on these observations, PCA is not as successful at maintaining the structure of these groups within the dataset. Specifically, t-SNE and UMAP highlight the uniqueness and homogeneity of <span style="color:#36ada4;">tetracyclines</span>, whereas PCA spreads the <span style="color:#36ada4;">tetracyclines</span> out amidst various other scaffolds in an unidentifiable way. This again supports that, although PCA maintains a few key elements of the global structure, t-SNE and UMAP preserve the global and local structure more consistently throughout the dataset.</p> <p>Differences between the embeddings are less noticeable when examining the <span style="color:#a48cf4;">steroids</span> Each method embeds the <span style="color:#a48cf4;">steroids</span> in a clearly identifiable, yet disperse cluster. UMAP’s <span style="color:#a48cf4;">steroid</span> cluster is the most isolated but as discussed earlier, this extra separation is not particularly meaningful. Both PCA and UMAP place the chemically similar <span style="color:#a48cf4;">steroids</span> and <span style="color:#36ada4;">tetracyclines</span> nearby each other while t-SNE does not. This seemingly implies t-SNE’s global structure is not as informative. However, t-SNE places the <span style="color:#36ada4;">tetracyclines</span> near the <span style="color:#97a431;">β-lactam antibiotics</span>, which, as we will read below, actually makes sense.</p> <p>The differences between the methods are most apparent when we examine the <span style="color:#97a431;">β-lactam antibiotics</span>. We see the namesake <span style="color:#97a431;">β-lactam</span> ring present in every compound and, in terms of global structure, all three methods separate out the <span style="color:#97a431;">β-lactam antibiotic</span> compounds as significantly unique from the rest. Again, the placement in PCA is not as isolated as it is in the other methods but the cluster still stands out.</p> <div class="l-gutter"><p> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/beta_lactam_examples-480.webp 480w,/assets/img/umap_post/beta_lactam_examples-800.webp 800w,/assets/img/umap_post/beta_lactam_examples-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/beta_lactam_examples.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </p> </div> <p>The <span style="color:#97a431;">β-lactam antibiotics</span> are interesting because they give us a view into the level of local organization that t-SNE and UMAP have within the subclasses of this drug-type:</p> <div class="l-page"> <div class="row"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/pca_antibiotic-2-480.webp 480w,/assets/img/umap_post/pca_antibiotic-2-800.webp 800w,/assets/img/umap_post/pca_antibiotic-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/pca_antibiotic-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/tsne_antibiotic-2-480.webp 480w,/assets/img/umap_post/tsne_antibiotic-2-800.webp 800w,/assets/img/umap_post/tsne_antibiotic-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/tsne_antibiotic-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/umap_antibiotic-2-480.webp 480w,/assets/img/umap_post/umap_antibiotic-2-800.webp 800w,/assets/img/umap_post/umap_antibiotic-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/umap_antibiotic-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>We have labelled each of the various <a href="https://en.wikipedia.org/wiki/List_of_%CE%B2-lactam_antibiotics">subclasses</a> of <span style="color:#97a431;">β-lactam antibiotics</span>. They vary based on the particular details of the β-lactam ring system in a given compound. You don’t need to actually understand the differences between the various β-lactam subclasses but know they are fairly small. The importance of visualizing them is to highlight their placements in each of the embeddings.</p> <p>PCA has all of the subtypes mixed together, which makes sense given that in PCA, the principal components we are visualizing are meant to have maximal variance. The nuances within a class of drugs are not particularly high variance so a 2-D PCA plot loses this local structure. On the other hand, t-SNE and UMAP both maintain local structures in the dataset by embedding the <span style="color:#97a431;">β-lactam antibiotics</span> in a way that separates the compounds based on their subclass.</p> <p>Zooming into the bottom of our UMAP plot where the <span style="color:#97a431;">β-lactam</span> compounds are embedded, we can better examine the details of each subclass:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/antibiotics_only-480.webp 480w,/assets/img/umap_post/antibiotics_only-800.webp 800w,/assets/img/umap_post/antibiotics_only-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/antibiotics_only.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>Not only are the <span style="color:#97a431;">β-lactam antibiotics</span> contained in this section of the UMAP embedding, but the individual subclasses are each grouped together. If we look at the interactive versions of the plots (<a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/PCA_antibiotic.html">PCA</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/TSNE_antibiotic.html">TSNE</a>, <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/UMAP_antibiotic.html">UMAP</a>), we see a similar phenomena in the t-SNE embedding. This ability to maintain both global structure and such specificity in the local structure is what makes these methods useful for easily exploring a dataset.</p> <p>If we look even closer at the placement of the <span style="color:#97a431;">β-lactam antibiotics</span>, we discover a fascinating property of our embedding. In the annotated plot above there are two main clusters of <span style="color:#97a431;">β-lactam antibiotics</span> and one outlier compound, in the upper left corner, that has the substructure of a Penicillin (penam), yet appears to be located far away; in fact, it is placed in the <span style="color:#36ada4;">tetracycline</span> cluster.</p> <p>The <span style="color:#36ada4;">tetracycline antibiotics</span> are structurally quite different from the <span style="color:#97a431;">β-lactam antibiotics</span>, and yet they are embedded relatively nearby in both UMAP and t-SNE. We do not expect the distances between clusters to necessarily mean anything, but in this case, their relative positioning actually does.</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/penimocycline_graph_plot-1-480.webp 480w,/assets/img/umap_post/penimocycline_graph_plot-1-800.webp 800w,/assets/img/umap_post/penimocycline_graph_plot-1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/penimocycline_graph_plot-1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>The outlier compound, Penimocycline, contains both the substructure of the Penicillins (penam) and the substructure of the Tetracycline. It is actually <a href="https://www.who.int/medicines/services/inn/StemBook_2011_Final.pdf">classified</a> as both types of antibiotics. When constructing a graph representation of this dataset, this compound is likely connected to both of these two well-connected subgraphs. This would link the two groups together, ultimately leading to their respective clusters of compounds being located near each-other in the final embedding.</p> <p>To investigate this assumption, we can remove Penimocycline from our dataset and generate a new embedding. If this compound is functioning as a link between the <span style="color:#36ada4;">tetracycline antibiotics</span> and <span style="color:#97a431;">β-lactam antibiotics</span>, embedding the dataset without it would break the connection between the groups and their positions would no longer be close to each other.</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/umap_full-480.webp 480w,/assets/img/umap_post/umap_full-800.webp 800w,/assets/img/umap_post/umap_full-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/umap_full.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/umap_no_penimocycline-1-480.webp 480w,/assets/img/umap_post/umap_no_penimocycline-1-800.webp 800w,/assets/img/umap_post/umap_no_penimocycline-1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/umap_no_penimocycline-1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>As hypothesized, in the new embedding on the right, we see that the <span style="color:#36ada4;">tetracycline antibiotics</span> are no longer placed near the <span style="color:#97a431;">β-lactam antibiotics</span>. <span style="color:#36ada4;">Tetracycline antibiotics</span> are still near the <span style="color:#a48cf4;">steroids</span> in each embedding variant, yet the induced separation between the <span style="color:#36ada4;">tetracycline</span> and <span style="color:#97a431;">β-lactam antibiotics</span> has shifted much of the embedding. UMAP is non-deterministic so we re-ran these embeddings (both the original on the left and the modified dataset on the right) multiple times and this phenomena held up. This implies that the single compound actually is the influential node linking the <span style="color:#97a431;">β-lactam</span> and <span style="color:#36ada4;">tetracycline</span> antibiotic clusters. It also reveals how the structure of UMAP is greatly influenced by individual compounds with strong connections between otherwise disconnected subgraphs of the dataset.</p> <p>If you continue exploring the other areas of this dataset in the <a href="https://reverie-public.s3.amazonaws.com/interactive_BBBP_plots/UMAP_antibiotic.html">interactive links</a> or <a href="https://colab.research.google.com/gist/ElanaPearl/444b3331f61485bbe8862db27cb2b968/mapping-chemical-space-with-umap.ipynb">Colab notebook</a> provided you will also find collections of narcotics, sedatives, NSAIDs etc. Examining where specific compounds get placed in each of the embeddings helps explain the structural differences between the embeddings.</p> <h3 id="hyperparameters">Hyperparameters</h3> <p>UMAP has several <a href="https://umap-learn.readthedocs.io/en/latest/parameters.html">hyperparameters</a> that give the user a bit more control over the structure of the final embedding based on their particular priorities:</p> <ul> <li><code class="language-plaintext highlighter-rouge">n_components</code> is the dimensionality of the final embedding. To create visualizations in 2 dimensions we keep this fixed at 2 components.</li> <li><code class="language-plaintext highlighter-rouge">metric</code> is the metric used to determine distance between points. Because we are comparing ECFPs, we use <a href="https://en.wikipedia.org/wiki/Jaccard_index">Jaccard</a> distance (typically referred to as <a href="https://jcheminf.biomedcentral.com/articles/10.1186/s13321-015-0069-3">Tanimoto</a> distance in cheminformatics).</li> <li><code class="language-plaintext highlighter-rouge">n_neighbors</code> determines the prioritization of local versus global structure in the embedding. This value constrains the number of neighbors that a given compound has in the graph representation of the dataset. If n_neighbors is small then the embedding focuses on optimizing the distances between similar compounds to ensure the small differences between them are well represented. If n_neighbors is larger, then the distances between less similar compounds is prioritized.</li> <li><code class="language-plaintext highlighter-rouge">min_dist</code> is the minimum distance between any two points. This affects the tightness of the embedding. The larger min_dist, the more spread out the compounds will be.</li> </ul> <p><code class="language-plaintext highlighter-rouge">n_neighbors</code> and <code class="language-plaintext highlighter-rouge">min_dist</code> can be tuned based on the dataset’s properties and the user’s preferences. The ideal settings may vary based on the dataset. These plots show how varying n_neighbors (along the rows) and min_dist (along the columns) can influence the embedding:</p> <div class="l-body"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/Screen-Shot-2021-03-28-at-9.31.14-PM-2-480.webp 480w,/assets/img/umap_post/Screen-Shot-2021-03-28-at-9.31.14-PM-2-800.webp 800w,/assets/img/umap_post/Screen-Shot-2021-03-28-at-9.31.14-PM-2-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/Screen-Shot-2021-03-28-at-9.31.14-PM-2.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <p>As these plots reveal, the spread of the embedding is quite dependent on the <em>relationship</em> between these two parameters. When min_dist is very small, compounds that are very similar to each-other are placed almost directly on top of each other, which makes it very easy to identify unique clusters. However, it is more difficult to decipher the actual quantity of compounds and the nuanced differences between them. When min_dist is large, it is easier to gauge the full spread of the compounds but more difficult to isolate specific clusters. When n_neighbors is small there are meaningful patterns <em>within</em> clusters but the global structure is less interpretable. As this value increases we see the relationships between the clusters becoming more noticeable as the clusters become less sparse. As with the t-SNE / UMAP comparison, there is no clear answer that one particular set of hyperparameters is <em>always best</em>. Depending on what the user is looking for, there are many great options.</p> <h3 id="other-works">Other Works</h3> <p>For a non-comprehensive list of examples of others using UMAP for chemistry and biology purposes see:</p> <ul> <li><a href="https://www.nature.com/articles/s41467-020-15351-4">Dimensionality reduction by UMAP to visualize physical and genetic interactions</a></li> <li><a href="https://practicalcheminformatics.blogspot.com/2020/06/wicked-fast-cheminformatics-with-nvidia.html">Wicked Fast Cheminformatics with NVIDIA</a></li> <li><a href="https://pubs.rsc.org/en/content/articlehtml/2021/sc/d0sc04263c">One class classification as a practical approach for accelerating π–π co-crystal discovery</a></li> <li><a href="https://www.nature.com/articles/s42256-020-0160-y">Generative molecular design in low data regimes</a></li> <li><a href="https://www.biorxiv.org/content/10.1101/2020.04.22.021360v1.full">Focus Your Screening Library: Rapid Identification of Novel PDE2 Inhibitors with in silico Driven Library Prioritization and MicroScale Thermophoresis</a></li> <li><a href="https://www.blopig.com/blog/2020/12/umap-visualization-of-sars-cov-2-data-in-chembl/">UMAP Visualization of SARS-CoV-2 Data in ChEMBL</a></li> <li><a href="https://assets.researchsquare.com/files/rs-90793/v1_stamped.pdf">De novo design and Bioactivity Prediction of SARS-CoV-2 Main Protease Inhibitors using ULMFit</a></li> </ul> <p>This year yet another alternative to t-SNE, <a href="https://jcheminf.biomedcentral.com/articles/10.1186/s13321-020-0416-x">TMAP</a> has been developed. We haven’t extensively investigated that method yet, but it does seem promising.</p> <h3 id="caveat">Caveat</h3> <p>It is impossible to distill all of the complexities of chemical space into 2-dimensions and a lot of information gets lost in the process. Our low-dimensional embeddings can only be as good as their high-dimensional predecessors, the ECFPs. ECFPs are an imperfect, yet important, method for vectorizing molecules and using the distances between these sparse vectors as the basis for generating the underlying UMAP graphs means that our embeddings can only be as good as those distance metrics. Despite these flaws, we still find these plots to provide significant value to our design efforts.</p> <h2 id="practical-uses-for-umap-at-reverie-labs">Practical Uses for UMAP at Reverie Labs</h2> <p>We use UMAP in two main ways when examining a dataset:</p> <ol> <li><span style="color:#1864ab;">Dataset Specific Embeddings</span>: Examine the particular distribution of compounds within a specific dataset</li> <li><span style="color:#e67700;">Dataset-Agnostic Embeddings</span>: Examine where the compounds of a dataset fit into our general embedding of global chemical space</li> </ol> <p>We can also use these two embeddings as new lenses to view the distribution of <span style="color:#5f3dc4;">measured properties</span> and other <span style="color:#0b7285;">physicochemical properties</span>.</p> <h2 id="dataset-specific-embeddings"><span style="color:#1864ab;">Dataset Specific Embeddings</span></h2> <p>To create a dataset-specific embedding, we fit a UMAP model on the molecules of the particular dataset we are investigating. All of the visualizations we have used so far are based on dataset-specific embeddings of our example BBBP dataset. As we saw when examining the local and global structure of the UMAP embedding, a dataset-specific plot provides a good understanding of the nuances within the dataset.</p> <p>To add an extra dimension, we can color the plots based on any other relevant data we have on the compounds. Earlier we did this using the drug-types of certain compounds, but we can just as easily visualize the compounds colored by the date-of-synthesis, the measured property (in our case blood brain barrier permeability), physicochemical properties, or the dataset-split.</p> <p>This can help answer questions such as: Are there clusters of compounds that have not been actively developed in years? Are all of the most potent compounds in one area of chemical space? If splitting the dataset for model training and evaluation, do certain splitting methods lead to any particular artifacts?</p> <p><strong><span style="color:#5f3dc4;">Measured Properties</span></strong></p> <table style="border-collapse: collapse; border: none;"> <tr> <td width="50%" style="border: none;"> Here we color the points on the plot by the measured property of our dataset: whether the compounds can permeate the blood brain barrier or not. Many of the areas in this space contain exclusively permeable or exclusively impermeable compounds. This reveals that, within this dataset, there are certain types of compounds that are consistently permeable or not and the general scaffold of the compound is sufficient to determine the permeability of many compounds. </td> <td width="65%" style="border: none;"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/measured_property_umap-1-480.webp 480w,/assets/img/umap_post/measured_property_umap-1-800.webp 800w,/assets/img/umap_post/measured_property_umap-1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/measured_property_umap-1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </td> </tr> </table> <p>If we were to use machine learning to model this dataset, we might want to ensure that individual homogeneously-labeled scaffolds are not split between the training and test sets as that could misleadingly inflate performance metrics.</p> <p><strong><span style="color:#0b7285;">Physicochemical properties</span></strong></p> <table style="border-collapse: collapse; border: none;"> <tr> <td width="50%" style="border: none;"> Using Molecular Weight as our example physicochemical property, we observe that UMAP groups the heaviest compounds all together as part of one cluster, which contains many large macrocyclic compounds that are all impermeable. </td> <td width="65%" style="border: none;"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/molecular_weight_umap-480.webp 480w,/assets/img/umap_post/molecular_weight_umap-800.webp 800w,/assets/img/umap_post/molecular_weight_umap-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/molecular_weight_umap.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </td> </tr> </table> <table style="border-collapse: collapse; border: none;"> <tr> <td width="50%" style="border: none;"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/molecular_weight-1-480.webp 480w,/assets/img/umap_post/molecular_weight-1-800.webp 800w,/assets/img/umap_post/molecular_weight-1-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/molecular_weight-1.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </td> <td width="50%" style="border: none;"> Visualizing the Molecular Weight (MW) of the permeable/impermeable compounds we see a similar phenomena in which the majority of the heavy compounds (MW &gt; 500) are impermeable. This aligns with traditional assumptions that compounds heavier than 500 Da will struggle to permeate the blood brain barrier. </td> </tr> </table> <p>The UMAP plot is useful in quickly identifying where the compounds of a given molecular weight lie, what is the diversity of molecular weight within given clusters, and how these properties interplay with blood brain barrier permeability.</p> <h2 id="dataset-agnostic-embeddings"><span style="color:#e67700;">Dataset-Agnostic Embeddings</span></h2> <p>To create a dataset-agnostic embedding, we take advantage of the way that UMAP treats fitting and transforming a dataset as two separate steps. We start by fitting a UMAP model on a large in-house corpus of drug-like compounds that we consider to be representative of ‘drug-like chemical space’. Now, when we want to examine a new dataset, we load up this cached model and use it to transform the compounds of our dataset into this universal embedding space. With <span style="color:#1864ab;">Dataset Specific</span> embeddings, the dimensionality reduction method itself depends on the relationship between all the compounds in the dataset we seek to visualize. With <span style="color:#e67700;">Dataset-Agnostic</span> embeddings, the dimensionality reduction method is treated as fixed so the location of a given compound is agnostic to the other compounds in the dataset.</p> <p>This allows us to examine multiple datasets all with a consistent frame of reference. We can visualize if the data is diverse and covers a wide area in this drug-like space or if it is contained to a few specific areas.</p> <div class="l-body"> <div class="row"> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/global_data_agnostic_bbbp-480.webp 480w,/assets/img/umap_post/global_data_agnostic_bbbp-800.webp 800w,/assets/img/umap_post/global_data_agnostic_bbbp-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/global_data_agnostic_bbbp.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> <div class="col-sm mt-3 mt-md-0"> <figure> <picture> <source class="responsive-img-srcset" srcset="/assets/img/umap_post/global_chemical_space-480.webp 480w,/assets/img/umap_post/global_chemical_space-800.webp 800w,/assets/img/umap_post/global_chemical_space-1400.webp 1400w," sizes="95vw" type="image/webp"/> <img src="/assets/img/umap_post/global_chemical_space.png" class="img-fluid rounded z-depth-1" width="100%" height="auto" data-zoomable="" loading="lazy" onerror="this.onerror=null; $('.responsive-img-srcset').remove();"/> </picture> </figure> </div> </div> </div> <p>Here we visualize both the original embedding of our global chemical space compounds used to fit the general UMAP model, and a Dataset-Agnostic embedding of the BBBP dataset created with this fixed model. The compounds of each are colored by their cluster assignment in this embedded space. To get these cluster assignments we have pre-fit a <a href="https://hdbscan.readthedocs.io/en/latest/">clustering model</a> on the original global chemical space compounds in their UMAP embedding. We then use this pre-trained clustering model on new datasets to quickly determine which of the global UMAP clusters our new compounds fit into. We can even calculate the percentage of global clusters covered by our new dataset to create a quick, quantitative heuristic of chemical space covered by the dataset.</p> <p>As we saw in our case study, distances in UMAP space aren’t necessarily meaningful and global structure can be greatly influenced by individual compounds. Thus, clusters based off of these embedded distances should be interpreted with caution. However, they can definitely help with the visual examination of the dataset, and can establish a quick, quantitative approximation of chemical space coverage. When combining these cluster heuristics with the visuals of the global embedding, we can start to understand the diversity of a given chemical dataset.</p> <h2 id="final-thoughts">Final Thoughts</h2> <p>UMAP is useful because it is easy and quick to create local and global embeddings. This allows us to treat this analysis as a standard piece of our internal pipeline for cleaning, analyzing, and preparing new datasets for analysis and machine learning. Overall UMAP seems to be a great alternative to the more popular methods for dataset embedding, and it would be exciting to see more examples of groups using it for chemical data.</p> <p>This post highlighted UMAP’s value for exploring the chemical space of datasets relevant to drug design. The t-SNE vs. UMAP debate is still greatly contested, but we don’t aim to use this example to prove that UMAP is fundamentally <em>better</em> than t-SNE (although <a href="https://www.biorxiv.org/content/10.1101/549659v1.full.pdf">many</a> have <a href="https://www.nature.com/articles/nbt.4314">argued this</a>, <a href="https://www.biorxiv.org/content/10.1101/2019.12.19.877522v1.full.pdf">some have refuted it</a>, and others have even <a href="https://towardsdatascience.com/why-umap-is-superior-over-tsne-faa039c28e99">refuted the refutation</a>!) Similarly, even though PCA is less useful for our particular goals, PCA also has many <a href="https://stats.stackexchange.com/questions/238538/are-there-cases-where-pca-is-more-suitable-than-t-sne/249520#249520">advantages</a> over its nonlinear alternatives that make it very useful for other purposes. Ultimately, there are many great dimensionality reduction tools to choose from, but we hope that this post has helped to put UMAP on your map.</p>]]></content><author><name></name></author><summary type="html"><![CDATA[]]></summary></entry></feed>