A visual walkthrough of the AlphaFold3 architecture, with more details and diagrams than you were probably looking for.
Do you want to understand exactly how AlphaFold3 works? The architecture is quite complicated and the description in the paper can be overwhelming, so we made a much more friendly (but just as detailed!) visual walkthrough.
This is mostly written for an ML audience and multiple points assume familiarity with the steps of attention. If you’re rusty, see Jay Alammar’s The Illustrated Transformer for a thorough visual explanation. That post is one of the best explanations of a model architecture at the level of individual matrix operations and also the inspiration for the diagrams and naming.
There are already many great explanations of the motivation for protein structure prediction, the CASP competition, model failure modes, debates about evaluations, implications for biotech, etc. so we don’t focus on any of that. Instead we explore the how.
How are these molecules represented in the model and what are all of the operations that convert them into a predicted structure?
This is probably more exhaustive than most people are looking for, but if you want to understand all the details and you like learning via diagrams, this should help :)
We’ll start by pointing out that goals of the model are a bit different than previous AlphaFold models: instead of just predicting the structure of individual protein sequences (AF2) or protein complexes (AF-multimeter), it predicts the structure of a protein, optionally complexed with other proteins, nucleic acids, or small molecules, all from sequence alone. So while previous AF models only had to represent sequences of standard amino acids, AF3 has to represent more complex input types, and thus there is a more complex featurization/tokenization scheme. Tokenization is described in its own section, but for now just know that when we say “token” it either represents a single amino acid (for proteins), nucleotide (for DNA/RNA), or an individual atom if that atom is not part of a standard amino acid/nucleotide.
Interactive Table of Contents
The model can be broken down into 3 main sections:
We also have additional sections describing 4. the loss function, confidence heads, and other relevant training details and 5. some thoughts on the model from an ML trends perspective.
Throughout the model a protein complex is represented in two primary forms: the “single” representation which represents all the tokens in our protein complex, and a “pair” representation which represents the relationships (e.g. distance, potential interactions) between all pairs of amino acids / atoms in the complex. Each of these can be represented at an atom-level or a token-level, and will always be shown with these names (as established in the AF3 paper) and colors:
The actual input a user provides to AF3 is the sequence of one protein and optionally additional molecules. The goal of this section is to convert these sequences into a series of 6 tensors that will be used as the input to the main trunk of the model as outlined in this diagram. These tensors are s, our token-level single representation, z, our token-level pair representation, q, our atom-level single representation, p, our atom-level pair representation, m, our MSA representation, and t, our template representation.
This section contains:
In AF2, as the model only represented proteins with a fixed set of amino acids, each amino acid was represented with its own token. This is maintained in AF3, but additional tokens are also introduced for the additional molecule types that AF3 can handle:
As a result, we can think of some tokens (like those for amino acids) as being associated with multiple atoms, while other tokens (like those for an atom in a ligand) are associated with only a single atom. So, while a protein with 35 standard amino acids (likely > 600 atoms) would be represented by 35 tokens, a ligand with 35 atoms would also be represented by 35 tokens.
One of the key early steps in AF3 is something akin to Retrieval Augmented Generation RAG in language models. We find similar sequences to our protein and RNA sequences of interest (collected into a multiple sequence alignment, “MSA”), and any structures related to those (called the “templates”), then include them as additional inputs to the model called m and t, respectively.
Versions of the same protein found in different species can be quite structurally and sequentially similar. By aligning these together into a Multiple Sequence Alignment (MSA), we can look at how an individual position in a protein sequence has changed throughout evolution. You can think about an MSA for a given protein as a matrix where each row is the sequence of the analogous protein from a different species. It has been shown that the conservation patterns found along the column of a specific position in the protein can reflect how critical it is for that position to have certain amino acids present, and the relationships between different columns reflect relationships between amino acids (i.e. if two amino acids are physically interacting, the changes in their amino acids will likely be correlated across evolution). Thus, MSAs are often used to enrich representations of single proteins.
Similarly, if any of these proteins have known structures, those are also likely to inform the structure of this protein. Instead of searching for full structures, only individual chains of the proteins are used. This resembles the practice of homology modeling, in which the structure of a query protein is modeled based on templates from known protein structures that are presumed to be similar.
The only new part of these retrieval steps compared to AF-multimer is the fact that we now do this retrieval for RNA sequences in addition to protein sequences. Note that this is not traditionally called “retrieval” as the practice of using structural templates to guide protein structure modeling has been common practice in the field of homology modeling long before the term RAG existed. However, even though AlphaFold doesn’t explicitly refer to this process as retrieval, it does quite resemble what has now been popularized as RAG.
How do we represent these templates?
From our template search, we have a 3D structure for each of our templates and information about which tokens are in which chains. First, the euclidean distances between all pairs of tokens in a given template are calculated. For tokens associated with multiple atoms, a representative “center atom” is used to calculate distances. This would be the Cɑ atom for amino acids and C1’ atom for standard nucleotides.
This generates a Ntoken x Ntoken matrix for each template. However, instead of representing each distance as a numerical value, the distances get discretized into a “distogram” (a histogram of distances).
To each distogram, we then append metadata about which chain
To create q, our atom-level single representation, we need to pull all our atom-level features. The first step is to calculate a “reference conformer” for each amino acid, nucleotide, and ligand. While we do not yet know the structure of the entire complex, we have strong priors on the local structures of each individual component. A conformer (short for conformational isomer) is a 3D arrangement of atoms in a molecule that is generated by sampling rotations about single bonds. Each amino acid has a “standard” conformer which is just one of the low-energy conformations this amino acid can exist in, which can be retrieved through a look-up. However, each small molecule requires its own conformation generation. These are generated with RDKit’s ETKDGv3, an algorithm that combines experimental data and torsion angle preferences to produce 3D conformers.
Then we concatenate the information from this conformer (relative location) with each atom’s charge, atomic number, and other identifiers. Matrix c stores this information for all the atoms in our sequences
Finally, we make a copy of our atom-level single representation, calling this copy q. This matrix q is what we will be updating going forward, but c does get saved and used later.
Having generated q (representation of all the atoms) and p (representation of each pair of atoms), we now want to update these representations based on other atoms nearby. Anytime AF3 applies attention at the atom-level, we use a module called the Atom Transformer. The atom transformer is a series of blocks that use attention to update q using both p and the original representation of q called c. As c does not get updated by the Attention Transformer, it can be thought of as a residual connection to the starting representation.
The Atom Transformer mostly follows a standard transformer structure using layer norm, attention, then an MLP transition. However, each step has been adapted to include additional input from c and p (including a secondary input here is sometimes referred to as “conditioning”.) There is also a ‘gating’ step between the attention and MLP blocks. Going through each of these 4 steps in more detail:
Adaptive LayerNorm (AdaNorm) is a variant of LayerNorm with one simple extension. Recall that for a given input matrix, traditional LayerNorm learns two parameters (a scaling factor gamma and a bias factor beta) that adjust the mean and standard deviation of each of the channels in our matrix. Instead of learning fixed parameters for gamma and beta, AdaNorm learns a function to generate gamma and beta adaptively based on the input matrix. However, instead of generating the parameters based on the input getting re-scaled (in the Atom Transformer this is q), a secondary input (c in the Atom Transformer) is used to predict the gamma and beta that re-scale the mean and standard deviation of q.
Atom-Level Attention with Pair-Bias can be thought of as an extension of self-attention. Like in self-attention, the queries, keys, and values all come from the same 1D sequence (our single representation, q). However, there are 3 differences:
Pair-biasing: after the dot product of the queries and keys are calculated, a linear projection of the pair representation is added as a bias to scale the attention weights. Note that this operation does not involve any information from q being used to update p, just one way flow from the pair representation to q. The reasoning for this is that atoms that have a stronger pairwise relationship should attend to each other more strongly and p is effectively already encoding an attention map.
Gating: In addition to the queries, keys, and values, we create an additional projection of q that is passed through a sigmoid, to squash the values between 0 and 1. Our output is multiplied by this “gate” right before all the heads are re-combined. This effectively forces the model to ignore some of what it learned in this attention process. This type of gating appears frequently in AF3 and is discussed more in the ML-musings section. To briefly elaborate, because the model is constantly adding the outputs of each section to the residual stream, this gating mechanism can be thought of as the model’s way to specify what information does or does not get saved in this residual stream. It is presumably named a “gate” after the similar “gates” in LSTM which uses a sigmoid to learn a filter for what inputs get added to the running cell state.
Sparse attention:
Because the number of atoms can be much larger than the number of tokens, we do not run full attention at this step, rather, we use a type of sparse attention (called Sequence-local atom attention) in which the attention is effectively run in local groups where groups of 32 atoms at a time can all attend to 128 other atoms. Sparse attention patterns are more thoroughly described elsewhere on the internet. |
We apply another gate to our data, but this time the gate is generated from our origin atom-level single matrix, c
This step is equivalent to the MLP layers in a transformer, and is called “conditioned” because the MLP is sandwiched in between Adaptive LayerNorm (Step 1 of Atom Transformer) and Conditional Gating (Step 3 of Atom Transformer) which both depend on c.
The only other piece of note in this section is that AF3 uses SwiGLU in the transition block instead of ReLU. The switch from ReLU → SwiGLU happened with AF2 → AF3 and has been a common change in many recent architectures so we visualize it here.
With a ReLU-based transition layer (as in AF2), we take the activations, project them up to 4x the size, apply a ReLU, then down-project them back to their original size. When using SwiGLU (in AF3), the input activation creates two intermediate up-projections, one of which goes through a swish non-linearity (improved variant of ReLU), then these are multiplied before down-projecting. The diagram below shows the differences:
While the data so far has all been stored at an atom-level, the representation learning section of AF3 from here onwards operates at the token-level. To create these token-level representations, we first project our atom-level representation to a larger dimension (catom=128, ctoken=384). Then, we take the mean over all atoms assigned to the same token. Note that this only applies to the atoms associated with standard amino acids and nucleotides (by taking the mean across all atoms attached to the same token), while the rest remain unchanged
Now that we are working in “token space”, we concatenate our token-level features and statistics from our MSA (where available)
Now that we have created sinit, our initialized single representation, the next step is to initialize our pair representation zinit. The pair representation is a three dimensional tensor, but it’s easiest to think of it as a heatmap-like 2D matrix with an implicit depth dimension of cz=128 channels. So, entry zi,j of our pair representation is a cz dimensional vector meant to store information about the relationship between token i and token j in our token sequence. We have created an analogous atom-level matrix p, and we follow a similar process here at the token-level.
To initialize zi,j, we use a linear projection to make the channel dimension of our sequence representation match that of the pair representation (384 → 128) and add the resulting si and sj. To this, we add a relative positional encoding, pi,j
Now we’ve successfully created and embedded all of the inputs that will be used in the rest of our model:
For Step 2, we will set aside the atom-level representations (c, q, p) and focus on updating our token-level representations s and z in the next section (with the help of m and t).
This section is the majority of the model, often referred to as the “trunk”, as it is where most of the computation is done. We call it the representation learning section of the model, as the goal is to learn improved representations of our token-level “single” (s) and “pair” (z) tensors initialized above.
This section contains:
Each individual block is repeated multiple times, and then the output of the whole section is fed back into itself again as input and the process is repeated (this is called recycling).
Each template (Ntemplates=2 in the diagram) goes through a linear projection and is added together with a linear projection of our pair representation (z). This newly combined matrix goes through a series of operations called the Pairformer Stack (described in depth later). Finally, all of the templates are averaged together and go through - you guessed it - another linear layer.
This module greatly resembles what was called “Evoformer” in AF2, and the goal of it is to simultaneously improve the MSA and pair representations. It does a series of operations independently on these two representations then also enables cross-talk between them.
The first step is to subsample the rows of the MSA, rather than use all rows of the MSA previously generated (which could be up to 16k), then add a projected version of our single representation to this subsampled MSA.
Next, we take the MSA representation and incorporate it into the pair representation via the “Outer Product Mean”. Comparing two columns of the MSA reveals information about the relationships between two positions in the sequence (e.g. how correlated are these two positions in the sequence across evolution). For each pair of token indices i,j, we iterate over all evolutionary sequences, taking the outer product of ms,i and ms,j, then averaging these across all the evolutionary sequences. We then flatten this outer product, project it back down, and add this to the pair representation zi,j (full details in diagram). While each outer product only compares values within a given sequence ms, when we take the mean of these, that mixes information across sequences. This is the only point in the model where information is shared across evolutionary sequences. This is a significant change to reduce the computational complexity of the Evoformer in AF2.
Having updated the pair representation based on the MSA, the model next updates the MSA based on the pair representation. This specific update pattern is called row-wise gated self attention using only pair bias, and is a simplified version of self attention with pair bias, discussed in the Atom Transformer section, applied to every sequence (row) in the MSA independently. It is inspired by attention, but instead of using queries and keys to determine what other positions each token should attend to, we just use the existing relationships between tokens stored in our pair representation z.
In the pair representation, each zi,j is a vector containing information about the relationship between tokens i and j. When the tensor z gets projected down to a matrix, each zi,j vector becomes a scalar that can be used to determine how much token i should attend to token j. After applying row-wise softmax, these are now equivalent to attention scores, which are used to create a weighted average of the values as a typical attention map would.
Note that there is no information shared across the evolutionary sequences in the MSA as it is run independently for each row.
The last step of the MSA module is to update the pair representation through a series of steps referred to as triangle updates and attention. These triangle operations are described below with Pairformer, where they are used again. There are also some transition blocks that use SwiGLU to up/down project the matrix as was done in the Atom Transformer.
Diagram from AF3 supplement
Having updated our pair representation based on the templates and MSA Module, we now ignore them for the rest of the model. Instead, only the updated pair representation (z) and single representation (s) enter the Pairformer and are used to update each other. As the transition blocks have already been described, this section focuses on the Triangle Updates and Triangle Attention, then briefly explains how the Single Attention with Pair Bias differs from the variant described earlier. These triangle-based layers were first introduced in AF2 are one of the pieces that not only remained in AF3, but now are even more present in the architecture, so they get quite a bit of attention.
The guiding principle here is the idea of the triangle inequality: “the sum of any two sides of a triangle is greater than or equal to the third side”. Recall that each zi,j in the pair tensor encodes the relationship between positions i and j in the sequence. While it does not literally encode the physical distances between pairs of tokens, let’s think about it for a moment as if it did. If we imagine that each zi,j is the distance between two amino acids and we know zi,j=1 and zj,k=1. By the triangle inequality zi,k cannot be larger than \(\sqrt{2}\). Knowing two of the distances gives us a strong belief about what the third distance must be. The goal of triangle updates and triangle attention are to try to encode these geometric constraints into the model.
The triangle inequality is not enforced in the model but rather, it is encouraged through ensuring each position zi,j is updated by looking at all possible triplets of positions (i,j,k) at a time. So zi,j is updated based on zj,k and zi,k for all other atoms k. Because z represents the complex physical relationship between these tokens, rather than merely their distance, these relationships can be directional. So for zi,j, we also want to encourage consistency with zk,i and zk,j for all atoms k. If we think of the atoms as a graph, with z as a directed adjacency matrix, it makes sense that AlphaFold calls these “outgoing edges” and “incoming edges”.
Consider row i=0 of this adjacency matrix, and let’s say we want to update z0,2, which has been highlighted in purple. The idea behind the update is that if we know the distances between 0→1 and 2→1, that gives us some constraints on what 0→2 can be. Similarly, if we know the distances between 0→3 and 2→3, this also gives us a constraint on 0→2. This would apply for all atoms k.
So, in the triangle updates and attention, we effectively look at all directed paths for 3 nodes in this graph (a.k.a triangles, hence the name!).
Having carefully looked at the triangle operations from a graph theory perspective, we can see how this is implemented with tensor operations. In the outgoing update, every position zi,j in the pair representation gets updated independently based on a weighted combination of the other elements in the same row (zi,j), where the weighting of each zi,k is based on the third element in its outgoing edge triangle (zj,k).
Practically, we take three linear projections of z (called a, b, and g). To update zi,j, we take an element-wise multiplication of row i from a and row j from b. We then sum over all these rows (different values of k), and gate with our g projection.
For the incoming update, we effectively do the same thing but flipping the rows with the columns, so to update zi,j we take a weighted sum of the other elements in the same column (zk,j), where the weighting of each zk,j is based on the third element in its outgoing edge triangle (zk,i). After creating the same linear projections, we take an element-wise multiplication of column i from a and column j from b, and sum over all the rows of this matrix. You’ll find that these operations exactly mirror the graph-theory adjacency view described above.
After our two triangle update steps, we also update each zi,j using triangle attention for the outgoing edges and triangle attention for the incoming edges. The AF3 paper refers to the “outgoing edges” as attention “around starting node” and “incoming edges” as attention “around ending node”.
To build up to triangle attention, it can be helpful to start with typical self-attention over a 1D sequence. Recall that queries, keys, and values are all transformations of the original 1D sequence. An attention variant called axial attention extends this to matrices by applying independent 1D self-attention over the different axes of a 2D matrix (the rows, then the columns). Triangle attention adds the triangle principle we discussed earlier to this, updating zi,j by incorporating zi,k and zj,k for all atoms k. Specifically, in the “starting node” case, to calculate the attention scores along row i (to determine how much zi,j should be influenced by zi,k), we do a query-key comparison between zi,j and zi,k as usual, then bias the attention based on zj,k as is shown above.
For the “ending node” case, we again swap rows for columns. For zi,j, the keys and values will both come from column i of z, while the bias will come from column j. So, when comparing the query zi,j with the key zk,i, we bias that attention score based on zk,j. Then, once we have attention scores over all k, we use our values vectors from column i.
Now that we’ve updated our pair representation with these four triangle steps, we pass the pair representation through a Transition block as described above. Finally, we want to update our single representation (s) using this new updated pair representation (z), so we will use single attention with pair bias, pictured below. This is identical to Single Attention with Pair Bias described
We repeat the Pairformer for 48 blocks, eventually creating strunk and ztrunk.
Now, with these refined representations, we are ready to use s and z to predict the structure of our complex. One of the changes introduced in AF3 is that entire structure prediction is based on atom-level diffusion. Existing posts more thoroughly explain the intuition and math for diffusion, but the basic idea of a Diffusion Model is to start with real data, add random noise to your data, then train a model to predict what noise was added. Noise is iteratively added to the data over a series of T timesteps to create a sequence of T variants of each datapoint. We call the original data point xt=0 and the fully noised version xt=T. During training, at timestep t, the model is given the xt and predicts what noise was added between xt-1 and xt. We take a gradient step on the predicted noise added compared to the actual noise that had been added.
Then, at inference time, we simply start with random noise, which is equivalent to xt=T. For every time step, we predict the noise the model thinks has been added, and remove that predicted noise. After a pre-specified number of timesteps, we end up with a fully “denoised” datapoint that should resemble the original data from our dataset.
Conditional Diffusion lets the model ‘condition’ these de-noising predictions on some input. Practically this means that for each step of the model, it takes three inputs:
As a result, the final generation is not just a random example that resembles the training data distribution, but should specifically match the information represented by this conditioning vector.
With AF3, the data we learn to de-noise is a matrix x with the x,y,z coordinates of all the atoms in our sequences. During training, we add Gaussian noise to these coordinates until they are effectively fully random. Then at inference time, we start with random coordinates. At each time step, we first randomly rotate and translate our entire predicted complex. This data-augmentation teaches the model that any rotation and translation of our complex is equally valid, and replaces the much more complicated Invariant Point Attention used in AF2.
In each de-noising diffusion step, we condition our prediction on multiple representations of the input sequences:
The AF3 paper breaks down its diffusion process into 4 steps that involve moving from tokens to atoms, back to tokens, and back to atoms:
To initialize our token-level conditioning representation, we concatenate ztrunk to the relative positional encodings then project this larger representation back down and pass it through several residual-connection transition blocks.
Similarly, for our token-level single representation, we concatenate the very first representation of the input created at the start of the model (sinputs) and our current representation (strunk), then project it back down to its original size. We then create a Fourier embedding based on the current diffusion time step
At this point, our conditioning vectors are storing information at a per-token level, but we want to also run attention at the atom-level. To address this, we take our initial atom-level representations of the input created in the Embedding section (c and p), and update them based on the current token-level representations, to create atom-level conditioning tensors.
Next, we scale the atom’s current coordinates (x) by the variance of the data, effectively creating “dimensionless” coordinates with unit variance (called r). We then update q based on r such that q is now aware of the atom’s current location. Finally, we update q with the Atom Transformer (which also takes the pair representation as input), and aggregate the atoms back to tokens as we’ve previously seen.
At the end of this step, we return
The goal of this step is to apply attention to update our token-level representation of the atom coordinates and sequence information, a. This step uses the Diffusion Transformer visualized during input preparation, which mirrors the Atom Transformer but for tokens.
Now, we return to atom space. We use our updated a (token-level representations based on current “center atom” locations) to update q (atom-level representation of all atoms based on current location) using the Atom Transformer. As was done in step 3, we broadcast our tokens representation to match the number of atoms we started with (selectively duplicating the tokens that represent multiple atoms), and run the Atom Transformer. Most importantly, one last linear layer maps this atom-level representation q back to R3. This is the key step: we’ve used all these conditioning representations to generate coordinate updates rupdate for all atoms. Now, because we generated these in the “dimensionless” space rl, we carefully re-scale
With that, we’ve completed our tour through the main architecture of AlphaFold 3! Now we provide some additional information about the loss function, auxiliary confidence heads, and training details.
\(L_{\text{loss}} = L_{\text{distogram}}*\alpha_{\text{distogram}}+L_{\text{diffusion}}*\alpha_{\text{diffusion}}+L_{\text{confidence}}*\alpha_{\text{confidence}}\)
The loss is a weighted sum of 3 terms:
The output of our model is atom-level coordinates, which can easily be used to create an atom-level distogram
The diffusion loss itself is a weighted sum of three terms each computed over the atom positions, additionally scaled by the amount of noise
The goal of this loss is not to improve the accuracy of the structure, but rather to teach the model to predict its own accuracy. This loss is a weighted sum of 4 terms that each correspond to a method of evaluating the quality of a predicted structure:
\[L_{\text{confience}} = L_{\text{plDDT}} + L_{\text{PDE}} + L_{\text{resolved}} + L_{\text{PAE}} * \alpha_{\text{PAE}}\]lDDT Atom-level “local distance difference test”, capturing the expected accuracy of an atom’s predicted distances to nearby atoms.
PAE Predicted alignment error between token i’s predicted and the ground-truth positions. We first rotate and translate the predicted token i and ground-truth token i into the frame of token j. That is, if we assume for a moment token j is in exactly its ground-truth position, we predict how close token i is to where it should be, based on its relation to token j.
PDE Predicted distance error between tokens, capturing the accuracy of predicted differences between all pairs of tokens.
Experimentally resolved prediction The model predicts which atoms were experimentally resolved (not every atom is experimentally resolved in every crystal structure).
To get these confidence losses for each of the metrics, AF3 predicts values for these error metrics, then these error metrics are calculated on the predicted structure, and the loss is based on the difference between these two. So even if the structure is really incorrect and the PAE is high, if the predicted PAE is also high, the Lpae will be low.
These confidence predictions are generated mid-way through the diffusion process. At a selected diffusion step t, the predicted coordinates rt are used to update the single and pair representations created in the representation learning trunk. The predicted errors are then calculated from linear projections of the updated pair representation (for PAE and PDE) or this updated single representation (pLDDT and experimentally resolved). Then, the actual error metrics are calculated based on the same generated atom coordinates (process described below, if interested) for comparison.
While these terms are included in the confidence head loss, gradients from these terms are only used to update the confidence prediction heads and do not affect the rest of the model.
pLDDT: The LDDT for atom l is calculated in the following way: in the current predicted structure, we calculate the distance between atom l and a set of atoms R that is indexed by m, and compare this to the ground truth equivalent. To be in this set, an atom m must be part of a polymer chain, within 15 or 30 Å of l depending on the molecule m is a part of, and the center atom of a token. We then calculate four binary distance tests with increasingly tight thresholds (4, 2, 1, and .5 Å) and take the average pass rate, and sum over the atoms in R. We bin this percentage into 50 bins between 0 and 1.
At inference time, we have a pLDDT head. This head takes the single representation of a given token, repeats it out across all the atoms "attached" to this token
Predicted Alignment Error (PAE): Every token is considered to have a frame, that is a 3D coordinate frame created from three atoms (called a, b, c) involved in that token. Atom b within those three atoms forms the origin in this frame. In cases where each token has a single atom "attached", the center atom of the frame is the single atom of the token, and the two other nearest tokens of the same entity (e.g., same ligand) form the basis of the frame. For every token pair (i,j) we re-express the predicted coordinates of the center atom for token_i using the frame of token_j. We do the same for the ground-truth coordinates of the center atom of token_i. The euclidean distance between these transformed true and predicted coordinates of the center atom of token_i is our alignment error, binned into 64 bins. We predict this alignment error from the pair representation zi,j, projecting it to 64 dimensions that we treat as logits and convert to probabilities with a softmax. We train this head with a classification loss, which each bin as a class. See here for additional details.
Third, AF3 predicts the distance error (PDE) between tokens. The true distance error is calculated by taking the distance between center atoms for every token pair, and binning these distances over 64 uniformly-sized bins from 0 Å to 32 Å. The predicted distance error comes from projecting the pair representation zi,j plus the pair representation zj,i into 64 dimensions that we again treat as logits, and again convert to probabilities with a softmax.
Finally, AF3 predicts whether each atom was experimentally resolved in the ground-truth structure. Similar to the pLDDT head, we repeat the si single representation out for the number of atoms this token represents, and project to 2 dimensions and use a binary classification loss.
Now that the architecture is covered, the last pieces are some of the additional training details.
As introduced in AF2, AF3 recycles its weights; that is, rather than making the model deeper, the model weights are re-used and inputs are run through the modules multiple times to continually improve the representations. Diffusion inherently uses recycling at inference time, as the model is trained to incorporate the timestep information and use the same model weights for every time step.
AF3 uses a mix of synthetic training data generated by itself (via self-distillation) but also by AF2, via cross-distillation. Specifically, the authors note that, by switching to the diffusion-based generative module, the model stopped producing the characteristic “spaghetti” regions that allowed users of AF2 to visually identify low-confidence and likely disordered regions. Just visually looking at the diffusion-based generations, all regions appeared equally high in confidence, making it more difficult to identify potential hallucinations.
To solve this problem, they included generations from AF2 and AF-Multimer in the training data for AF3, allowing the model to learn that, when AF2 was not confident in its prediction, it should output these unfolded regions and to “instruct” AF3 to do the same.
While no part of the model has an explicit restriction on the length of the input sequences, the memory and compute requirements increase significantly with sequence length (recall the multiple O(Ntokens3 operations)). Thus, for efficiency the proteins get randomly cropped. As introduced in AF-multimer, because we want to model the interactions between multiple chains, the random cropping needs to include all of these. They use 3 methods for cropping and all 3 of these are used in different proportions depending on the training data (ex: PDB crystal structure vs disordered PDB complex vs distillation, etc.)
While a model trained on random crops of 384 can be applied to longer sequences, to improve the model’s ability to handle these sequences, it is iteratively fine-tuned on larger sequence lengths. The mix of datasets and other training details is also varied in each training stage as is shown in the table below.
The authors note that AF3’s loss does not include a clash penalty for overlapping atoms. While switching to a diffusion-based structure module means the model could in theory predict two atoms to be in the same location, this seems to be minimal after training. That said, AF3 does employ a clashing penalty when ranking generated structures.
Although the diffusion process sounds quite involved, it is still significantly less computationally expensive than the trunk of the model. Thus, the AF3 authors found that it is more efficient from a training perspective to expand the batch size of the model after the trunk. So for each input structure, it gets run through the embedding and trunk, then 48 independent data-augmented versions of the structure are applied, and these 48 structures are all trained in parallel.
That’s it for the training process! There are some other small details but this is probably already more than you need, and if you’ve made it this far, the rest should be easy to pick up from reading the AF3 supplement.
Having walked so thoroughly through the architecture of AF3 and its comparisons to AF2, it is interesting how the choices made by the authors fit into broader Machine Learning trends.
At the time AF2 was released, it was not common to include retrievals from the training set at inference time. In the case of AF, utilizing an MSA and template search. MSA-based methods were being used for protein modeling, but this type of retrieval was less used in other areas of Deep Learning (i.e., ResNets do not embed relevant training images at inference time when classifying a new image in Computer Vision, for example). Although AF3 reduces the emphasis on the MSA compared to AF2 (it is no longer operated on and updated in the 48 blocks of the Evoformer/Pairformer), they still incorporate both the MSA and templates, even as other protein prediction models such as ESMFold have dropped retrieval in favor of fully parametric inference.
Interestingly, some of the largest and most successful Deep Learning models now often include similar additional information at inference time. While the details of the retrieval systems are not always disclosed, Large Language Models routinely use Retrieval Augmented Generation systems such as a traditional web search at inference time to orient the model toward relevant information (even if that information was likely already in its training data) that should guide inference. It will be interesting to see how the use of directly relevant examples at inference time develops in the future.
One of the major components of AF2 that is even more present in AF3 is Pair-Bias Attention. That is, attention where the queries, keys, and values all originate from the same source (like in self-attention), but where there is a bias term added to the attention map from another source. This effectively acts as a light-touch version of information sharing, without full cross-attention. Pair-Bias Attention appears in almost every module. While this type of attention is now used in other protein modeling architectures, we have not seen this particular type of cross-biasing used in other fields (although that does not mean it hasn’t been done!). Perhaps it only works well here because the pair-representation is naturally analogous to a self-attention map already, but is an intriguing alternative to pure self or pure cross-attention.
Self-supervised models like ESM have been able to achieve impressive results at predicting protein structure by replacing the MSA embedding with a “probabilistic MSA” using self-supervised pre-training. In AF2, the model had an additional task that predicted masked tokens from the MSA, achieving a similar self-supervision, but that was removed with AF3. We have not seen commentary from the authors on why they did not use any self-supervised language modeling pre-training approach on the MSA, and in fact decreased the compute used to process the MSA. Three possible reasons self-supervised learning is not used to initialize the MSA embeddings are 1) they viewed the massive pre-training phase as a suboptimal use of compute 2) they tried it and found that including a small MSA module outperformed pre-trained embeddings and was worth the additional inference-time cost or 3) utilizing a mix of pre-trained embeddings for amino acid tokens and randomly initialized embeddings for DNA/RNA/ligands would not be compatible or underperformed fully supervised training on their hybrid atom-token structure.
As in AF2, AF3 continues to use a mix of MSE and binned classification losses. The classification components are interesting as, if the model predicts a distogram bin that is only off-by-one, it gets no “credit” for being close rather than way off. It is unclear what informed this design decision, but perhaps the authors found the gradients to be more stable than working with several different MSE losses, and perhaps the per-atom losses saw so many gradient steps that the additional signal from a continuous loss would not have proven beneficial.
AF3’s architecture incorporates several design elements reminiscent of recurrent neural networks that are not typically found in traditional transformers:
It was shown in the AF2 ablations that the recycling was important, but there was little disucssion on the importance of gating. Presumably it helps with training stability as in LSTMs but it is interesting that it is so prevalent here yet not in many other transformer-based architectures.
The use of AF2 generations to re-introduce its distinctive style specifically for low-confidence regions is very interesting. If there is a lesson here, it may be the most practical of all: If your previous model is doing one specific thing better than your new model, you can try cross-distillation to get the best of both worlds!
If you’ve made it this far, thanks for reading and hope it was helpful!! If you have any questions / comments / corrections / feedback feel free to reach out to us on twitter (E, J) or email (E, J) !
Special thanks to Kristy Carpenter, Nicholas Joseph, Kyle Swanson, and Kara Liu for giving feedback on this 💜