如何将TensorFlow代码的这一小部分翻译成pyTorch?
def transforms(x):
# stft returns spectogram for each sample and each eeg
# input X contains 3 signals, apply stft for each
# and get array with shape [samples, num_of_eeg, time_stamps, freq]
# change dims and return [samples, time_stamps, freq, num_of_eeg]
spectrograms = tf.signal.stft(x, frame_length=32, frame_step=4, fft_length=64)
spectrograms = tf.abs(spectrograms)
return tf.einsum("...ijk->...jki", spectrograms)
1条答案
按热度按时间piv4azn71#
您可以找到STFT pytorch实现here的文档。剩下的是快进。它应该是: