使用CNN和pytorch计算每个类别的准确度

oyjwcjzk  于 2022-11-09  发布在  其他
关注(0)|答案(4)|浏览(206)

我可以使用此代码计算每个历元后的精度。但是,我想在最后计算每个类的精度。我该怎么做呢?我有两个文件夹train和瓦尔。每个文件夹有7个文件夹,7个不同的类。train文件夹用于训练。否则val文件夹用于测试

def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
    since = time.time()

    best_model = model
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                mode='train'
                optimizer = lr_scheduler(optimizer, epoch)
                model.train()  # Set model to training mode
            else:
                model.eval()
                mode='val'

            running_loss = 0.0
            running_corrects = 0

            counter=0
            # Iterate over data.
            for data in dset_loaders[phase]:
                inputs, labels = data
                print(inputs.size())
                # wrap them in Variable
                if use_gpu:
                    try:
                        inputs, labels = Variable(inputs.float().cuda()),                             
                        Variable(labels.long().cuda())
                    except:
                        print(inputs,labels)
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # Set gradient to zero to delete history of computations in previous epoch. Track operations so that differentiation can be done automatically.
                optimizer.zero_grad()
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)

                loss = criterion(outputs, labels)
                # print('loss done')                
                # Just so that you can keep track that something's happening and don't feel like the program isn't running.
                # if counter%10==0:
                #     print("Reached iteration ",counter)
                counter+=1

                # backward + optimize only if in training phase
                if phase == 'train':
                    # print('loss backward')
                    loss.backward()
                    # print('done loss backward')
                    optimizer.step()
                    # print('done optim')
                # print evaluation statistics
                try:
                    # running_loss += loss.data[0]
                    running_loss += loss.item()
                    # print(labels.data)
                    # print(preds)
                    running_corrects += torch.sum(preds == labels.data)
                    # print('running correct =',running_corrects)
                except:
                    print('unexpected error, could not calculate loss or do a sum.')
            print('trying epoch loss')
            epoch_loss = running_loss / dset_sizes[phase]
            epoch_acc = running_corrects.item() / float(dset_sizes[phase])
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val':
                if USE_TENSORBOARD:
                    foo.add_scalar_value('epoch_loss',epoch_loss,step=epoch)
                    foo.add_scalar_value('epoch_acc',epoch_acc,step=epoch)
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model)
                    print('new best accuracy = ',best_acc)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('returning and looping back')
    return best_model

def exp_lr_scheduler(optimizer, epoch, init_lr=BASE_LR, lr_decay_epoch=EPOCH_DECAY):
    """Decay learning rate by a factor of DECAY_WEIGHT every lr_decay_epoch epochs."""
    lr = init_lr * (DECAY_WEIGHT**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer
ryhaxcpt

ryhaxcpt1#

计算总体精度相当简单:

outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)

acc_all = (preds == labels).float().mean()

要按类计算它,需要多几行代码:

acc = [0 for c in list_of_classes]
for c in list_of_classes:
    acc[c] = ((preds == labels) * (labels == c)).float() / (max(labels == c).sum(), 1))
mbyulnm0

mbyulnm02#

您也可以考虑使用sklearn classification_report来详细报告多类分类模型的性能。它可以为您提供所有类的精度、召回率和f1-score等参数,然后提供总体的宏和加权平均值。
您可以使用此代码片段来完成此操作。

from sklearn.metrics import classification_report
output = model(test_input.float())
_, predictions = torch.max(output, dim = 1)

print(classification_report(true_labels, predictions))
8e2ybdfx

8e2ybdfx3#

在试图了解我的CNN的问题时偶然发现了这个问题。已经使用了维克托的解决方案。也想检查哪些类没有得到适当的培训,以及哪些类被错误地归类为其他类
此处代码为https://github.com/alexcpn/cnn_lenet_pytorch/blob/main/cnn/model_accuracy.py
以下片段

