Ergonomic SIMD in Rust
Let's explore SIMD type erasure in Rust
Why Bother?
So chances are, if you've ever dabbled in any performance critical project, you've probably come across SIMD, if not, I'd encourage you to check out this post by Shnatsel. In a nutshell SIMD allows batching arithmetic operations on a single CPU instruction, giving your code considerable performance improvement.
Now this is great, but it actually comes at a cost. It does eventually hinder your experience as a developer, targeting all relevant instruction sets, vectorizing some algorithms (a future post on that), while offering ergonomic APIs can pose a serious challenge. As laid out in Shnatsel's great blog post, there exist 4 solutions to SIMD usage, but we'll focus on the 3rd one, namely portable SIMD abstractions, simply because auto-vectorization in Rust is still brittle, and raw intrinsics are just too much work. Rust has its own built-in portable simd module which is perfect for our use here because it's self-contained and has great platform support.
Ok, so practically, the issue is that we will have to duplicate code for scalar and SIMD types:
fn add_f32(a: f32, b: f32) -> f32 {
a + b
}
fn add_f32x4(a: f32x4, b: f32x4) -> f32x4 {
a + b
}
// ...
We want to write numeric code and have it work efficiently for both scalars and SIMD
vectors without code duplication. We'll explore portable SIMD and leverage Rust's type
system to define a trait Num that will be used as an abstraction to conflate our logic
under:
// Works over both scalar and SIMD
fn add<N: Num>(a: N, b: N) -> N {
a + b
}A First Attempt...
Num is essentially the opaque type that will be used for all arithmetic logic,
abstracting over scalar and vector types. We know it has to be cheaply passed around
by value and should support all arithmetic operators +, -, *, etc.
With this in mind, we can already start defining this Num trait with the appropriate
bounds and a couple of functions:
pub trait Num:
Copy
+ Neg<Output = Self>
+ Add<Output = Self>
+ Sub<Output = Self>
+ Mul<Output = Self>
+ Div<Output = Self>
+ Rem<Output = Self>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
{
fn min(self, other: Self) -> Self;
fn max(self, other: Self) -> Self;
fn clamp(self, min: Self, max: Self) -> Self;
fn pi() -> Self;
fn tau() -> Self;
fn abs(self) -> Self;
fn cos(self) -> Self;
fn sin(self) -> Self;
// ...
}
This should be enough to cover the basics. It is now time to implement this trait for
our desired primitives and vector types. The implementation should be pretty
straightforward, most of the functions are tiny wrappers with no logic.
To do so, we simply forward the implementor's corresponding method, with one small caveat,
we have to make sure the compiler inlines the indirection call. The reason is that these
functions are so critical, and such small wrappers, we can't afford the associated
runtime cost of calling them. So we will force the compiler to inline them with the
help of the attribute #[inline(always)]. Generally speaking, it's probably not a good
idea to do this, the compiler just knows better, but this is one of the very few
instances where you'd be safe to use this attribute. Otherwise, most of the time, you'd
rather hint at the compiler using #[inline].
So, we will implement Num for 4 types, f32, f64, f32x4 and f64x4, but you may
extend this to other types.
Having to implement this by hand can get a bit redundant, to avoid that, we'll use the help of a declarative macro. Let's call it impl_num, this macro will do the grunt work of writing out the same implementation for each type.
Here's our impl_num macro:
// Define `impl_num` macro
macro_rules! impl_num {
// Match against scalar ident.
($scalar:ident) => {
// Implement 'Num' for each match ident.
impl Num for $scalar {
// Force compiler to inline the method
#[inline(always)]
fn min(self, other: Self) -> Self {
// Forward the corresponding method
$scalar::min(self, other)
}
// Repeat...
}
};
}
// Implement 'Num' for each scalar
impl_num!(f32);
impl_num!(f64);
To have a clearer view of what's going on, we will expand it using cargo expand.
impl Num for f32 {
#[inline(always)]
fn min(self, other: Self) -> Self {
f32::min(self, other)
}
// ...
}
impl Num for f64 {
#[inline(always)]
fn min(self, other: Self) -> Self {
f64::min(self, other)
}
// ...
}
We can see how the trait Num, defined earlier, is implemented for f32 and f64.
We make sure the compiler has inlined the indirection using the attribute mentioned above.
The logic is amenable to just about any function you'd want to throw in there.
Vector types follow the exact same principle, for which we will write another macro,
following the same pattern:
use std::simd::{f32x4, f64x4, StdFloat, num::SimdFloat};
macro_rules! impl_num {
($simd:ident) => {
impl Num for $simd {
#[inline(always)]
fn min(self, other: Self) -> Self {
$simd::simd_min(self, other)
}
// ...
}
}
}
impl_num!(f32x4);
impl_num!(f64x4);
Note that you may have to disambiguate some functions and use some tricks here and there:
use std::simd::{f32x4, f64x4, StdFloat, num::SimdFloat};
macro_rules! impl_num{
// We are adding an extra `$element` ident which is our SIMD element
($simd:ident, $element:ident) => {
impl Num for $simd {
// ...
#[inline(always)]
fn exp(self) -> Self {
<$simd as StdFloat>::exp(self)
}
#[inline(always)]
fn sinh(self) -> Self {
$simd::from_array(self.to_array().map($element::sinh))
}
// ...
}
}
impl_num!(f32x4, f32);
impl_num!(f64x4, f64);
Here StdFloat doesn't have a native sinh method we can forward, but we do have sinh
for its underlying element type (f32/f64). So we fallback to a lane-wise mapping of
the corresponding sinh.
But overall it works in similar fashion and just like that we can now start using our
generic add function:
use std::f32::consts::{PI,TAU};
use std::simd::f32x4;
// Super neat
fn add<N: Num>(a: N, b: N) -> N {
a + b
}
fn main() {
// f32
assert_eq!(add(PI,PI), TAU);
// ..and simd!
assert_eq!(add(f32x4::splat(PI), f32x4::splat(PI)), f32x4::splat(TAU));
}
That's great, but this abstraction has one major flaw, it doesn't handle comparison operations. The problem around branching, is that there's a fundamental type mismatch we can't just simply shove under the rug:
fn gt(self, other: Self) -> ???
You see, branching logic behaves differently under vector types.
Comparing SIMD types involves masks, which are their own type, with their own usage,
as opposed to bool for scalars. It seems like this is where our neatly unified
abstraction starts to show a fundamental design flaw.
As a first solution, we could add an associated type Mask to our Num trait:
type Mask; // bool for scalars, Mask<T, N> for SIMD
fn gt(self, other: Self) -> Self::Mask;
such that now implementing Num, we specify the Mask as well, in our case either
std::simd::Mask<T,N> or bool. But we are now generic over Mask as well,
which spreads complexity:
fn some_clamp<N: Num>(val: N, min: N, max: N) -> N
where
N::Mask: ??? // Now we need an extra bound
{
let too_low = val.gt(min);
let too_high = val.gt(max);
// And all the mask operations...
}
This totally defeats the purpose if each function doing comparison has to be aware of
N::Mask. That's Rust's type system kindly letting us know our solution isn't the
right one.
As much as comparison can be reasonably unified with arithmetic, SIMD involves data
selection, which does not really integrate well with either of the two. So we'll have
to make a compromise here, and probably abandon the grand idea of having one trait Num
that does it all for us.
A Better Approach
As it stands our provided solution isn't inherently wrong, it's just incomplete, and relying on it will introduce too much friction. Rust offers a modern type system and we can definitely make use of it to bridge this gap more elegantly.
For starters, we can introduce the marker traits Float and Simd, for which Num is a
supertrait. They should give us finer-grained controls without the twiddly generics and
type mismatch involved. We let the user decide which of these traits they'll be using,
Num for general small math utilities, Float for more involved structs, and Simd
for explicitly vectorized hot callbacks, but they should be designed to work together
smoothly.
So let's start with Float:
pub trait Float: Num {
type SIMD: Simd<Self, 4>;
}
We see here that Float should implement Num and has an associated type SIMD which
in turn implements Simd<T,N> (more on that later). For practical reasons, we have
hardcoded the lane-width N to 4, this is such that we can elide this generic parameter,
otherwise it will spread around, feel free to deal with it as you'd like.
That is really all we need to do, Num does the rest and we've already defined it.
We need to define Simd<T,N> before implementing Float:
pub trait Simd<T: Copy, const N: usize>:
Num
+ From<[T; N]>
+ Into<[T; N]>
+ PartialEq
+ SimdPartialOrd
+ SimdPartialEq
+ SimdFloat
+ StdFloat
{}
Let's see what's going on here, first off From<[T;N]> and Into<[T;N]> are given
helpers to offer seamless SIMD abstraction to and from T, here's a practical example
to see how it works:
fn process_audio<F>(samples: &[F]) -> Vec<F>
where
F: Float
{
for chunk in samples.chunks_exact(4) {
let array: [F; 4] = chunk.try_into().unwrap();
// Easy SIMD usage from F
let simd: F::SIMD = array.into();
let processed = simd.sqrt();
let result: [F; 4] = processed.into();
}
}
SimdFloat, StdFloat give us ops and higher-level math for parity with scalar
counterpart, and finally PartialEq, SimdPartialOrd, SimdPartialEq give us
different flavors of equality/ordering checks. All in all this should make our SIMD
type somewhat complete.
Right, so now we have a marker trait Simd<T,N>,
we can now implement it for f32x4/f64x4, but why restrict ourselves to these two
types when we've got the rigorous type-level working ground to extend it to any
std::simd::Simd<T,N> since both f32x4 and f64x4 are type aliases of this wrapper
struct. We will also rename std::simd::Simd<T,N> to StdSimd<T,N> since it shares
the name with our Simd<T,N> trait.
use std::simd::{num::SimdFloat,
Simd as StdSimd,
LaneCount, StdFloat,
SupportedLaneCount,
cmp::{SimdPartialEq, SimdPartialOrd}};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};
impl<T, const N: usize> Simd<T, N> for StdSimd<T, N> where
T: Copy + PartialEq + std::simd::SimdElement,
LaneCount<N>: SupportedLaneCount,
StdSimd<T, N>:
Num
+ From<[T; N]>
+ Into<[T; N]>
+ PartialEq
+ SimdPartialOrd
+ SimdPartialEq
+ SimdFloat
+ StdFloat
{}
First off, we can see T must implement std::simd::SimdElement, this is a simple
safety constraint, we ensure our SIMD element can be packed in a SIMD register.
Same goes for LaneCount<N>: SupportedLaneCount, we ensure proper hardware alignment
by working with power-of-2 lane counts. The rest of what you see, is to comply with our
previously stated bounds. Now f32x4 and f64x4 (..and all the rest) implement
Simd<T,N>, great!!
Back to Float, now that we have our SIMD trait complete, we can implement our Float
trait for scalars:
impl Float for f32 { type SIMD = StdSimd<Self, 4>; }
impl Float for f64 { type SIMD = StdSimd<Self, 4>; }
And just like that, we have seamless interoperability between unified numerics with Num,
primitives with Float and vector types with Simd.
We will be putting this code to the test in a next post, but here's a little example that illustrates how you'd typically make use of it:
pub struct Osc<F: Float, W: Waveforms<F::SIMD>> {
waveform: W,
delta: F::SIMD,
}
impl<F: Float, W: Waveforms<F::SIMD>> Osc<F, W> {
#[inline]
pub fn process(&mut self) -> F::SIMD {
self.waveform.generate(self.delta)
}
}Wrapping Up
We've just seen despite being truly imperative, SIMD and numeric code can be reasoned about at a type level. Rust offers all the tools to help mitigate issues and offer better developer experience downstream. We will make use of it in a next post, until then, take care, bye bye ๐