Deep Mind Astonishes Again
Lost in all the excitement for the new wave of large language models was a paper published in Nature, the big brains over at Deep Mind. It’s really staggering that we live in a time where an assembly writing AI is barely a blip on the radar, but here we are. Published in June of 2023, Deep Mind details their new agent Alpha Dev and how it was able to optimize sorting algorithms that last saw changes a decade ago. In fact, these changes are so impressive that they were included in the C++ LLVM standard library.
The impact of this is difficult to understate. Sorting is a critical operation, particularly in light of the massive amounts of data being generated every second of the day. It’s no exaggeration to say that sorting algorithms get deployed a trillion times a day, so any time savings are going to be a huge boon to data centers across the globe.
If we check out the results from the paper, we can see that latency, or how long it takes the program to execute, is improved around 4 – 5% in most of the sorting algorithms. When Alpha Dev is turned loose in the Var Int environment, a 3X speedup is achieved. If you’re not familiar with it, this is an algorithm, written Google, for serializing structured data so that it can be transmitted over a network.
Improvements to code length are also pretty impressive. Even in the case of a length 4 fixed sort, Alpha Dev managed to match the human benchmark. Improvements to more complex algorithms are in the neighborhood of tens of instructions removed. The reason for these improvements boils down to the “alien” way in which the Alpha Dev agent “thinks” about sorting. We’ll get to an example later on, but the AI really does approach the problem in a different way than humans.
The Problem of Sorting
Sorting large sequences of data really boils down to sorting smaller sequences and then merging the smaller sorted sequences together in such a way as to preserve the ordinality. The simplest way of doing this is called the merge sort, where we take a list of unsorted elements and break it out into smaller lists, and sort those, and then systematically merge those together comparing elements of each of the sorted lists.
Really, there are two ways in which we can sort data. We can use a fixed sort, where we are sorting a pre-determined amount of data, or we can use a variable sort. In the variable sort of N elements, we can sort up to N elements. In the simplest case, we simply use some if-then statements to find the length of the variable length list, and then call the appropriate fixed length sorting algorithm. Even then, the code for variable sorting is longer than that of fixed sorting, if for no other reason than we need to stick in some conditionals and function calls.
Because efficiency, which I mean minimizing latency (execution time), is paramount, sorting algorithms are usually written in C++. This code is then compiled and turned into assembly code, which is finally converted into machine code for the computer to execute. Alpha Dev skips the C++ part and goes straight to writing assembly code, because it’s a coding ultra giga Chad.
Just so we’re on the same page, let’s take a look at a figure from the paper, where they show us an example of going from C++ to assembly code.
Looking at the left panel, we see some C++ code for sorting a variable length array of up to 2 elements. The first thing we want to do is make sure we’re sorting exactly two elements. We can do this with a switch statement, which if you’re a Python developer you’re probably not familiar with. You can think of it like a series of if then statements, where we are comparing the length variable to values in the switch statements. So, if length == 1 or 0 then return. If we have precisely 2 elements, then we can set about the business of sorting.
This can be done in a few lines of code. Simply store the first element in the array in a temporary variable. Then we set the first element of the array to the smaller of the two elements. Next, we compare the second element of the array to our temporary variable (which holds the original first element), and set it to the smaller of the two. When we’re done, we simply return.
Also worth nothing, if you’re a python programmer, you should know that we modified the array in place. That’s why you see a naked return statement, and not something like “return a”. We modified the contents of the array directly, and so we don’t need to return it.
Looking at the right panel, we can see the equivalent assembly code. To parse this, let’s get a crash course in assembly.
A Crash Course in x86 Assembly Language
Assembly language is basically one step above binary code. We are going to deal with moving values around in the memory on our cpu, as well as temporary memory like RAM (or even the disk). We could also handle input / output (i.e. from the mouse and keyboard), file operations, or anything else we would have to do on a computer. But for the purposes of sorting data, we only need a few commands. All of our commands will look like Opcode <operand 1>, <operand 2>. We’ll use the mov, cmp, jX, and cmovX commands.
So for instance mov %eax, %ecx
will move (copy and paste!) the value contained in the eax register to the ecx register. Clear? Let’s check out some of the other commands we’ll need:
cmp <A>, <B> ; compares values stored in location A and B. Sets conditional register flag appropriately
jX <A> ; conditional jump. X can be (G)reater, (L)ess, etc. Jumps to label <A> if X register is set.
cmovX <A>, <B> ; conditional move command. Moves from A to B if X is true, where X is defined as above.
Code language: HTML, XML (xml)
It may seem like this is a paltry set of commands, not fit for doing anything useful, but in reality we can use combinations of these to sort an array of any length. If we return to the panel on the right from the code comparison, we can now parse out what the assembly code is doing.
First we compare the value in the edi register, which holds the length of the list, to the constant value of 2. The edi register is filled with a 2, so the two are equal and we set our equality flag register to 1.
then we execute a jump not equal command, so if the result of the previous compare operation is that the array doesn’t have exactly two elements, then we jump down to the .Label , which is just a return statement. Since we are dealing with two elements, we execute the rest of the code.
The first mov statement copies a value from the memory location rsi to the eax register
then the next mov statement transfers another value from memory to the ecx register. These are the two elements of our array. So we’ve loaded them from RAM onto our registers.
Then we compare the values on the eax and ecx register.
The equivalent of our temp variable comes next when we mov the value from eax to the edx register.
Then if our above comparison statement evaluates to less than, meaning if the value at ecx is less than the value at eax, then we move the value in ecx to edx.
Then we transfer our smallest element, on the edx register to the rsi memory location.
Next we have another conditional move. If the value on register ecx was greater than the one on eax, then we move from ecx to eax. It’s not, so we do nothing.
Finally, we move from eax to the 4rsi memory location.
Now the sequence of 2 elements is sorted, with just a handful of commands.
The AlphaDev Agent
It’s pretty clear from the above example that even small sorting algorithms are going to prove intractable for simple agents. While the action space is discrete (the opcodes and operands form a discrete and finite set), it is hopelessly large for something like Q learning.
AlphaDev builds on the work done AlphaZero. The basic idea here is that one can use deep neural networks to approximate both the policy and the value function, and then use Monte Carlo tree search to perform simulations that allow the agent to update the weights of the deep neural networks. This is pretty vague, but the AlphaZero agent deserves its own post, so I’ll just leave it at that.
One critical component to solving this problem is the deep neural network that is used to represent the sorting algorithm. The AlphaDev representation network is made of two parts.
The first part is a transformer network that handles encoding the high level assembly instructions into an embedding vector that the AlphaDev agent can actually “understand”. If you recall, transformers are the networks behind the explosion of large language models we’ve witnessed recently, so it’s no surprise to see them here. They’re well suited for the task of deducing the meaning and structure of language, of which computer language is a subset.
Before the assembly instructions can be fed to the transformer, they first have to be converted to a one hot encoding. This one hot encoding can be constructed taking every combination of operand and opcodes and treating them as a separate word in our dictionary. Each gets a unique one hot encoded vector, and then these are stacked to produce the input to the transformer network.
Simultaneously, the state of the x86 memory and registers can be represented as a matrix. Details here are vague, but there are numerous options for representing the computer architecture. One simple option is a one dimensional vector of heterogenous data types. Each index in the array represents a register or memory location, which includes the conditional flags.
Regardless of the implementation details, the state is fed through an encoding network to generate an embedding. The embedding is then combined with the algorithm embedding to form the state of the system that the AlphaDev agent can work with. I should point out that the authors don’t specify what they use for a cpu state encoder, so I would presume this is a more traditional deep neural network rather than a transformer network.
AssemblyGame
State representations can be fed into the AlphaDev agent’s policy, which produces a probability distribution over the action space. The action space is the combination of operations and operands, and once chosen, an action is appended to the end of the current algorithm.
At each time step, the algorithm is tested inputting a test sequence that should produce a known output. The data is sorted according to the current algorithm, and the output is compared to the solution. The AlphaDev agent is then rewarded based on whether or not the output is correct as well as some sort of penalty for latency.
In some cases, such as in fixed length sorting, program length and execution latency are highly correlated, and so AssemblyGame penalizes the AlphaDev agent for algorithm length. In cases where the two are not correlated, such as variable length sorting, latency itself is used as a negative reward factor.
The combination of these reward factors will encourage the AlphaDev agent to produce correct sorting algorithms that minimize both latency and program length. This is facillitated through the use of a value function network, which estimates the present value of rewards that follow the current time step.
In fact, Alphadev comes equipped with two value functions. One for predicting program correctness, and the other for predicting latency. The latency head uses actual latency values to train itself, rather than some proxy. Evidently, the use of two separate value functions, as opposed to one function to model the combination of latency and correctness rewards, results in better performance with respect to real world latency.
How AlphaDev Approaches Sorting
To get an appreciation for the innovations this new agent was able to come up with, let’s take a look at the current human benchmark for sorting.
The above figure is called a sorting network. Specifically, it’s a sorting network for a fixed length sort 4. The idea here is that the horizontal lines represent a data input/output. The vertical lines, called comparators, swap the data that ride the horizintal lines. So for instance, the first comparator compares the value of 2 and 1, and since 1 is smaller than 2 it moves up to the top horizontal line. Then the 3 and 4 are compared, where the 3 is smaller and so it stays on top. Then the 3 and 1 are compared, where again 1 is smaller and stays on top. Simultaneously, the 2 and 4 are compared with no change since 2 is smaller and on top. The final comparator compares the 2 and 3, swapping their positions since the 2 is smaller.
A network is optimized when it produces correct output with the least number of comparators possible.
With that in mind, let’s check out another figure from the paper. Here, we can see AlphaDev’s work on optimizing the right two comparators from a fixed length sort 3 network, as well as the corresponding assembly code.
What we see is that AlphaDev is able to remove an entire instruction (recall from the first data table that the fixed length sort 3 goes from 18 to 17 instructions with AlphaDev’s help) from the assembly code.
The circled portion of the sorting network basically takes inputs on lines A, B, C and performs a min operation on them. However, right before that is an operation that sorts lines B and C, so it’s necessarily the case that B <= C. We then only have to compare A and B to get a properly sorted output on line A.
What’s great is that this sorting operation (fixed length sort 3) is used in many other sorting algorithms. Every time it appears, AlphaDev is able to use this trick to save an instruction. In fact, it happens so often that the DeepMind team gave this a name: the AlphaDev swap move.
From the first data table, we know that AlphaDev also changes up variable length sorting algorithms. It does so in some surprising ways.
On the left we see a flow diagram for the benchmark variable sort 4 algorithm. A variable length array is input and the program determines the number of elements in the array. Then we just call the appropriate fixed length algorithm and return the sorted output. Pretty straightforward and easy to conceptualize.
On the right we can see what AlphaDev thinks of this, and it’s really quite insightful. First, one should know if the array has fewer than two elements; return if it does (a single element is sorted definition). If it has precisely two elements, then call the sort 2 algorithm and return. If it has more than 2 elements, then automatically call the sort 3 algorithm on the first three elements of the array. If we only have 3 elements, we’re done and we can return. If we have 4, then we can figure out where that final unsorted element goes into the sorted array.
The innovation is that the final sort operation, of merging the final unsorted element into a sorted 3 element array, is significantly faster than the traditional fixed length sort 4 algorithm. This new variable sort 4 algorithm shaves off 32 instructions and saves about 5% in execution time.
Wrapping Up
Reinforcement learning can often be thought of as a toy. This is a consequence of the fact that agents are often trained on toy environments, typically something like an Atari game. However, these algorithms are incredibly powerful at optimizing processes long thought optimal. Far from being a mere toy, reinforcement learning is a powerful tool for any researcher looking to optimize some aspect of their process.
The idea that an agent can write assembly code to optimize sorting operations probably seemed like science fiction a mere decade ago. Yet, here we are.
Personally, I can think of many uses for AlphaDev, or some variant of it, that could revolutionize computing. It’s not hard to see a future where RL agents handle writing code for nearly every optimizable task, and humans are left as artists to design problems as games.
Until then, you’re going to need to learn how to stay at the edge of reinforcement learning research. That’s what we do here at the Neuralnet Academy. If you want to level up your career, then check out our paid courses.