python Tensorflow M2 Pro故障

3pvhb19x  于 2023-03-28  发布在  Python
关注(0)|答案(1)|浏览(186)

当我运行下面的tensorflow测试脚本时

import tensorflow as tf
cifar = tf.keras.datasets.cifar100
(x_train, y_train), (x_test, y_test) = cifar.load_data()
model = tf.keras.applications.ResNet50(
    include_top=True,
    weights=None,
    input_shape=(32, 32, 3),
    classes=100,)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer="adam", loss=loss_fn, metrics=["accuracy"])
model.fit(x_train, y_train, epochs=5, batch_size=4)

我得到以下终端输出:

Metal device set to: Apple M2 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

2023-03-23 00:26:32.203361: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-03-23 00:26:32.203521: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
zsh: bus error  python3 app/model/tf_verify.py
xxhby3vn

xxhby3vn1#

虽然从苹果官方的documentation不清楚,但看起来tensorflow-macos版本应该与“发布”部分的tensorflow-metal插件版本相匹配。由于您使用的是tensorflow-macos==2.9,因此应该使用tensorflow-metal==0.5.0而不是tensorflow-metal==0.6.0
我能够在MacBook Pro M1 Pro上重现并解决这个问题,训练效果很好。

相关问题