Today I am happy to announce the relase of version 1.0 of
torchkbnufft (GitHub, Documentation). There are many changes: complex number support, an improved backend that is 4 times faster on the CPU and 2 times faster on the GPU, a better density compensation function, and more detailed documentation. Why all the updates now? Well, recently PyTorch began supporting complex tensors natively (you can read about complex number support here). Before, we had to use four high-level multiplies in Python for complex multiplications. With native complex tensor support, we can move these multiplications down to lower-level PyTorch code for a significant speed-up. This, along with updates for the PyTorch FFT API, prompted a rewrite of
In updating the code for complex multiplications I noticed many other areas for improvement. In this post I’ll document some of the more important ones as well as my reasoning for making the changes.
An Updated API
The first thing that most users will notice with
torchkbnufft version 1.0 is a different API. Previously, for an MRI problem with a batch size of 5, 8 channels, and height/width of 64, you would pass a tensor to the forward NUFFT of shape
[5, 8, 2, 64, 64], where the
2 dimension was the real/imaginary dimension. This was always a little bit strange - PyTorch’s FFT expected the real/imaginary dimension to be at the end of the shape. The reason for this was that a lot of early deep learning MRI models would include real/imaginary in the channel dimension for convolutions. However, for version 1.0, we decided to convert the NUFFT to follow PyTorch FFT convention. There are a couple of reasons for this: 1) it brings us in line with the PyTorch ecosystem and 2) it’s very easy to convert real tensors (with last dimension of size
2) to complex tensors. So now the package does this for any real input and we can have a more efficient backend based on complex tensors.
As a result, for our problem above, you’ll now be expected to pass in a complex-valued tensor with shape
[5, 8, 64, 64]. You can still pass in a real tensor with separate real/imaginary dimensions as
[5, 8, 64, 64, 2], but for the NUFFT code it will be converted to complex and then back to real before returning to you. Not having to deal with real tensors in the backend simplifies the code base and makes things more efficient.
Another change that people will notice is the size of the k-space trajectory. Previously, it would have been
[5, 2, klength], where
klength was the number of k-space samples. The idea was that you could apply a different k-space trajectory for each batch element. In the end, I decided to remove this feature and only do one k-space trajectory for a forward pass. The reason is that in the underlying code, I just wrote a
for loop over the different trajectories. This took away some optimization opportunities in the backend (detailed below). It’s better for
torchkbnufft to only take one trajectory for the forward pass and have the user write
for loops over their trajectories while I write a more efficient backend, so this is the behavior in 1.0.
Improved Indexing Operations
The slowest part of
torchkbnufft are its indexing operations. These are pretty difficult to handle in a high-level library, and the solutions that I have at the moment still may not be ideal. Nonetheless, for version 1.0 we managed to make some improvements over what the package did previously, achieving about a four-fold speedup for forward/backward on the CPU and a two-fold speedup on the GPU. For all the pseudo-code I show below, you can see the full, up-to-date version on GitHub. Prior to version 1.0, the indexing operation for the forward interpolation looked like this:
coef, arr_ind = calc_coef_and_indices( tm, kofflist, Jlist[:, Jind], table, centers, L, dims ) # unsqueeze coil and real/imag dimensions for on-grid indices arr_ind = ( arr_ind.unsqueeze(0).unsqueeze(0).expand(kdat.shape, kdat.shape, -1) ) # gather and multiply coefficients kdat += complex_mult( coef.unsqueeze(0), torch.gather(griddat, 2, arr_ind), dim=1 )
The code calculates
coef, which are interpolation coefficients based on the Kaiser-Bessel kernel, and
arr_ind, which are the indices of the neighbors to use for interpolation. The key indexing operation is
torch.gather(griddat, 2, arr_ind). The GPU implementation in 1.0 is basically the same, but uses complex numbers for multiplication and
griddat[:, :, arr_ind] instead of
torch.gather. I’ll focus on the larger changes for the CPU version.
The primary issue with this code on the CPU is that indexing into an array is slow in PyTorch. We can mitigate this by minimizing the size of the index problem - in version 1.0 of
torchkbnufft, we split up the k-space trajectory and send a different chunk of the trajectory to each process as follows:
@torch.jit.script def table_interp_over_batches( image: Tensor, omega: Tensor, tables: List[Tensor], n_shift: Tensor, numpoints: Tensor, table_oversamp: Tensor, offsets: Tensor, num_forks: int, ) -> Tensor: """Table interpolation backend (see table_interp()).""" # indexing is worse when we have repeated indices - let's spread them out klength = omega.shape omega_chunks = [omega[:, ind:klength:num_forks] for ind in range(num_forks)] futures: List[torch.jit.Future[torch.Tensor]] =  for omega_chunk in omega_chunks: futures.append( torch.jit.fork( table_interp_one_batch, image, omega_chunk, tables, n_shift, numpoints, table_oversamp, offsets, ) ) kdat = torch.zeros( image.shape, image.shape, omega.shape, dtype=image.dtype, device=image.device, ) for ind, future in enumerate(futures): kdat[:, :, ind:klength:num_forks] = torch.jit.wait(future) return kdat
In this case,
table_interp_one_batch is basically the same as our old table interpolation function. The forks will execute asynchronously over their separate k-space chunks using
torch.jit.fork (see here), and at the end we’ll join them all together and return. This speeds up indexing operations by reducing the number of k-space points to look at and is one of the main sources of our improvements.
We’ve also changed the adjoint, where we have to scatter a k-space trajectory on to an equispaced grid using the Kaiser-Bessel kernel. Prior to 1.0, it looked like this:
coef, arr_ind = calc_coef_and_indices( tm, kofflist, Jlist[:, Jind], table, centers, L, dims, conjcoef=True ) # the following code takes ordered data and scatters it on to an image grid # profiling for a 2D problem showed drastic differences in performances # for these two implementations on cpu/gpu, but they do the same thing if device == torch.device("cpu"): tmp = complex_mult(coef.unsqueeze(0), kdat, dim=1) for bind in range(griddat.shape): for riind in range(griddat.shape): griddat[bind, riind].index_put_( tuple(arr_ind.unsqueeze(0)), tmp[bind, riind], accumulate=True ) else: griddat.index_add_(2, arr_ind, complex_mult(coef.unsqueeze(0), kdat, dim=1))
You might notice the device branch. For some reason,
accumulate=True was faster on the CPU, whereas
index_add_ was faster on the GPU. I haven’t observed this anymore when building PyTorch off its master branch, so we’ll probably use
index_add_ for everything going forward once the next version of PyTorch is out.
The issue with the old code for the adjoint is that the double
for loop over batch and real/imaginary indices isn’t very fast on the CPU branch. Furthermore,
index_add_ doesn’t work very well for the GPU branch over batch dimensions, either. It would be better to dispatch a bunch of workers to work on every independent batch and coil element, and this is exactly what 1.0 does. The code I’m showing below is a partial construction of how we now do adjoint interpolation showing the key pieces.
def accum_tensor_index_add(image: Tensor, arr_ind: Tensor, data: Tensor) -> Tensor: """We fork this function for the adjoint accumulation.""" return image.index_add_(0, arr_ind, data) def accum_tensor_index_put(image: Tensor, arr_ind: Tensor, data: Tensor) -> Tensor: """We fork this function for the adjoint accumulation.""" return image.index_put_((arr_ind,), data, accumulate=True) @torch.jit.script def fork_and_accum(image: Tensor, arr_ind: Tensor, data: Tensor, num_forks: int): device = image.device futures: List[torch.jit.Future[torch.Tensor]] =  for batch_ind in range(image.shape): for coil_ind in range(image.shape): # if we've used all our forks, wait for one to finish and pop if len(futures) == num_forks: torch.jit.wait(futures) futures.pop(0) # one of these is faster on cpu, other is faster on gpu if device == torch.device("cpu"): futures.append( torch.jit.fork( accum_tensor_index_put, image[batch_ind, coil_ind], arr_ind, data[batch_ind, coil_ind], ) ) else: futures.append( torch.jit.fork( accum_tensor_index_add, image[batch_ind, coil_ind], arr_ind, data[batch_ind, coil_ind], ) ) _ = [torch.jit.wait(future) for future in futures] ... coef, arr_ind = calc_coef_and_indices( tm=tm, base_offset=base_offset, offset_increments=offset, tables=tables, centers=centers, table_oversamp=table_oversamp, grid_size=grid_size, conjcoef=True, ) tmp = coef * data if not device == torch.device("cpu"): tmp = torch.view_as_real(tmp) # this is a much faster way of doing index accumulation if USING_OMP: torch.set_num_threads(threads_per_fork) fork_and_accum(image, arr_ind, tmp, num_forks) if USING_OMP: torch.set_num_threads(num_threads)
torch.jit.fork, we create a new asynchronous task for every batch and coil element. The tasks each handle accumulation for their own element. The calls to
torch.jit.wait causes the code to wait for these asynchronous tasks to finish. Since the accumulation is done in-place, we don’t have to worry about whatever these tasks return. However, there is one thing we have to worry about with forking: OpenMP. If we don’t do a bit of thread management, then we can use more threads than we were given or suffer performance degradation from oversubscription. To prevent this, we do a little bit of thread management to make sure that we don’t have too many forks.
The adjoint operation with forking is faster - more than a factor-of-4 over the previous implementation for the CPU and a factor-of-2 for the GPU. (Note: the GPU operations are still real-valued, but this should change in the future when
index_add_ supports complex numbers.)
Overall these improvements have made version 1.0 of
torchkbnufft about four times as fast as previously on the CPU and and two times as fast on the GPU. The forward operation was bound more by the complex multiplies and indexing - we get about a 2-3 speed-up by using complex tensors and using
torch.jit.fork to break up the trajectory. The adjoint operation was bound by the accumulation, and we get a 2-5 speedup by using
torch.jit.fork to dispact over batches and coils.
The package will scale very well over coils and batch dimensions. In general, we’re bound by our indexing operations, so the main thing that makes NUFFTs slower or faster is the size of the k-space trajectory.
One thing that does affect indexing is using a 3D NUFFT. The package is faster for 3D than before, but unfortunately the speedup isn’t as consistent. PyTorch indexing begins to perform worse with larger arrays, and this is the situation we have for 3D NUFFTs. There are a few steps you can take that will help:
- Use 32-bit precision instead of 64.
- Lower the oversampling ratio.
- Use fewer neighbors for interpolation (e.g., set
- Use a GPU.
But if that’s not good enough, then you’re running into the limitations of the package.
Updates to Documentation
torchkbnufft was decent on the GitHub repository with the
README.md and several Jupyter notebooks, but the documentation on Read the Docs was a bit lacking. It only consisted of an API, and the layout of the table of contents made it hard to navigate.
This has also been updated substantially for 1.0. We now prominently display our core modules:
ToepNufft. Each one of these is now accompanied by a mathematical description of the operations as well as detailing connections to notation in Fessler’s NUFFT paper. (Note: If you have any comments or notice any errors in the documentation, please let me know!) We also prominently display our primary utility functions:
calc_toeplitz_kernel. This should make it a lot easier for beginners to navigate the package.
A New Density Compensation Function
Thanks to a notification from Zaccharie Ramzi and an implementation by Chaithya G.R., we got a pull request for implementing Pipe’s density compensation method. This was quite a bit better than my original method which presumably only worked for radial trajectories. The density compensation function calculator also has a simplified interface.
Version 1.0 of
torchkbnufft was essentially a complete rewrite of the repository and its documentation. The result is a faster, better-documented NUFFT package that retains its original benefit of being written completely in high-level Python.
This remains a personal project unaffiliated with my official position at FAIR, so all of this work was done on my own time. Still, I think it was quite rewarding, and I’m happy with the improvements to the repository.
For my next project, I think it may be finally time to move beyond Python. I’ve grown to love Python and PyTorch over the last 2+ years, but there are so many cool languages out there to try, I think I’ll have to look into one of those next…