tensorflow.js向模型中添加许多图像样本会填充视频卡内存并导致崩溃

yftpprvb  于 2021-09-23  发布在  Java
关注(0)|答案(1)|浏览(349)

我指的是这个例子
https://github.com/tensorflow/tfjs-examples/tree/master/webcam-transfer-learning
https://storage.googleapis.com/tfjs-examples/webcam-transfer-learning/dist/index.html
您可以看到,演示可以拍摄各种场景,训练模型,然后预测网络摄像头场景的情况。
我将演示代码更改为自己的代码,并使用文件输入上传了大量图片作为输入示例。
当我上传许多(300-400)图片224*244图片时,每个图片的大小约为70kb,我的图形卡内存(rx 570 4gb)将被填满,然后崩溃。
这是我的演示视频
https://www.youtube.com/watch?v=irnd29lcqi0
错误消息 Uncaught (in promise) Error: Failed to compile fragment shader. 这是我的代码:

class ControllerDataset {
    constructor(numClasses) {
            this.numClasses = numClasses;
        }
        /**
         * Adds an example to the controller dataset.
         * @param {Tensor} example A tensor representing the example. It can be an image,
         *     an activation, or any other type of Tensor.
         * @param {number} label The label of the example. Should be a number.
         */
    addExample(example, label) {
        // One-hot encode the label.
        const y = tf.tidy(
            () => tf.oneHot(tf.tensor1d([label]).toInt(), this.numClasses));

        if (this.xs == null) {
            // For the first example that gets added, keep example and y so that the
            // ControllerDataset owns the memory of the inputs. This makes sure that
            // if addExample() is called in a tf.tidy(), these Tensors will not get
            // disposed.
            this.xs = tf.keep(example);
            this.ys = tf.keep(y);
        } else {
            const oldX = this.xs;
            this.xs = tf.keep(oldX.concat(example, 0));

            const oldY = this.ys;
            this.ys = tf.keep(oldY.concat(y, 0));

            oldX.dispose();
            oldY.dispose();
            y.dispose();
        }
    }
}

var truncatedMobileNet;
const NUM_CLASSES = 3;
const controllerDataset = new ControllerDataset(NUM_CLASSES);

async function addMultiSampleFromInputfile(files, label) {
    for (let index = 0; index < files.length; index++) {

        const file = files[index];
        let image = await readFileToImageElement(file);
        let { sourceImageTensor, imageTensorNormalize } = getTensorImgFromElement(image)
        controllerDataset.addExample(truncatedMobileNet.predict(imageTensorNormalize), label);

        sourceImageTensor.dispose();
        imageTensorNormalize.dispose();
    }
}

// Loads mobilenet and returns a model that returns the internal activation
// we'll use as input to our classifier model.
async function loadTruncatedMobileNet() {
    const url = document.getElementById("MobileNetUrl").value
    const mobilenet = await tf.loadLayersModel(url);

    // Return a model that outputs an internal activation.
    const layer = mobilenet.getLayer('conv_pw_13_relu');

    return tf.model({ inputs: mobilenet.inputs, outputs: layer.output });
}

function getTensorImgFromElement(element) {
    const imageTensor = tf.browser.fromPixels(element);
    const processedImgTensor = tf.tidy(() => imageTensor.expandDims(0).toFloat().div(127).sub(1));

    return { sourceImageTensor: imageTensor, imageTensorNormalize: processedImgTensor }
}

function readFileToImageElement(file) {
    return new Promise((resolve, reject) => {
        let reader = new FileReader();
        reader.onload = function() {
            let image = document.createElement('img');
            image.src = this.result;
            image.onload = function() {
                resolve(image)
            }
        }
        reader.readAsDataURL(file);
    });
}

loadTruncatedMobileNet().then(model => {
    truncatedMobileNet = model;
})

// add multi sample from html input file
addMultiSmapleBtn.onclick = () => {
    // label value is 0 or 1 or 2
    if (truncatedMobileNet)
        addMultiSampleFromInputfile(imagefiles.files, parseInt(label.value))
}
iugsix8n

iugsix8n1#

试试这个:

async function readFileToImageElement(file) {
  const img = new Image()
  img.src = URL.createObjectURL(file)
  await img.decode()
  return img
}

// when the image is not needed anymore call:
URL.revokeObjectURL(img.src)

相关问题