WeightedRandomSampler示例

x33g5p2x  于2022-03-09 转载在 其他  
字(1.0k)|赞(0)|评价(0)|浏览(387)

weights 的长度是数据的长度

如果是[0.5,0.5], 可能只取到了第1条和第2条数据

数据为2万条,weights 就是2万条,你可以控制个值,来控制这条数据的权重
 

把1000条数据,概率相等的采样,采200条数据:

from torch.utils.data import WeightedRandomSampler

weights=[1]*1000

bbb=list(WeightedRandomSampler(weights, 200, replacement=True))

print(bbb)

一个与dataset合用的例子:

weights =aaa=[1]*20000
sampler=WeightedRandomSampler(weights,num_samples=200,replacement=True)

_image_size = 32
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]
trans = transforms.Compose([
    transforms.RandomCrop(_image_size),
    # transforms.RandomHorizontalFlip(),
    # transforms.ColorJitter(.3, .3, .3),
    transforms.ToTensor(),
    # transforms.Normalize(_mean, _std),
])
if __name__ == '__main__':

    train_ds = DogsCatsDataset(r"D:\data\ocr\wanqu\archive", "train", transform=trans)
    train_dl = DataLoader(train_ds, batch_size=2,num_workers=1,sampler=sampler)
    # train_dl = DataLoader(train_ds, batch_size=20,num_workers=1,shuffle=True)

    for i, (data, target) in enumerate(train_dl):
        # print(i,target)
        if len(np.where(target.numpy() == 1)[0])>0:
            print('find 1')

相关文章