with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        #print("Outputs=",outputs.shape) #Outputs= torch.Size([64, 10])
        _, predicted = torch.max(outputs.data, 1) # get the class with the most probability out
        #print("predicted=",predicted.shape,predicted[10]) # predicted= torch.Size([64])
        #print("labels=",labels.shape,labels[10]) #labels= torch.Size([64]) 
        total += labels.size(0)
        correct += (predicted == labels).float().sum().item()  #this is Torch Tensor semantics
        #print("correct",correct) # say 56 out of 64
        #print("classification_report",classification_report(labels.cpu(), predicted.cpu()))
        #-------- Lets check also which classes are wrongly predicted with other classes (we need to clip at max prob > .5 to do)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        zipped = zip(wrong_labels,wrong_predicted)

        for _,j in enumerate(zipped):
            wrong_per_class[j[0].item()].append(j[1].item())
            #print(f"wrong_per_class{j[0].item()}={j[1].item()}",)

        for index, element in enumerate(categories):
            cal = ((predicted == labels)*(labels ==index)).sum().item()/ ((labels == index).sum()) #this is Torch Tensor semantics
            wrong_class = (predicted != labels)*(labels == index)
            # >>> import torch
            # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
            # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
            # >>> (some_integers ==some_integers3)*(some_integers == 3)
            # tensor([False,  True, False, False, False, False, False, False])
            # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
            # 3
            if not math.isnan(cal):
                precision_per_class[element].append(cal.item())
            #print(f"{element}={cal}")

    avg_accuracy =[]    
    for key,val in precision_per_class.items():
        avg = np.mean(val)
        precision_per_class[key] = avg
        avg_accuracy.append(avg)
        print(f"Accuracy of Class {key}={avg}")

    # Just to cross check with the average accuracy results bleow    
    print(f"Average accuracy={np.mean(avg_accuracy)}")

    for key,val in wrong_per_class.items():
        print(f"wrong_per_class {categories[key]}={Counter(val)}")

    print(
        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total
        )
    )

输出

Accuracy of Class tench=0.8504464285714286
Accuracy of Class English springer=0.6907253691128322
Accuracy of Class cassette player=0.7420465648174286
Accuracy of Class chain saw=0.5169889160564968
Accuracy of Class church=0.6264965534210205
Accuracy of Class French horn=0.5337499976158142
Accuracy of Class garbage truck=0.7543565290314811
Accuracy of Class gas pump=0.5343750034059797
Accuracy of Class golf ball=0.5873511944498334
Accuracy of Class parachute=0.5481353274413517
Average accuracy=0.6384671883923666
wrong_per_class tench=Counter({3: 25, 8: 16, 1: 10, 2: 7, 6: 3, 5: 3, 7: 2, 9: 1})
wrong_per_class English springer=Counter({3: 39, 0: 23, 8: 21, 6: 7, 5: 7, 7: 3, 9: 3, 4: 3, 2: 3})
wrong_per_class cassette player=Counter({7: 36, 6: 14, 3: 13, 8: 11, 0: 8, 5: 4, 1: 4, 4: 2})
wrong_per_class chain saw=Counter({0: 49, 1: 30, 6: 27, 7: 22, 5: 21, 4: 19, 2: 12, 8: 8, 9: 4})
wrong_per_class church=Counter({6: 23, 5: 21, 3: 20, 7: 19, 8: 16, 0: 14, 2: 10, 9: 7, 1: 5})
wrong_per_class French horn=Counter({3: 64, 4: 26, 2: 22, 1: 21, 7: 19, 0: 13, 8: 12, 6: 11})
wrong_per_class garbage truck=Counter({3: 28, 4: 23, 2: 14, 7: 14, 0: 8, 5: 5, 1: 4, 8: 2})
wrong_per_class gas pump=Counter({2: 50, 6: 46, 3: 41, 4: 23, 1: 11, 5: 9, 8: 8, 0: 7, 9: 2})
wrong_per_class golf ball=Counter({1: 38, 0: 37, 3: 27, 4: 17, 9: 11, 5: 10, 2: 9, 6: 7, 7: 6})
wrong_per_class parachute=Counter({8: 56, 3: 46, 4: 19, 6: 13, 7: 12, 0: 10, 2: 6, 1: 6, 5: 2})
Accuracy of the network on the 3925 test/validation images: 64.07643312101911 %

稍后将使用此数据以适当的图更新答案

zfciruhq

zfciruhq4#

一个比我之前的答案更准确的方法,首先创建一个混淆矩阵,然后从中进行推断;也将有助于培训的其他分析


# In test phase, we don't need to compute gradients (for memory efficiency)

