Better-than-Cubic Complexity for Matrix Multiplication in Rust

<p>Years ago, I&nbsp;<a href="https://github.com/mikecvet/strassen/blob/master/cpp/src/strassen/strassen_matrix_multiplier.hpp#L159" rel="noopener ugc nofollow" target="_blank">wrote an implementation</a>&nbsp;of the&nbsp;<a href="https://en.wikipedia.org/wiki/Strassen_algorithm" rel="noopener ugc nofollow" target="_blank">Strassen matrix multiplication algorithm</a>&nbsp;in C++, and recently&nbsp;<a href="https://github.com/mikecvet/strassen/tree/master/rust/src" rel="noopener ugc nofollow" target="_blank">re-implemented it in Rust</a>&nbsp;as I continue to learn the language. This was a useful exercise in learning about Rust performance characteristics and optimization techniques, because although the&nbsp;<em>algorithmic complexity</em>&nbsp;of Strassen is superior to the naive approach, it has a high&nbsp;<em>constant factor</em>&nbsp;from the overhead of allocations and recursion within the algorithm&rsquo;s structure.</p> <ul> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#b85b" rel="noopener ugc nofollow">The general algorithm</a></li> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#740c" rel="noopener ugc nofollow">Transposition for better performance</a></li> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#d7d7" rel="noopener ugc nofollow">Sub-cubic: How the Strassen algorithm works</a></li> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#8747" rel="noopener ugc nofollow">Parallelism</a></li> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#fc64" rel="noopener ugc nofollow">Benchmarking</a></li> <li><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6#98ce" rel="noopener ugc nofollow">Profiling and performance optimization</a></li> </ul> <h2>The general algorithm</h2> <p>The general (naive) matrix multiplication algorithm is the&nbsp;<a href="https://github.com/mikecvet/strassen/blob/master/rust/src/mult.rs#L11" rel="noopener ugc nofollow" target="_blank">three nested loops approach</a>&nbsp;everyone learns in their first linear algebra class, which most will recognize as&nbsp;<em>O(n&sup3;)</em></p> <pre> pub fn mult_naive (a: &amp;Matrix, b: &amp;Matrix) -&gt; Matrix { if a.rows == b.cols { let m = a.rows; let n = a.cols; // preallocate let mut c: Vec&lt;f64&gt; = Vec::with_capacity(m * m); for i in 0..m { for j in 0..m { let mut sum: f64 = 0.0; for k in 0..n { sum += a.at(i, k) * b.at(k, j); } c.push(sum); } } return Matrix::with_vector(c, m, m); } else { panic!(&quot;Matrix sizes do not match&quot;); } }</pre> <p>This algorithm slow not just because of the three nested loops, but because the inner-loop traversal of&nbsp;<code>B</code>&nbsp;by columns via&nbsp;<code>b.at(k, j)</code>&nbsp;rather than by rows is&nbsp;<em>terrible</em>&nbsp;for CPU cache hit rate.</p> <h2>Transposition for better performance</h2> <p>The&nbsp;<a href="https://github.com/mikecvet/strassen/blob/master/rust/src/mult.rs#L39" rel="noopener ugc nofollow" target="_blank">transposed-naive approach</a>&nbsp;reorganizes the multiplication strides of matrix B into a more cache-favorable format by allowing multiplications iterations over B run over rows rather than columns. Thus&nbsp;<code>A x B</code>&nbsp;turns into&nbsp;<code>A x B^t</code></p> <p><img alt="" src="https://miro.medium.com/v2/resize:fit:700/1*MT319bN-OEQDGFIeXcCrOg.png" style="height:321px; width:700px" /></p> <p>It involves a new matrix allocation (in&nbsp;this implementation, anyways) and a complete matrix iteration (an&nbsp;<em>O(n&sup2;)</em>&nbsp;operation, meaning more precisely this approach is&nbsp;<em>O(n&sup3;) + O(n&sup2;)</em>) &mdash; I will show further down how much better it performs. It looks like the following:</p> <pre> fn multiply_transpose (A: Matrix, B: Matrix): C = new Matrix(A.num_rows, B.num_cols) // Construct transpose; requires allocation and iteration through B B&rsquo; = B.transpose() for i in 0 to A.num_rows: for j in 0 to B&#39;.num_rows: sum = 0; for k in 0 to A.num_cols: // Sequential access of B&#39;[j, k] is much faster than B[k, j] sum += A[i, k] * B&#39;[j, k] C[i, j] = sum return C</pre> <h2>To understand how Strassen&rsquo;s algorithm works (code in Rust here), first consider how matrices can be represented by&nbsp;<em>quadrants.</em></h2> <p><a href="https://betterprogramming.pub/better-than-cubic-complexity-for-matrix-multiplication-in-rust-cf8dfb6299f6">Read More</a></p>