我正试图根据您在这里找到的关于超参数调优的tensorboard api文档,在tensorflow上训练一个简单的时尚列表模型
目前,为了测试的目的,我运行在独立模式,所以。 master = 'local[*]'
我已安装 pyspark==3.0.1
以及 tensorflow==2.1.0
. 以下是我正在尝试运行的内容:
# For a given hyper parameter, this will run the train & return the model + accuracy which I'm looking for.
# This works when I run without spark.
def train(hparam) -> Tuple[Model, Any]:
fashion_mnist = fashion
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = Sequential([
Flatten(),
Dense(hparam['num_units'], activation=tf.nn.relu),
Dropout(hparam['dropout']),
Dense(10, activation=tf.nn.softmax),
])
model.compile(
optimizer=hparam['optimizer'],
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
)
model.fit(x_train, y_train, epochs=1) # Run with 1 epoch to speed things up for demo purposes
_, accuracy = model.evaluate(x_test, y_test)
return model, accuracy
这是我运行的Spark代码。
if __name__ == '__main__':
hp_nums = hp.HParam('num_units', hp.Discrete([16, 32]))
hp_dropouts = hp.HParam('dropout', hp.RealInterval(0.1, 0.2))
hp_opts = hp.HParam('optimizer', hp.Discrete(['adam', 'sgd']))
all_params = [] ##contains a list of different hparams
for num_units in hp_nums.domain.values:
for dropout_rate in (hp_dropouts.domain.min_value, hp_dropouts.domain.max_value):
for optimizer in hp_opts.domain.values:
hparams = {
'num_units': num_units,
'dropout': dropout_rate,
'optimizer': optimizer,
}
all_params.append(hparams)
spark_sess = SparkSession.builder.master(
'local[*]'
).appName(
'LocalTraining'
).getOrCreate()
res = spark_sess.sparkContext.parallelize(
all_hparams, len(all_hparams)
).map(
train #above function
).collect()
temp = 0.0
best_model = None
for model, acc in res:
if acc > temp:
best_model = model
print("best accuracy is -> " + str(temp))
这看起来不错,适用于任何简单的mapreduce(如基本示例)。这让我相信我的环境是完美的。
我的环境:
java : Java 11.0.8 2020-07-14 LTS
python: Python 3.6.5
pyspark: 3.0.1
tensorflow: 2.1.0
Keras: 2.3.1
windows: 10 (if this really matters)
cores : 8 (i5 10th gen)
Memory: 6G
但是当我运行上面的代码时。我得到以下错误。我可以看到训练运行,在1个执行器运行之后它就停止了
59168/60000 [============================>.] - ETA: 0s - loss: 0.7350 - accuracy: 0.7471
60000/60000 [==============================] - 3s 42us/step - loss: 0.7331 - accuracy: 0.7477
20/12/05 14:03:57 ERROR Executor: Exception in task 0.0 in stage 0.0 (TID 0)
java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
0/12/05 14:03:57 ERROR TaskSetManager: Task 0 in stage 0.0 failed 1 times; aborting job
20/12/05 14:03:57 INFO TaskSchedulerImpl: Cancelling stage 0
20/12/05 14:03:57 INFO TaskSchedulerImpl: Killing all running tasks in stage 0: Stage cancelled
20/12/05 14:03:57 INFO Executor: Executor is trying to kill task 1.0 in stage 0.0 (TID 1), reason: Stage cancelled
20/12/05 14:03:57 INFO TaskSchedulerImpl: Stage 0 was cancelled
20/12/05 14:03:57 INFO DAGScheduler: ResultStage 0 (collect at C:/Users/<>/<>/<>/main.py:<>) failed in 7.506 s due to Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
at scala.collection.Iterator.foreach$(Iterator.scala:941)
py4j.protocol.Py4JJavaError: An error occurred while calling z:org.apache.spark.api.python.PythonRDD.collectAndServe.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0, host.docker.internal, executor driver): java.net.SocketException: Connection reset
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:186)
at java.base/java.net.SocketInputStream.read(SocketInputStream.java:140)
at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:252)
at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:271)
at java.base/java.io.DataInputStream.readInt(DataInputStream.java:392)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:628)
at org.apache.spark.api.python.PythonRunner$$anon$3.read(PythonRunner.scala:621)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:456)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator.foreach(Iterator.scala:941)
Driver stacktrace:
20/12/05 14:03:57 INFO DAGScheduler: Job 0 failed: collect at C:/<>/<>/<>/main.py, took 7.541442 s
Traceback (most recent call last):
File "C:/<>/<>/<>/main.py", line 68, in main
return res.collect()
File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\rdd.py", line 889, in collect
sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\java_gateway.py", line 1305, in __call__
answer, self.gateway_client, self.target_id, self.name)
File "C:\Users\<>\<>\<>\venv\lib\site-packages\pyspark\sql\utils.py", line 128, in deco
return f(*a,**kw)
File "C:\Users\<>\<>\<>\venv\lib\site-packages\py4j\protocol.py", line 328, in get_return_value
format(target_id, ".", name), value)
错误在线路上 model.fit()
. [只有当我这么做的时候才会发生 model.fit
如果我把它评论出来,并在那里有其他的东西,它的工作非常好。我不确定它为什么在model.fit()上失败
暂无答案!
目前还没有任何答案,快来回答吧!