def train_model(model, criterion, optimizer, lr_scheduler, num_epochs=25):
    since = time.time()

    best_model = model
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                optimizer = lr_scheduler(optimizer, epoch)
                model.train()  # Set model to training mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for data in dset_loaders[phase]:
                inputs, labels = data
                # wrap them in Variable
                if use_gpu:
                        inputs, labels = Variable(inputs.float().cuda()),                             
                    inputs, labels = Variable(inputs), Variable(labels)

                # Set gradient to zero to delete history of computations in previous epoch. Track operations so that differentiation can be done automatically.
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)

                loss = criterion(outputs, labels)
                # print('loss done')                
                # Just so that you can keep track that something's happening and don't feel like the program isn't running.
                # if counter%10==0:
                #     print("Reached iteration ",counter)

                # backward + optimize only if in training phase
                if phase == 'train':
                    # print('loss backward')
                    # print('done loss backward')
                    # print('done optim')
                # print evaluation statistics
                    # running_loss += loss.data[0]
                    running_loss += loss.item()
                    # print(labels.data)
                    # print(preds)
                    running_corrects += torch.sum(preds == labels.data)
                    # print('running correct =',running_corrects)
                    print('unexpected error, could not calculate loss or do a sum.')
            print('trying epoch loss')
            epoch_loss = running_loss / dset_sizes[phase]
            epoch_acc = running_corrects.item() / float(dset_sizes[phase])
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val':
                if USE_TENSORBOARD:
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model = copy.deepcopy(model)
                    print('new best accuracy = ',best_acc)
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))
    print('returning and looping back')
    return best_model

