如何使用pytorch lightning在多个GPU之间手动拆分模型?

xytpbqjk  于 12个月前  发布在  其他
关注(0)|答案(1)|浏览(129)

我有一个模型,这是太大的一个单一的GPU。它上面有一个transformers列表,一旦我通过第17个隐藏层运行它,我就会得到一个CUDA out of memory error,因此我想在多个GPU上运行它。

import torch
import torch.nn as nn
import pytorch_lightning as pl

class SplitModel(pl.LightningModule):
    def __init__(self, device1, device2):
        super(SplitModel, self).__init__()

        # Define your model segments
        self.segment1 = #arbitraty in layer
        self.transformers = #a torch.nn.ModuleList of transformers
        self.segment2 = #arbitrary out layer

        self.loss_fn = nn.CrossEntropyLoss()
        self.device1 = torch.device('cuda:0')
        self.device2 = torch.device('cuda:1')

    def forward(self, x):
        # Forward pass for segment1 on device1
        x = self.segment1(x)
        for i, transformer in enumerate(self.transformers):
             current_device = '['+''.join("{}: {} ".format(name, next(child.parameters()).device if list(child.parameters()) else "CPU") for name, child in transformer.named_children()) + ']'
              print("itterating through transformer {} on device {}".format(i, current_device))
               attn, ff = transformer
               x = attn(x) + x
               x = ff(x) + x

        # Forward pass for segment2 on device2
        x = self.segment2(x)

        return x

    def training_step(self, batch, batch_idx):
        inputs, labels = batch

        # Forward pass
        outputs = self(inputs)

        # Calculate loss using segment2 outputs
        loss = self.loss_fn(outputs, labels)

        # Log loss for monitoring (optional)
        self.log('train_loss', loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = SplitModel()
ddp_strategy = DDPStrategy(find_unused_parameters=True)
Trainer = pl.Trainer(precision="16-mixed", accelerator="cuda", devices=[0, 1], strategy=ddp_strategy)
data_loader = #some dataloader
Trainer.fit(model, data_loader)

那么这将是一个示例输出:

itterating through transformer 0 on device [0: cuda:0]
itterating through transformer 1 on device [0: cuda:0 1: cuda:0]
itterating through transformer 2 on device [0: cuda:0 1: cuda:0]
itterating through transformer 3 on device [0: cuda:0 1: cuda:0]
itterating through transformer 4 on device [0: cuda:0 1: cuda:0]
itterating through transformer 5 on device [0: cuda:0 1: cuda:0]
itterating through transformer 6 on device [0: cuda:0 1: cuda:0]
itterating through transformer 7 on device [0: cuda:0 1: cuda:0]
itterating through transformer 8 on device [0: cuda:0 1: cuda:0]
itterating through transformer 9 on device [0: cuda:0 1: cuda:0]
itterating through transformer 10 on device [0: cuda:0 1: cuda:0]
itterating through transformer 11 on device [0: cuda:0 1: cuda:0]
itterating through transformer 12 on device [0: cuda:0 1: cuda:0]
itterating through transformer 13 on device [0: cuda:0 1: cuda:0]
itterating through transformer 14 on device [0: cuda:0 1: cuda:0]
itterating through transformer 15 on device [0: cuda:0 1: cuda:0]
itterating through transformer 16 on device [0: cuda:0 1: cuda:0]
itterating through transformer 17 on device [0: cuda:0 1: cuda:0]
CUDA out of memory error

但是,如果我将这行代码添加到forward pass中:

self.Segment2 = self.Segment2.to(self.device2)
for i, transformer in enumerate(self.transformers):
     if i == 17:
         x = x.to(self.device2)
     if i > 16:
         transformer = transformer.to(self.device2)
    #the rest of iterating through the transformers
return self.Segment2(x).to(self.device1)

然后我没有得到一个CUDA out of memory error,但是,我得到了以下错误从向后传递:

RuntimeError: grad.device() == bucket_view.device() INTERNAL ASSERT FAILED at "../torch/csrc/distributed/c10d/reducer.cpp":314, please report a bug to PyTorch.

我还研究了对模型进行分片,而不是手动决定将哪些部分放在GPU上。在pl.Trainer模块中的strategy将是strategy="fsdp",我得到了一个关于批量范数变量的错误,一个是torch.cuda.FloatTensortorch.cuda.HalfTensor
有没有一种方法可以做到这一点,我创建一个自定义的反向层,手动更改设备?

zpqajqem

zpqajqem1#

让我们假设你有一个像你描述的模型:

import torch
import torch.nn as nn
import pytorch_lightning as pl

class SplitModel(pl.LightningModule):
    def __init__(self):
        super(SplitModel, self).__init__()

        # Define your model segments
        self.segment1 = nn.Sequential(
            # arbitrary layers
        )
        self.transformers = nn.ModuleList([
            # List of transformers
        ])
        self.segment2 = nn.Sequential(
            # arbitrary layers
        )

        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        # Forward pass for segment1
        x = self.segment1(x)
        for transformer in self.transformers:
            x = transformer(x)

        # Forward pass for segment2
        x = self.segment2(x)
        return x

    def training_step(self, batch, batch_idx):
        inputs, labels = batch

        # Forward pass
        outputs = self(inputs)

        # Calculate loss using segment2 outputs
        loss = self.loss_fn(outputs, labels)

        # Log loss for monitoring (optional)
        self.log('train_loss', loss)

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

model = SplitModel()

让我们设置DDP和多GPU支持的训练:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# Define the data loader
data_loader = # Your data loader

# Set up DDP and Trainer
checkpoint_callback = ModelCheckpoint(dirpath='./checkpoints', filename='model-{epoch:02d}')
trainer = Trainer(
    gpus=[0, 1],  # List of GPUs to use
    precision=16,  # You can change the precision as needed
    accelerator='ddp',  # Use DistributedDataParallel (DDP)
    callbacks=[checkpoint_callback],
)

# Train the model
trainer.fit(model, data_loader)

此设置应适用于模型的多GPU训练。DDP将自动处理指定GPU之间的数据并行和模型拆分。确保您的模型构造正确,数据加载器设置正确。
如果您在使用混合精度时遇到与批量归一化层相关的问题,您可能需要在Trainer中设置sync_bn=False以禁用跨GPU的批量归一化同步:

trainer = Trainer(
    gpus=[0, 1],
    precision=16,
    accelerator='ddp',
    callbacks=[checkpoint_callback],
    sync_bn=False,  # Disable batch norm synchronization
)

如果您仍然遇到内存或其他错误的问题,您可能需要考虑模型优化技术,如梯度累积,减少批量大小或使用梯度裁剪来管理内存使用。此外,请确保您拥有最新版本的PyTorch和PyTorch Lightning,因为更新可能已经解决了您提到的一些问题。

相关问题