pyspark 3.0.1无法在tensorflow 2.1.0中运行分布式培训

a2mppw5e  于 2021-07-03  发布在  Java
关注(0)|答案(0)|浏览(357)

我正试图根据您在这里找到的关于超参数调优的tensorboard api文档,在tensorflow上训练一个简单的时尚列表模型
目前,为了测试的目的,我运行在独立模式,所以。 master = 'local[*]' 我已安装 pyspark==3.0.1 以及 tensorflow==2.1.0 . 以下是我正在尝试运行的内容:


# For a given hyper parameter, this will run the train & return the model + accuracy which I'm looking for.

# This works when I run without spark.

def train(hparam) -> Tuple[Model, Any]:
    fashion_mnist = fashion
    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    model = Sequential([
        Flatten(),
        Dense(hparam['num_units'], activation=tf.nn.relu),
        Dropout(hparam['dropout']),
        Dense(10, activation=tf.nn.softmax),
    ])
    model.compile(
        optimizer=hparam['optimizer'],
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'],
    )
    model.fit(x_train, y_train, epochs=1)  # Run with 1 epoch to speed things up for demo purposes
    _, accuracy = model.evaluate(x_test, y_test)
    return model, accuracy

这是我运行的Spark代码。

if __name__ == '__main__':

     hp_nums = hp.HParam('num_units', hp.Discrete([16, 32]))
     hp_dropouts = hp.HParam('dropout', hp.RealInterval(0.1, 0.2))
     hp_opts = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))

     all_params = [] ##contains a list of different hparams

     for num_units in hp_nums.domain.values:
         for dropout_rate in (hp_dropouts.domain.min_value, hp_dropouts.domain.max_value):
             for optimizer in hp_opts.domain.values:
                 hparams = {
                    'num_units': num_units,
                    'dropout': dropout_rate,
                    'optimizer': optimizer,
                 }
                 all_params.append(hparams)

     spark_sess = SparkSession.builder.master(
         'local[*]'
     ).appName(
         'LocalTraining'
     ).getOrCreate()

     res = spark_sess.sparkContext.parallelize(
          all_hparams, len(all_hparams)
     ).map(
          train #above function
     ).collect()

     temp = 0.0
     best_model = None
     for model, acc in res:
         if acc > temp:
             best_model = model

     print("best accuracy is -> " + str(temp))

这看起来不错,适用于任何简单的mapreduce(如基本示例)。这让我相信我的环境是完美的。
我的环境:

java : Java 11.0.8 2020-07-14 LTS
python: Python 3.6.5
pyspark: 3.0.1
tensorflow: 2.1.0
Keras: 2.3.1
windows: 10 (if this really matters)
cores : 8 (i5 10th gen)
Memory: 6G

但是当我运行上面的代码时。我得到以下错误。我可以看到训练运行,在1个执行器运行之后它就停止了

59168/60000 [============================>.] - ETA: 0s - loss: 0.7350 - accuracy: 0.7471
60000/60000 [==============================] - 3s 42us/step - loss: 0.7331 - accuracy: 0.7477
20/12/05 14:03:57 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
0/12/05 14:03:57 ERROR TaskSetManager: Task 0 in stage 0.0 failed 1 times; aborting job
20/12/05 14:03:57 INFO TaskSchedulerImpl: Cancelling stage 0
20/12/05 14:03:57 INFO TaskSchedulerImpl: Killing all running tasks in stage 0: Stage cancelled
20/12/05 14:03:57 INFO Executor: Executor is trying to kill task 1.0 in stage 0.0 (TID 1), reason: Stage cancelled
20/12/05 14:03:57 INFO TaskSchedulerImpl: Stage 0 was cancelled
20/12/05 14:03:57 INFO DAGScheduler: ResultStage 0 (collect at C:/Users/<>/<>/<>/main.py:<>) failed in 7.506 s due to Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
    at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator.foreach(Iterator.scala:941)
    at scala.collection.Iterator.foreach$(Iterator.scala:941)

py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
    at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
    at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
    at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
    at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
    at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
    at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
    at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
    at scala.collection.Iterator.foreach(Iterator.scala:941)

Driver stacktrace:
20/12/05 14:03:57 INFO DAGScheduler: Job 0 failed: collect at C:/<>/<>/<>/main.py, took 7.541442 s
Traceback (most recent call last):
  File "C:/<>/<>/<>/main.py", line 68, in main
    return res.collect()
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\rdd.py", line 889, in collect
    sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\java_gateway.py", line 1305, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\sql\utils.py", line 128, in deco
    return f(*a,**kw)
  File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\protocol.py", line 328, in get_return_value
    format(target_id, ".", name), value)

错误在线路上 model.fit() . [只有当我这么做的时候才会发生 model.fit 如果我把它评论出来,并在那里有其他的东西,它的工作非常好。我不确定它为什么在model.fit()上失败

暂无答案!

目前还没有任何答案,快来回答吧!

相关问题