keras 为TensorFlow深度学习加载大型数据集

q43xntqr  于 2023-02-23  发布在  其他
关注(0)|答案(1)|浏览(181)

我正在加载由数千张MRI图像组成的数据。我使用nibabel从MRI文件中获取3D数据数组:

def get_voxels(path):
    img = nib.load(path)
    data = img.get_fdata()

    return data.copy()

df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
df = df.reset_index()

labels = []
images = []
for index, row in df.iterrows():
    images.append(get_voxels(row['path']))
    labels.append(row['pass'])
labels = np.array(labels)
images = np.array(images)

n = len(df.index)
train_n = int(0.8 * n)
train_images = images[:train_n]
train_labels = labels[:train_n]
validation_n = (n - train_n) // 2
validation_end = train_n + validation_n
validation_images, validation_labels = images[train_n:validation_end], labels[train_n:validation_end]
test_images = images[validation_end:]
test_labels = labels[validation_end:]

train_ds = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
validation_ds = tf.data.Dataset.from_tensor_slices((validation_images, validation_labels))
test_ds = tf.data.Dataset.from_tensor_slices((test_images, test_labels))

正如您所看到的,我使用的是tf.data.Dataset.from_tensor_slices,但是,由于大量的大文件,内存不足。
在TensorFlow或Keras中是否有更好的方法来执行此操作。

7eumitmz

7eumitmz1#

按照3D image classification from CT scans中的说明执行Hasib Zunair

import nibabel as nib
import pandas as pd
import numpy as np

def process_scan(path):
    """Read and resize volume"""
    # Read scan
    volume = read_nifti_file(path)
    # Normalize
    volume = normalize(volume)
    # Resize width, height and depth
    volume = resize_volume(volume)
    return volume

df = pd.read_csv("/home/paths_updated_shuffled_4.csv")
n = len(df.index)
passing_rows = df.loc[df['pass'] == 1]
normal_scan_paths = passing_rows['path'].tolist()
failing_rows = df.loc[df['pass'] == 0]
abnormal_scan_paths = failing_rows['path'].tolist()

print("Passing MRI scans: " + str(len(normal_scan_paths)))
print("Failing MRI scans: " + str(len(abnormal_scan_paths)))

# Loading data and preprocessing
# Read and process the scans.
# Each scan is resized across height, width, and depth and rescaled.
abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
normal_scans = np.array([process_scan(path) for path in normal_scan_paths])

相关问题