Better-than-Cubic Complexity for Matrix Multiplication in Rust
<p>Years ago, I <a href="https://github.com/mikecvet/strassen/blob/master/cpp/src/strassen/strassen_matrix_multiplier.hpp#L159" rel="noopener ugc nofollow" target="_blank">wrote an implementation</a> of the <a href="https://en.wikipedia.org/wiki/Strassen_algorithm" rel="noopener ugc nofollow" target="_blank">Strassen matrix multiplication algorithm</a> in C++, and recently <a href="https://github.com/mikecvet/strassen/tree/master/rust/src" rel="noopener ugc nofollow" target="_blank">re-implemented it in Rust</a> as I continue to learn the language. This was a useful exercise in learning about Rust performance characteristics and optimization techniques, because although the <em>algorithmic complexity</em> of Strassen is superior to the naive approach, it has a high <em>constant factor</em> from the overhead of allocations and recursion within the algorithm’s structure.</p>
<ul>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#b85b" rel="noopener ugc nofollow">The general algorithm</a></li>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#740c" rel="noopener ugc nofollow">Transposition for better performance</a></li>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#d7d7" rel="noopener ugc nofollow">Sub-cubic: How the Strassen algorithm works</a></li>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#8747" rel="noopener ugc nofollow">Parallelism</a></li>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#fc64" rel="noopener ugc nofollow">Benchmarking</a></li>
<li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#98ce" rel="noopener ugc nofollow">Profiling and performance optimization</a></li>
</ul>
<h2>The general algorithm</h2>
<p>The general (naive) matrix multiplication algorithm is the <a href="https://github.com/mikecvet/strassen/blob/master/rust/src/mult.rs#L11" rel="noopener ugc nofollow" target="_blank">three nested loops approach</a> everyone learns in their first linear algebra class, which most will recognize as <em>O(n³)</em></p>
<pre>
pub fn
mult_naive (a: &Matrix, b: &Matrix) -> Matrix {
if a.rows == b.cols {
let m = a.rows;
let n = a.cols;
// preallocate
let mut c: Vec<f64> = Vec::with_capacity(m * m);
for i in 0..m {
for j in 0..m {
let mut sum: f64 = 0.0;
for k in 0..n {
sum += a.at(i, k) * b.at(k, j);
}
c.push(sum);
}
}
return Matrix::with_vector(c, m, m);
} else {
panic!("Matrix sizes do not match");
}
}</pre>
<p>This algorithm slow not just because of the three nested loops, but because the inner-loop traversal of <code>B</code> by columns via <code>b.at(k, j)</code> rather than by rows is <em>terrible</em> for CPU cache hit rate.</p>
<h2>Transposition for better performance</h2>
<p>The <a href="https://github.com/mikecvet/strassen/blob/master/rust/src/mult.rs#L39" rel="noopener ugc nofollow" target="_blank">transposed-naive approach</a> reorganizes the multiplication strides of matrix B into a more cache-favorable format by allowing multiplications iterations over B run over rows rather than columns. Thus <code>A x B</code> turns into <code>A x B^t</code></p>
<p><img alt="" src="https://miro.medium.com/v2/resize:fit:700/1*MT319bN-OEQDGFIeXcCrOg.png" style="height:321px; width:700px" /></p>
<p>It involves a new matrix allocation (in this implementation, anyways) and a complete matrix iteration (an <em>O(n²)</em> operation, meaning more precisely this approach is <em>O(n³) + O(n²)</em>) — I will show further down how much better it performs. It looks like the following:</p>
<pre>
fn multiply_transpose (A: Matrix, B: Matrix):
C = new Matrix(A.num_rows, B.num_cols)
// Construct transpose; requires allocation and iteration through B
B’ = B.transpose()
for i in 0 to A.num_rows:
for j in 0 to B'.num_rows:
sum = 0;
for k in 0 to A.num_cols:
// Sequential access of B'[j, k] is much faster than B[k, j]
sum += A[i, k] * B'[j, k]
C[i, j] = sum
return C</pre>
<h2>To understand how Strassen’s algorithm works (code in Rust here), first consider how matrices can be represented by <em>quadrants.</em></h2>
<p><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6">Read More</a></p>