with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        # ------------------------------------------------------------------------------------------
        # Predict for the batch of images
        # ------------------------------------------------------------------------------------------
        outputs = model(images)  #Outputs= torch.Size([64, 10]) Probability of each of the 10 classes
        _, predicted = torch.max(outputs.data, 1) # get the class with the highest Probability out Given 1 per image # predicted= torch.Size([64])
        total += labels.size(0) #labels= torch.Size([64])  This is the truth value per image - the right class
        correct += (predicted == labels).float().sum().item()  # Find which are correctly classified

        # ------------------------------------------------------------------------------------------
        #  Lets check also which classes are wrongly predicted with other classes  to create a MultiClass confusion matrix
        # ------------------------------------------------------------------------------------------

        mask=(predicted != labels) # Wrongly predicted
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        wrongly_zipped = zip(wrong_labels,wrong_predicted)

        mask=(predicted == labels) # Rightly predicted
        rightly_predicted =torch.masked_select(predicted,mask)
        right_labels =rightly_predicted #same torch.masked_select(labels,mask)
        rightly_zipped = zip(right_labels,rightly_predicted)

        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(wrongly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            wrong_per_class[k].append(l)
            confusion_matrix[k][l] +=1

        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(rightly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            right_per_class[k].append(l)
            confusion_matrix[k][l] +=1

    #print("Confusion Matrix1=\n",confusion_matrix)
    # ------------------------------------------------------------------------------------------
    # Print Confusion matrix in Pretty print format
    # ------------------------------------------------------------------------------------------
    print(categories)
    for i in range(len(categories)):
        for j in range(len(categories)):
            print(f"\t{confusion_matrix[i][j]}",end='')
        print(f"\t{categories[i]}\n",end='')
    # ------------------------------------------------------------------------------------------
    # Calculate Accuracy per class
    # ------------------------------------------------------------------------------------------
    print("---------------------------------------")
    total_correct =0
    for i in range(len(categories)):
        print(f"Average accuracy per class {categories[i]} from confusion matrix {confusion_matrix[i][i]/confusion_matrix[i].sum()}")
        total_correct +=confusion_matrix[i][i]

    print(f"Average Accuracy/precision from the confusion matrix is {total_correct/confusion_matrix.sum()}")

    # Overall accuracy as below
    print(
        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total
        )
    )

请注意,这使用了下面的TorchTensor语义


# Below illustrates the above Torch Tensor semantics

        # >>> import torch
        # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
        # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
        # >>> (some_integers ==some_integers3)*(some_integers == 3)
        # tensor([False,  True, False, False, False, False, False, False])
        # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
        # 3

输出量

2022-10-20 13:38:01,112 Gpu device NVIDIA GeForce RTX 3060 Laptop GPU
['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
        320.0   10.0    7.0     25.0    0.0     3.0     3.0     2.0     16.0    1.0     tench
        23.0    286.0   3.0     39.0    3.0     7.0     7.0     3.0     21.0    3.0     English springer
        8.0     4.0     265.0   13.0    2.0     4.0     14.0    36.0    11.0    0.0     cassette player
        49.0    30.0    12.0    194.0   19.0    21.0    27.0    22.0    8.0     4.0     chain saw
        14.0    5.0     10.0    20.0    274.0   21.0    23.0    19.0    16.0    7.0     church
        13.0    21.0    22.0    64.0    26.0    206.0   11.0    19.0    12.0    0.0     French horn
        8.0     4.0     14.0    28.0    23.0    5.0     291.0   14.0    2.0     0.0     garbage truck
        7.0     11.0    50.0    41.0    23.0    9.0     46.0    222.0   8.0     2.0     gas pump
        37.0    38.0    9.0     27.0    17.0    10.0    7.0     6.0     237.0   11.0    golf ball
        10.0    6.0     6.0     46.0    19.0    2.0     13.0    12.0    56.0    220.0   parachute
---------------------------------------
Average accuracy per class tench from confusion matrix 0.8268733850129198
Average accuracy per class English springer from confusion matrix 0.7240506329113924
Average accuracy per class cassette player from confusion matrix 0.742296918767507
Average accuracy per class chain saw from confusion matrix 0.5025906735751295
Average accuracy per class church from confusion matrix 0.6699266503667481
Average accuracy per class French horn from confusion matrix 0.5228426395939086
Average accuracy per class garbage truck from confusion matrix 0.7480719794344473
Average accuracy per class gas pump from confusion matrix 0.5298329355608592
Average accuracy per class golf ball from confusion matrix 0.5939849624060151
Average accuracy per class parachute from confusion matrix 0.5641025641025641
Average Accuracy/precision from the confusion matrix is 0.640764331210191
Accuracy of the network on the 3925 test/validation images: 64.07643312101911

相关问题