如何将TensorFlow代码的这一小部分翻译成pyTorch?

xxhby3vn  于 2022-11-23  发布在  其他
关注(0)|答案(1)|浏览(137)

如何将TensorFlow代码的这一小部分翻译成pyTorch?

def transforms(x):
        # stft returns spectogram for each sample and each eeg
        # input X contains 3 signals, apply stft for each 
        # and get array with shape [samples, num_of_eeg, time_stamps, freq]
        # change dims and return [samples, time_stamps, freq, num_of_eeg]
        spectrograms = tf.signal.stft(x, frame_length=32, frame_step=4, fft_length=64)
        spectrograms = tf.abs(spectrograms)
        return tf.einsum("...ijk->...jki", spectrograms)
piv4azn7

piv4azn71#

您可以找到STFT pytorch实现here的文档。剩下的是快进。它应该是:

def transforms(x: torch.Tensor) -> torch.Tensor:
    """Return Fourrier spectrogram."""
    spectrograms = torch.stft(x, win_length=32, n_fft=4, hop_length=64)
    spectrograms = torch.abs(spectrograms)
    return torch.einsum("...ijk->...jki", spectrograms)

相关问题