我正在做一个项目,我需要分离睡眠数据和它的标签。但是我被上面提到的错误卡住了。
由于我是机器学习方面的新手,如果有人能帮助我解决这个问题,我将非常感激。
我已经使用以下代码实现了一个模型:
EEG_training_data = EEG_training_data.reshape(EEG_training_data.shape[0], EEG_training_data.shape[1],1)
print(EEG_training_data.shape)# (5360, 5000, 1)
EEG_validation_data = EEG_validation_data.reshape(EEG_validation_data.shape[0], EEG_validation_data.shape[1],1)
print(EEG_validation_data.shape)#(1396, 5000, 1)
label_class = (np.unique(EEG_training_label))
num_classes = label_class.size # num_classes = 5
#define the model using CNN
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv1D(filters=64, kernel_size= 16, activation='relu', batch_input_shape=(None,5000, 1))) # #input_shape=(5000, 1)
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.MaxPool1D(8, padding='same'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(16, activation='relu'))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))
#Summary of the model defined:
model.summary()
#Define loss function
model.compile(
loss= 'categorical_crossentropy', # 'sparse_categorical_crossentropy',
optimizer='adam',
metrics=[tf.keras.metrics.FalseNegatives(), tf.keras.metrics.FalsePositives(), 'accuracy'])
#one Hot Encoding
y_train_hot = tf.keras.utils.to_categorical(EEG_training_label, num_classes)
print('New y_train shape: ', y_train_hot.shape)#(5360, 5)
y_valid_hot = tf.keras.utils.to_categorical(EEG_validation_label, num_classes)
print('New y_valid shape: ', y_valid_hot.shape)#(1396, 5)
# apply fit on data
model_history = model.fit(
x=EEG_training_data,
y=y_train_hot,
batch_size=32,
epochs=5,
validation_data=(EEG_validation_data, y_valid_hot),
)
model_prediction = model.predict(EEG_testing_data)
predicted_matrix = tf.math.confusion_matrix(labels=EEG_testing_label.argmax(axis=1), predictions=model_prediction.argmax(axis=1)).numpy()
print(predicted_matrix)
1条答案
按热度按时间pprl5pva1#
您提供的程式码没有问题。请尝试执行下列程式码,应该可以正常运作。如果是这样,请再次检查您所有数据的形体,例如
EEG_training_data
等,是否如下列所示: