Home

Ergonomic SIMD in Rust

Let's explore SIMD type erasure in Rust

ยท

Why Bother?

So chances are, if you've ever dabbled in any performance critical project, you've probably come across SIMD, if not, I'd encourage you to check out this post by Shnatsel. In a nutshell SIMD allows batching arithmetic operations on a single CPU instruction, giving your code considerable performance improvement.

Now this is great, but it actually comes at a cost. It does eventually hinder your experience as a developer, targeting all relevant instruction sets, vectorizing some algorithms (a future post on that), while offering ergonomic APIs can pose a serious challenge. As laid out in Shnatsel's great blog post, there exist 4 solutions to SIMD usage, but we'll focus on the 3rd one, namely portable SIMD abstractions, simply because auto-vectorization in Rust is still brittle, and raw intrinsics are just too much work. Rust has its own built-in portable simd module which is perfect for our use here because it's self-contained and has great platform support.

Ok, so practically, the issue is that we will have to duplicate code for scalar and SIMD types:

fn add_f32(a: f32, b: f32) -> f32 {
  a + b
}

fn add_f32x4(a: f32x4, b: f32x4) -> f32x4 {
  a + b
}

// ...

We want to write numeric code and have it work efficiently for both scalars and SIMD vectors without code duplication. We'll explore portable SIMD and leverage Rust's type system to define a trait Num that will be used as an abstraction to conflate our logic under:

// Works over both scalar and SIMD
fn add<N: Num>(a: N, b: N) -> N {
  a + b
}

A First Attempt...

Num is essentially the opaque type that will be used for all arithmetic logic, abstracting over scalar and vector types. We know it has to be cheaply passed around by value and should support all arithmetic operators +, -, *, etc. With this in mind, we can already start defining this Num trait with the appropriate bounds and a couple of functions:

pub trait Num:
    Copy 
    + Neg<Output = Self>
    + Add<Output = Self>
    + Sub<Output = Self>
    + Mul<Output = Self>
    + Div<Output = Self>
    + Rem<Output = Self>
    + AddAssign
    + SubAssign
    + MulAssign
    + DivAssign
{
  fn min(self, other: Self) -> Self;
  fn max(self, other: Self) -> Self;
  fn clamp(self, min: Self, max: Self) -> Self;
  fn pi() -> Self;
  fn tau() -> Self;
  fn abs(self) -> Self;
  fn cos(self) -> Self;
  fn sin(self) -> Self;
 // ...
}

This should be enough to cover the basics. It is now time to implement this trait for our desired primitives and vector types. The implementation should be pretty straightforward, most of the functions are tiny wrappers with no logic. To do so, we simply forward the implementor's corresponding method, with one small caveat, we have to make sure the compiler inlines the indirection call. The reason is that these functions are so critical, and such small wrappers, we can't afford the associated runtime cost of calling them. So we will force the compiler to inline them with the help of the attribute #[inline(always)]. Generally speaking, it's probably not a good idea to do this, the compiler just knows better, but this is one of the very few instances where you'd be safe to use this attribute. Otherwise, most of the time, you'd rather hint at the compiler using #[inline].

So, we will implement Num for 4 types, f32, f64, f32x4 and f64x4, but you may extend this to other types.

Having to implement this by hand can get a bit redundant, to avoid that, we'll use the help of a declarative macro. Let's call it impl_num, this macro will do the grunt work of writing out the same implementation for each type.

Here's our impl_num macro:

// Define `impl_num` macro
macro_rules! impl_num {
    // Match against scalar ident.
    ($scalar:ident) => {
        // Implement 'Num' for each match ident.
        impl Num for $scalar {
            // Force compiler to inline the method
            #[inline(always)] 
            fn min(self, other: Self) -> Self {
                // Forward the corresponding method
                $scalar::min(self, other)
            }
            // Repeat...
        }
    };
}
// Implement 'Num' for each scalar
impl_num!(f32);
impl_num!(f64);

To have a clearer view of what's going on, we will expand it using cargo expand.

