如何在我的python代码中启用面向Pytorch的英特尔扩展(IPEX)?

bjg7j2ky  于 2022-11-09  发布在  Python
关注(0)|答案(4)|浏览(360)

我想在我的代码中使用Pytorch的英特尔扩展来提高整体性能。请参考此GitHub(https://github.com/intel/intel-extension-for-pytorch)进行安装。
目前,我正在试用一个拥抱脸总结PyTorch示例(https://github.com/huggingface/transformers/blob/master/examples/pytorch/summarization/run_summarization.py)。


# Initialize our Trainer

    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset if training_args.do_train else None,
        eval_dataset=eval_dataset if training_args.do_eval else None,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics if training_args.predict_with_generate else None,
    )

我不知道在此代码中启用IPEX。有人能帮助我吗?
提前致谢!

uemypmqf

uemypmqf1#

启用IPEX所需的主要变更包括:


# Import the library:

    import intel_extension_for_pytorch as ipex
    #Apply the optimizations to the model for its datatype:
    model = ipex.optimize(model)
    #torch.channels_last should be applied to both of the model object and data to raise CPU resource usage efficiency.
    model = model.to(memory_format=torch.channels_last)
    data = data.to(memory_format=torch.channels_last)

另外,请查看https://intel.github.io/intel-extension-for-pytorch/latest/tutorials/examples.html了解IPEX示例。请查看IPEX官方页面https://www.intel.com/content/www/us/en/developer/tools/oneapi/extension-for-pytorch.html

c0vxltue

c0vxltue2#

要启用Intel Extension for Pytorch,您只需在代码中给予以下内容:

import intel_extension_for_pytorch as ipex

导入上述内容扩展了PyTorch,优化了英特尔硬件上的性能提升
之后,您必须在代码中添加此代码

model = model.to(ipex.DEVICE)
t9eec4r0

t9eec4r03#

首先,您需要将Trainer对象子类化,并创建一个自定义优化器,如Hugging Face docs
用于使用intel_extension_for_pytorch的API发生了一些变化,要使用该库,您只需执行以下操作:

import intel_extension_for_prytorch as ipex

model, optimizer = ipex.optimize(model, optimizer=optimizer)
umuewwlo

umuewwlo4#

目前,Transformers 4.21已支持IPEX。IPEX图形优化采用JIT模式

python run_qa.py
    --model_name_or_path csarron/bert-base-uncased-squad-v1 \
    --dataset_name squad \
    --do_eval \
    --max_seq_length 384 \
    --doc_stride 128 \
    --output_dir /tmp/ \
    --no_cuda \
    --use_ipex \
    --jit_mode_eval

相关问题