如何在pyspark中测试/训练按列值而不是按行分割

k7fdbhmy  于 2021-05-18  发布在  Spark
关注(0)|答案(1)|浏览(845)

我想为机器学习生成一个训练和测试集。假设我有一个包含以下列的Dataframe:

account_id | session_id | feature_1 | feature_2 | label

在这个数据集中,每一行都有一个唯一的session\u id,但是一个account\u id可以出现多次。但是,我希望我的训练集和测试集具有互斥的帐户ID(几乎与分层抽样相反)。
对Pandas来说,这很简单。我有如下几点:

def train_test_split(df, split_col, feature_cols, label_col, test_fraction=0.2):
    """
    While sklearn train_test_split splits by each row in the dataset,
    this function will split by a specific column. In that way, we can 
    separate account_id such that train and test sets will have mutually
    exclusive accounts, to minimize cross-talk between train and test sets.
    """
    split_values = df[split_col].drop_duplicates()
    test_values = split_values.sample(frac=test_fraction, random_state=42)

    df_test = df[df[split_col].isin(test_values)]
    df_train = df[~df[split_col].isin(test_values)]

    return df_test, df_train

现在,我的数据集足够大,无法放入内存,我必须从pandas切换到pyspark中完成所有这些。如何在pyspark中拆分一个train和测试集,使其拥有互斥的account\u id,而不将所有内容都放入内存?

4smxwvx5

4smxwvx51#

你可以用 rand() 函数来自 pyspark.sql.functions 为每个不同的 account_id 创造 train 以及 test 基于此随机数的Dataframe。

from psypark.sql import functions as F

TEST_FRACTION = 0.2

train_test_split = (df.select("account_id")
                      .distinct()  # removing duplicate account_ids
                      .withColumn("rand_val", F.rand())
                      .withColumn("data_type", F.when(F.col("rand_val") < TEST_FRACTION, "test")
                                                .otherwise("train")))

train_df = (train_test_split.filter(F.col("data_type") == "train")
                            .join(df, on="account_id"))  # inner join removes all rows other than train

test_df = (train_test_split.filter(F.col("data_type") == "test")
                           .join(df, on="account_id"))

自从 account_id 不能两者兼而有之 train 以及 test 一次, train_df 以及 test_df 将具有互斥性 account_id s。

相关问题