def exp_lr_scheduler(optimizer, epoch, init_lr=BASE_LR, lr_decay_epoch=EPOCH_DECAY):
    """Decay learning rate by a factor of DECAY_WEIGHT every lr_decay_epoch epochs."""
    lr = init_lr * (DECAY_WEIGHT**(epoch // lr_decay_epoch))

    if epoch % lr_decay_epoch == 0:
        print('LR is set to {}'.format(lr))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return optimizer



outputs = model(inputs)
_, preds = torch.max(outputs.data, 1)

acc_all = (preds == labels).float().mean()


acc = [0 for c in list_of_classes]
for c in list_of_classes:
    acc[c] = ((preds == labels) * (labels == c)).float() / (max(labels == c).sum(), 1))


您也可以考虑使用sklearn classification_report来详细报告多类分类模型的性能。它可以为您提供所有类的精度、召回率和f1-score等参数,然后提供总体的宏和加权平均值。

from sklearn.metrics import classification_report
output = model(test_input.float())
_, predictions = torch.max(output, dim = 1)

print(classification_report(true_labels, predictions))



with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        #print("Outputs=",outputs.shape) #Outputs= torch.Size([64, 10])
        _, predicted = torch.max(outputs.data, 1) # get the class with the most probability out
        #print("predicted=",predicted.shape,predicted[10]) # predicted= torch.Size([64])
        #print("labels=",labels.shape,labels[10]) #labels= torch.Size([64]) 
        total += labels.size(0)
        correct += (predicted == labels).float().sum().item()  #this is Torch Tensor semantics
        #print("correct",correct) # say 56 out of 64
        #print("classification_report",classification_report(labels.cpu(), predicted.cpu()))
        #-------- Lets check also which classes are wrongly predicted with other classes (we need to clip at max prob > .5 to do)
        mask=(predicted != labels)
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        zipped = zip(wrong_labels,wrong_predicted)

        for _,j in enumerate(zipped):

        for index, element in enumerate(categories):
            cal = ((predicted == labels)*(labels ==index)).sum().item()/ ((labels == index).sum()) #this is Torch Tensor semantics
            wrong_class = (predicted != labels)*(labels == index)
            # >>> import torch
            # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
            # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
            # >>> (some_integers ==some_integers3)*(some_integers == 3)
            # tensor([False,  True, False, False, False, False, False, False])
            # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
            # 3
            if not math.isnan(cal):

    avg_accuracy =[]    
    for key,val in precision_per_class.items():
        avg = np.mean(val)
        precision_per_class[key] = avg
        print(f"Accuracy of Class {key}={avg}")

    # Just to cross check with the average accuracy results bleow    
    print(f"Average accuracy={np.mean(avg_accuracy)}")

    for key,val in wrong_per_class.items():
        print(f"wrong_per_class {categories[key]}={Counter(val)}")

        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total


Accuracy of Class tench=0.8504464285714286
Accuracy of Class English springer=0.6907253691128322
Accuracy of Class cassette player=0.7420465648174286
Accuracy of Class chain saw=0.5169889160564968
Accuracy of Class church=0.6264965534210205
Accuracy of Class French horn=0.5337499976158142
Accuracy of Class garbage truck=0.7543565290314811
Accuracy of Class gas pump=0.5343750034059797
Accuracy of Class golf ball=0.5873511944498334
Accuracy of Class parachute=0.5481353274413517
Average accuracy=0.6384671883923666
wrong_per_class tench=Counter({3: 25, 8: 16, 1: 10, 2: 7, 6: 3, 5: 3, 7: 2, 9: 1})
wrong_per_class English springer=Counter({3: 39, 0: 23, 8: 21, 6: 7, 5: 7, 7: 3, 9: 3, 4: 3, 2: 3})
wrong_per_class cassette player=Counter({7: 36, 6: 14, 3: 13, 8: 11, 0: 8, 5: 4, 1: 4, 4: 2})
wrong_per_class chain saw=Counter({0: 49, 1: 30, 6: 27, 7: 22, 5: 21, 4: 19, 2: 12, 8: 8, 9: 4})
wrong_per_class church=Counter({6: 23, 5: 21, 3: 20, 7: 19, 8: 16, 0: 14, 2: 10, 9: 7, 1: 5})
wrong_per_class French horn=Counter({3: 64, 4: 26, 2: 22, 1: 21, 7: 19, 0: 13, 8: 12, 6: 11})
wrong_per_class garbage truck=Counter({3: 28, 4: 23, 2: 14, 7: 14, 0: 8, 5: 5, 1: 4, 8: 2})
wrong_per_class gas pump=Counter({2: 50, 6: 46, 3: 41, 4: 23, 1: 11, 5: 9, 8: 8, 0: 7, 9: 2})
wrong_per_class golf ball=Counter({1: 38, 0: 37, 3: 27, 4: 17, 9: 11, 5: 10, 2: 9, 6: 7, 7: 6})
wrong_per_class parachute=Counter({8: 56, 3: 46, 4: 19, 6: 13, 7: 12, 0: 10, 2: 6, 1: 6, 5: 2})
Accuracy of the network on the 3925 test/validation images: 64.07643312101911 %





# In test phase, we don't need to compute gradients (for memory efficiency)

with torch.no_grad():
    model.eval() #IMPORTANT set model to eval mode before inference
    correct = 0
    total = 0

    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)

        # ------------------------------------------------------------------------------------------
        # Predict for the batch of images
        # ------------------------------------------------------------------------------------------
        outputs = model(images)  #Outputs= torch.Size([64, 10]) Probability of each of the 10 classes
        _, predicted = torch.max(outputs.data, 1) # get the class with the highest Probability out Given 1 per image # predicted= torch.Size([64])
        total += labels.size(0) #labels= torch.Size([64])  This is the truth value per image - the right class
        correct += (predicted == labels).float().sum().item()  # Find which are correctly classified

        # ------------------------------------------------------------------------------------------
        #  Lets check also which classes are wrongly predicted with other classes  to create a MultiClass confusion matrix
        # ------------------------------------------------------------------------------------------

        mask=(predicted != labels) # Wrongly predicted
        wrong_predicted =torch.masked_select(predicted,mask)
        wrong_labels =torch.masked_select(labels,mask)
        wrongly_zipped = zip(wrong_labels,wrong_predicted)

        mask=(predicted == labels) # Rightly predicted
        rightly_predicted =torch.masked_select(predicted,mask)
        right_labels =rightly_predicted #same torch.masked_select(labels,mask)
        rightly_zipped = zip(right_labels,rightly_predicted)

        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(wrongly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            confusion_matrix[k][l] +=1

        # Note that this is for a single batch - add to the list associated with class
        for _,j in enumerate(rightly_zipped):
            k = j[0].item() # label
            l = j[1].item() # predicted
            confusion_matrix[k][l] +=1

    #print("Confusion Matrix1=\n",confusion_matrix)
    # ------------------------------------------------------------------------------------------
    # Print Confusion matrix in Pretty print format
    # ------------------------------------------------------------------------------------------
    for i in range(len(categories)):
        for j in range(len(categories)):
    # ------------------------------------------------------------------------------------------
    # Calculate Accuracy per class
    # ------------------------------------------------------------------------------------------
    total_correct =0
    for i in range(len(categories)):
        print(f"Average accuracy per class {categories[i]} from confusion matrix {confusion_matrix[i][i]/confusion_matrix[i].sum()}")
        total_correct +=confusion_matrix[i][i]

    print(f"Average Accuracy/precision from the confusion matrix is {total_correct/confusion_matrix.sum()}")

    # Overall accuracy as below
        "Accuracy of the network on the {} test/validation images: {} %".format(
            total, 100 * correct / total


# Below illustrates the above Torch Tensor semantics

        # >>> import torch
        # >>> some_integers = torch.tensor((2, 3, 5, 7, 11, 13, 17, 19))
        # >>> some_integers3 = torch.tensor((12, 3, 5, 7, 11, 13, 17, 19))
        # >>> (some_integers ==some_integers3)*(some_integers == 3)
        # tensor([False,  True, False, False, False, False, False, False])
        # >>> ((some_integers ==some_integers3)*(some_integers >12)).sum().item()
        # 3


2022-10-20 13:38:01,112 Gpu device NVIDIA GeForce RTX 3060 Laptop GPU
['tench', 'English springer', 'cassette player', 'chain saw', 'church', 'French horn', 'garbage truck', 'gas pump', 'golf ball', 'parachute']
        320.0   10.0    7.0     25.0    0.0     3.0     3.0     2.0     16.0    1.0     tench
        23.0    286.0   3.0     39.0    3.0     7.0     7.0     3.0     21.0    3.0     English springer
        8.0     4.0     265.0   13.0    2.0     4.0     14.0    36.0    11.0    0.0     cassette player
        49.0    30.0    12.0    194.0   19.0    21.0    27.0    22.0    8.0     4.0     chain saw
        14.0    5.0     10.0    20.0    274.0   21.0    23.0    19.0    16.0    7.0     church
        13.0    21.0    22.0    64.0    26.0    206.0   11.0    19.0    12.0    0.0     French horn
        8.0     4.0     14.0    28.0    23.0    5.0     291.0   14.0    2.0     0.0     garbage truck
        7.0     11.0    50.0    41.0    23.0    9.0     46.0    222.0   8.0     2.0     gas pump
        37.0    38.0    9.0     27.0    17.0    10.0    7.0     6.0     237.0   11.0    golf ball
        10.0    6.0     6.0     46.0    19.0    2.0     13.0    12.0    56.0    220.0   parachute
Average accuracy per class tench from confusion matrix 0.8268733850129198
Average accuracy per class English springer from confusion matrix 0.7240506329113924
Average accuracy per class cassette player from confusion matrix 0.742296918767507
Average accuracy per class chain saw from confusion matrix 0.5025906735751295
Average accuracy per class church from confusion matrix 0.6699266503667481
Average accuracy per class French horn from confusion matrix 0.5228426395939086
Average accuracy per class garbage truck from confusion matrix 0.7480719794344473
Average accuracy per class gas pump from confusion matrix 0.5298329355608592
Average accuracy per class golf ball from confusion matrix 0.5939849624060151
Average accuracy per class parachute from confusion matrix 0.5641025641025641
Average Accuracy/precision from the confusion matrix is 0.640764331210191
Accuracy of the network on the 3925 test/validation images: 64.07643312101911
