python-3.x 如何在Keras CNN架构中实现2D-DWT?

mgdq6dx1  于 2023-01-14  发布在  Python
关注(0)|答案(1)|浏览(100)

我试图在CNN上实现2D-DWT(离散小波变换)块,其在输入中采用Tensor并给出四个不同的输出:一个近似值和三个细节:

import numpy as np
   import pywt

   data = np.ones((4,4), dtype=np.float64)
   coeffs = pywt.dwt2(data, 'haar')
   cA, (cH, cV, cD) = coeffs

   output: 

   tensor([[[2.0000, 2.0000],
            [2.0000, 2.0000]],

           [[0.0000, 0.0000],
            [0.0000, 0.0000]],

           [[0.0000, 0.0000],
            [0.0000, 0.0000]],

            [[0.0000, 0.0000],
             [0.0000, 0.0000]]], dtype=torch.float64)

现在,我有了这样的架构:

input_shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)
    n_classes = 10

    model = models.Sequential([
        data_scaling,

       #  data_augmentation,
        layers.Conv2D(32, kernel_size = (3,3), activation='relu', input_shape=input_shape),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),
        layers.Conv2D(64, (3, 3), activation='relu'),
        layers.MaxPooling2D((2, 2)),

        *Wavemix Lite block here*

        layers.Flatten(),
        layers.Dense(64, activation='relu'),
        layers.Dense(n_classes, activation='softmax'),
    ])

   model.build(input_shape=input_shape)

如何在此架构中实现2D-DWT模块?
我试着用这个类创建一个新的层,但是没有结果。我只需要实现2D-DWT块,这样架构就可以工作了。

jucafojl

jucafojl1#

您可以使用tf.keras.layers.Layer类将2D-DWT块实现为CNN架构中的自定义层。自定义层类的call方法应接受输入Tensor并使用pywt库执行2D-DWT变换。
下面是如何实施的解决方案:

import tensorflow as tf
import pywt

class DWT2D(tf.keras.layers.Layer):
    def __init__(self, wavelet='haar', **kwargs):
        super(DWT2D, self).__init__(**kwargs)
        self.wavelet = wavelet

    def call(self, input_tensor):
        coeffs = pywt.dwt2(input_tensor, self.wavelet)
        cA, (cH, cV, cD) = coeffs
        return cA, cH, cV, cD

input_shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, CHANNELS)
n_classes = 10

model = models.Sequential([
    data_scaling,

    layers.Conv2D(32, kernel_size = (3,3), activation='relu', input_shape=input_shape),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64,  kernel_size = (3,3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),
    layers.Conv2D(64, (3, 3), activation='relu'),
    layers.MaxPooling2D((2, 2)),

    DWT2D(),

    layers.Flatten(),
    layers.Dense(64, activation='relu'),
    layers.Dense(n_classes, activation='softmax'),
])
model.build(input_shape=input_shape)

您可以通过示例化DWT 2D类的对象并将其添加到模型来将自定义层添加到架构中。顺序列表。拟合模型后,自定义层将用于对输入Tensor执行2D-DWT变换。
此实现将返回4个Tensor:cA,cH,cV,cD。如果要在以下层中使用所有4个Tensor,则可能需要创建4个不同的层,并将每个Tensor传递到相应的层。
确保传递到自定义层的调用方法的输入Tensor为pywt.dwt2函数所需的np.float64类型。

相关问题