Contents

Fast Fourier Transform with PyTorch

Recently Google released a model similar to Transformer but with self-attention replaced by Fourier transform. The model is just a bit less accurate but 6x faster on GPU. It is a great chance to introduce yourself to Fourier Transform with PyTorch.

Introduction

Fourier Transform is a transformation made on a temporal sequence. You may hear that “it transforms the series from the time domain into the frequency domain” - which had literally no sense for me. In this article, I’ll try to explain the basics of the FT and how to use it in Pytorch.

Note
This article is intended as an easy introduction. If you need detailed math, please look for something else ;)

Before we start with the transform, we need to clean up our understanding of temporal series. For our purposes, we will often call it a signal.

In layman’s terms, a temporal sequence is just a sequence of values with assigned time. So, each observation has a value and a time when it happened. The time can vary a lot - it can be measured in nanoseconds or years. For this article, we need to have one assumption - the measurement is done at a constant rate. That is, the time difference between consecutive events is always the same. Let’s take a look at some examples.

The first one is a CO2 concentration at Manual Loa. Here, the time step is equal to one month.

Source: https://datahub.io/core/co2-ppm#data-cli

Source: https://datahub.io/core/co2-ppm#data-cli

The second one comes from astronomy and represents variation in the lightness of an asteroid.

Source: https://en.wikipedia.org/wiki/Light_curve

Source: https://en.wikipedia.org/wiki/Light_curve

Now, we have something interesting. Text can also be considered as a temporal sequence. For the concept of time, we do not need any units. It is perfectly fine to assume that consecutive words appear in \(t=0,1,2,3,4,\ldots\)

Text is also a temporal sequence

Text is also a temporal sequence

Sine wave

Such real-life signals are challenging to analyze. Let’s start with a simple sine wave. In the ideal world, it is an infinite, continuous, repeating pattern.

Part of infinite sine wave

Part of infinite sine wave

In real life, we have only a part of the signal - say, 3 seconds. Also, there is no infinite number of values - we sample the real signal with some frequency. We call this sampling rate \(f_s\) and measure in Hertz [Hz]. We do this uniformly so time step will be \(t_s=\frac{1}{f_s} \) . The data becomes more like this:

Higher sampling rate - higher quality but more data to deal with it.

Higher sampling rate - higher quality but more data to deal with it.

Wave properties

Each wave has some properties - we need to understand them before proceeding. For us, the most important are:

  • Duration - how long is the signal (in original units)
  • Sampling rate \(f_s\)
  • Time step \(dt\) - time between two sampling events.
  • Total samples \(N\) - how many samples we have in total (duration times sampling rate)
  • Amplitude \(A\) - maximum absolute value
  • Frequency \(f\)- Number of cycles per second [Hz]
  • Phase \(\varphi\) - initial offset [rad]
Amplitude and frequency

Amplitude and frequency

Phase shift - as wave does not have to start at zero.

Phase shift - as wave does not have to start at zero.

Generating sine wave

We need a piece of code for generating a sine wave. The equation is following:

$$ y(t) = A\sin(2\pi f t + \varphi) $$

In PyTorch we can do something like this:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import numpy as np
from typing import Tuple


def generate_sine_wave(
    amplitude: float,
    frequency: float,
    phase: float,
    sampling_rate: int,
    duration: float,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
    """Generates a sine wave

    Args:
        amplitude (float): Amplitude
        frequency (float): Frequnecy [Hz]
        phase (float): Phase [rad]
        sampling_rate (int): Sampling rate[Hz]
        duration (float): Duration [s]

    Returns:
        Tuple[torch.FloatTensor, torch.FloatTensor]: t, y(t)
    """
    total_steps = int(sampling_rate * duration)
    t = torch.linspace(0, duration, total_steps)
    y = amplitude * torch.sin(2 * np.pi * frequency * t + phase)

    return t, y
1
2
3
4
5
6
7
t1, y1 = generate_sine_wave(amplitude=1.5, frequency=3, phase=1, sampling_rate=1_000, duration=1)
t2, y2 = generate_sine_wave(amplitude=0.3, frequency=1.5, phase=np.pi/2, sampling_rate=1_000, duration=1.2)

fig, ax = plt.subplots()
ax.scatter(t1, y1, s=1)
ax.scatter(t2, y2, s=1)
ax.grid()
Generated signals

Generated signals

Combining multiple waves

The funny thining is we can combine these simple waves to create more fancy ones. Let’s see some examples:

Combined two waves

Combined two waves

Combined three waves

Combined three waves

Decomposing wave

What if I told you you can reverse this operation

Combining smaller things into a bigger mixture is easy. But what if we have a complicated wave and would like to decompose it into elementary pieces? This is what Fourier Transform is about. Here we will use an algorithm called Fast Fourier Transform, which is widely used in computer engineering.

Imagine we are given the following signal.

Excersise 1
Try to approximate compound sines before continuing and see if you were right ;) You need to find amplitude and frequency.

We are going to apply FFT to get elementary parts with PyTorch. There is a dedicated module, torch.fft. The first function we will use is rfft. The docs say:

Computes the one dimensional Fourier transform of real-valued input.

Input is 1D sequence of real values, so we are good.

1
2
3
import torch
import torch.fft as fft
fourier = fft.rfft(signal, norm='forward')

That was easy. But what did we get?

1
2
3
4
5
6
7
print(signal.shape, fourier.shape)
print(fourier[:5])

# torch.Size([1499]) torch.Size([750])
# tensor([ 2.1671e-08+0.0000e+00j, -2.9042e-08+2.2379e-09j,
        # -2.7894e-08-1.0006e-07j,  1.0370e-07-4.0000e-01j,
        #  2.2421e-08+6.8265e-08j])

rfft returned a complex tensor with 5001 elements. Each value is a complex number with non-zero real and imaginary parts. What do those numbers mean? For first, let’s compute the absolute value of them and make a plot.

Danger
Absolute value for a complex number is not the same as for real numbers!
1
2
3
4
5
6
7
8
9
absolutes = fourier.abs()

fig, ax = plt.subplots()
ax.set_xlabel('?')
ax.set_ylabel('$|F(s)|$')
ax.grid()
ax.scatter(x=torch.arange(absolutes.shape[0]), y=absolutes, s=8)

plt.savefig('../plots/fourier_1.svg')

What we see: a lot of zeros, and two few dots at the left side. And a mysterious axes. What do those numbers mean?

We already said that the fourier transform will decompose the signal into elemetary waves. In fact the dots represents parameters of the compound waves.

The Y (or the abs() of rfft) represents the magnitude of a specific wave. There are two non-zero, so we can conclude that the signal is built using two waves. But what waves? This is stored on X-axis. To interpret, we need to decode those numbers, and there is a function for that ;).

Yes, there is a specific function to get labels for X-axis - tfftfreq. It takes two parameters - length of the signal (\(N\)) and time step (\(dt\)) of the original signal. In return, we receive frequencies corresponding to subsequent values from rfft. They are often called bins.

Giving a try:

1
freq = fft.rfftfreq(len(signal), 1/sampling_rate)
Bottom view is zoomed in

Bottom view is zoomed in

Now, we clearly see: original signal is composed of two waves: 1/5 Hz and 2 Hz with magnitudes 0.4 and 0.1 We can easily see it by printing the biggest frequencies:

1
2
3
4
5
6
7
8
9
ranked = absolutes.argsort(descending=True)
for i in ranked[:5]:
    print(f'f={freq[i]:.2f}\t val={absolutes[i]:.4f}')

# f=0.20	 val=0.4000
# f=2.00	 val=0.1000
# f=1.93	 val=0.0000
# f=2.07	 val=0.0000
# f=0.13	 val=0.0000

Magnitude, 0 Hz bin and phase

But wait, what is magnitude? Intuitively it is the “amount” of signal distributed in a specific bin. So it is closely related to amplitude. In our case, we can compute amplitude by multiplying magnitude by 2 (because in fft we used norm=“forward” - for different normalization, we would do it differently).

Now we can safely safe that the signal is composed from two waves. First with \(A=0.8, f=0.2\text{Hz}\), second with \(A=0.2, f=2\text{Hz}\).

There is one more interesting point - the 0Hz bin. It represents a constant signal (DC) - just like the integration constant. In fact, it is equal to the mean of all samples - showing if it has a vertical offset.

Wait, but the wave also can have a phase shift. In this example for both of them they are equal zero ( \(\varphi=0 \text{rad}\)). Extracting with phase is more difficult, so it won’t be covered here - as it is not needed to understand the concept itself.

Undesired effects

Working with fft in real life is not so easy. For example, if I would change the signal duration from 15 to 16 seconds, the transform will look like this:

Signal is almost the same - just one second longer

Signal is almost the same - just one second longer

Are there more than two waves?

Are there more than two waves?

Zooming in shows multiple waves

Zooming in shows multiple waves

What happened? 😱

We see that instead of two waves, there are many of them. Let’s take a closer look:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
ranked = absolutes.argsort(descending=True)
for i in ranked[:5]:
    print(f'f={freq[i]:.2f}\t val={absolutes[i]:.4f}')
# f=0.19	 val=0.3705
# f=0.25	 val=0.0974
# f=2.00	 val=0.0962
# f=0.13	 val=0.0593
# f=0.31	 val=0.0453

for f in freq[:5]:
    print(f'f={f:.2f}')

# f=0.00
# f=0.06
# f=0.13
# f=0.19
# f=0.25

Bins now are different. Previously we had a bin f=0.2Hz, which was perfectly aligned to the signal, but now we have only f=0.19Hz and f=0.25Hz. None of them matches the exact frequency, so the FFT needs to “spread” across different ones. It happens because bin values depend on signal duration and timestep (recall arguments for rfftfreq). We cannot do much with it; we just need to know that there are no perfect results in reality.

Info
This phenomenon is known as spectral leakage .

Closing

Ok, I think this is enough for today. Now you should be able to conceptually know what fourier transform is. If you would like to continue exploring the topic I recommend the following starting point: https://betterexplained.com/articles/an-interactive-guide-to-the-fourier-transform/