impl Num for f32 {
  #[inline(always)]
  fn min(self, other: Self) -> Self {
    f32::min(self, other)
  }
  // ...
}

impl Num for f64 {
  #[inline(always)]
  fn min(self, other: Self) -> Self {
    f64::min(self, other)
  }
  // ...
}

We can see how the trait Num, defined earlier, is implemented for f32 and f64. We make sure the compiler has inlined the indirection using the attribute mentioned above. The logic is amenable to just about any function you'd want to throw in there. Vector types follow the exact same principle, for which we will write another macro, following the same pattern:

use std::simd::{f32x4, f64x4, StdFloat, num::SimdFloat};

macro_rules! impl_num {
    ($simd:ident) => {
        impl Num for $simd {
            #[inline(always)]
            fn min(self, other: Self) -> Self {
                $simd::simd_min(self, other)
            }
            // ...
         }
     }
}

impl_num!(f32x4);
impl_num!(f64x4);

Note that you may have to disambiguate some functions and use some tricks here and there:

use std::simd::{f32x4, f64x4, StdFloat, num::SimdFloat};

macro_rules! impl_num{
  // We are adding an extra `$element` ident which is our SIMD element
  ($simd:ident, $element:ident) => {
    impl Num for $simd {
      // ...
      #[inline(always)]
      fn exp(self) -> Self {
         <$simd as StdFloat>::exp(self)
      }
      #[inline(always)]
      fn sinh(self) -> Self {
        $simd::from_array(self.to_array().map($element::sinh))
      }
      // ...
   }
}

impl_num!(f32x4, f32);
impl_num!(f64x4, f64);

Here StdFloat doesn't have a native sinh method we can forward, but we do have sinh for its underlying element type (f32/f64). So we fallback to a lane-wise mapping of the corresponding sinh.

But overall it works in similar fashion and just like that we can now start using our generic add function:

use std::f32::consts::{PI,TAU};
use std::simd::f32x4;

// Super neat
fn add<N: Num>(a: N, b: N) -> N {
  a + b
}

fn main() {
  // f32
  assert_eq!(add(PI,PI), TAU);
  // ..and simd!
  assert_eq!(add(f32x4::splat(PI), f32x4::splat(PI)), f32x4::splat(TAU));
}

That's great, but this abstraction has one major flaw, it doesn't handle comparison operations. The problem around branching, is that there's a fundamental type mismatch we can't just simply shove under the rug:

fn gt(self, other: Self) -> ???

You see, branching logic behaves differently under vector types. Comparing SIMD types involves masks, which are their own type, with their own usage, as opposed to bool for scalars. It seems like this is where our neatly unified abstraction starts to show a fundamental design flaw.

As a first solution, we could add an associated type Mask to our Num trait:

type Mask; // bool for scalars, Mask<T, N> for SIMD
fn gt(self, other: Self) -> Self::Mask;

such that now implementing Num, we specify the Mask as well, in our case either std::simd::Mask<T,N> or bool. But we are now generic over Mask as well, which spreads complexity:

fn some_clamp<N: Num>(val: N, min: N, max: N) -> N 
where
    N::Mask: ???  // Now we need an extra bound
{
    let too_low = val.gt(min);
    let too_high = val.gt(max);
    // And all the mask operations...
}

This totally defeats the purpose if each function doing comparison has to be aware of N::Mask. That's Rust's type system kindly letting us know our solution isn't the right one.

As much as comparison can be reasonably unified with arithmetic, SIMD involves data selection, which does not really integrate well with either of the two. So we'll have to make a compromise here, and probably abandon the grand idea of having one trait Num that does it all for us.

A Better Approach

As it stands our provided solution isn't inherently wrong, it's just incomplete, and relying on it will introduce too much friction. Rust offers a modern type system and we can definitely make use of it to bridge this gap more elegantly.

For starters, we can introduce the marker traits Float and Simd, for which Num is a supertrait. They should give us finer-grained controls without the twiddly generics and type mismatch involved. We let the user decide which of these traits they'll be using, Num for general small math utilities, Float for more involved structs, and Simd for explicitly vectorized hot callbacks, but they should be designed to work together smoothly.

So let's start with Float:

pub trait Float:  Num {
    type SIMD: Simd<Self, 4>;
}

We see here that Float should implement Num and has an associated type SIMD which in turn implements Simd<T,N> (more on that later). For practical reasons, we have hardcoded the lane-width N to 4, this is such that we can elide this generic parameter, otherwise it will spread around, feel free to deal with it as you'd like. That is really all we need to do, Num does the rest and we've already defined it.

We need to define Simd<T,N> before implementing Float:

pub trait Simd<T: Copy, const N: usize>:
    Num
    + From<[T; N]>
    + Into<[T; N]>
    + PartialEq
    + SimdPartialOrd
    + SimdPartialEq
    + SimdFloat
    + StdFloat
    {}

Let's see what's going on here, first off From<[T;N]> and Into<[T;N]> are given helpers to offer seamless SIMD abstraction to and from T, here's a practical example to see how it works:

fn process_audio<F>(samples: &[F]) -> Vec<F>
where
    F: Float
{
    for chunk in samples.chunks_exact(4) {
        let array: [F; 4] = chunk.try_into().unwrap();
        // Easy SIMD usage from F
        let simd: F::SIMD = array.into();
        let processed = simd.sqrt();
        let result: [F; 4] = processed.into();
    }
}

SimdFloat, StdFloat give us ops and higher-level math for parity with scalar counterpart, and finally PartialEq, SimdPartialOrd, SimdPartialEq give us different flavors of equality/ordering checks. All in all this should make our SIMD type somewhat complete.

Right, so now we have a marker trait Simd<T,N>, we can now implement it for f32x4/f64x4, but why restrict ourselves to these two types when we've got the rigorous type-level working ground to extend it to any std::simd::Simd<T,N> since both f32x4 and f64x4 are type aliases of this wrapper struct. We will also rename std::simd::Simd<T,N> to StdSimd<T,N> since it shares the name with our Simd<T,N> trait.

use std::simd::{num::SimdFloat,
		Simd as StdSimd,
		LaneCount, StdFloat,
		SupportedLaneCount,
		cmp::{SimdPartialEq, SimdPartialOrd}};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign};

