我有一个包含超过60亿行数据的Spark RDD,我想使用train_on_batch来训练深度学习模型。我无法将所有行都放入内存,因此我希望一次获得10K左右的数据,以批量处理成64或128的块(取决于模型大小)。我目前使用的是rdd.sample(),但我不认为这能保证得到所有行。有没有更好的方法来划分数据,使其更易于管理,以便我可以编写一个生成器函数来获取批处理?我的代码如下:
data_df = spark.read.parquet(PARQUET_FILE)
print(f'RDD Count: {data_df.count()}') # 6B+
data_sample = data_df.sample(True, 0.0000015).take(6400)
sample_df = data_sample.toPandas()
def get_batch():
for row in sample_df.itertuples():
# TODO: put together a batch size of BATCH_SIZE
yield row
for i in range(10):
print(next(get_batch()))
字符串
2条答案
按热度按时间hgqdbh6s1#
试试这个:
字符串
z9smfwbn2#
我不相信Spark让你抵消或分页您的数据。
但你可以添加一个索引,然后分页,首先:
字符串
这不是最佳的,因为使用了pandas Dataframe ,它会严重利用spark,但会解决你的问题。
如果
id
影响到您的功能,请不要忘记删除它。