impl<T, const N: usize> Simd<T, N> for StdSimd<T, N> where
    T: Copy + PartialEq + std::simd::SimdElement,
    LaneCount<N>: SupportedLaneCount,
    StdSimd<T, N>:
    Num
    + From<[T; N]>
    + Into<[T; N]>
    + PartialEq
    + SimdPartialOrd
    + SimdPartialEq
    + SimdFloat
    + StdFloat
    {}

First off, we can see T must implement std::simd::SimdElement, this is a simple safety constraint, we ensure our SIMD element can be packed in a SIMD register. Same goes for LaneCount<N>: SupportedLaneCount, we ensure proper hardware alignment by working with power-of-2 lane counts. The rest of what you see, is to comply with our previously stated bounds. Now f32x4 and f64x4 (..and all the rest) implement Simd<T,N>, great!!

Back to Float, now that we have our SIMD trait complete, we can implement our Float trait for scalars:

impl Float for f32 { type SIMD = StdSimd<Self, 4>; }
impl Float for f64 { type SIMD = StdSimd<Self, 4>; }

And just like that, we have seamless interoperability between unified numerics with Num, primitives with Float and vector types with Simd.

We will be putting this code to the test in a next post, but here's a little example that illustrates how you'd typically make use of it:

pub struct Osc<F: Float, W: Waveforms<F::SIMD>> {
  waveform: W,
  delta: F::SIMD,
}

impl<F: Float, W: Waveforms<F::SIMD>> Osc<F, W> { 
  #[inline] 
  pub fn process(&mut self) -> F::SIMD {
    self.waveform.generate(self.delta)
  }
}

Wrapping Up

We've just seen despite being truly imperative, SIMD and numeric code can be reasoned about at a type level. Rust offers all the tools to help mitigate issues and offer better developer experience downstream. We will make use of it in a next post, until then, take care, bye bye ๐Ÿ‘‹