WeightedRandomSampler示例

x33g5p2x  于2022-03-09 转载在 其他  
字(1.0k)|赞(0)|评价(0)|浏览(483)

weights 的长度是数据的长度

如果是[0.5,0.5], 可能只取到了第1条和第2条数据

数据为2万条,weights 就是2万条,你可以控制个值,来控制这条数据的权重
 

把1000条数据,概率相等的采样,采200条数据:

  1. from torch.utils.data import WeightedRandomSampler
  2. weights=[1]*1000
  3. bbb=list(WeightedRandomSampler(weights, 200, replacement=True))
  4. print(bbb)

一个与dataset合用的例子:

  1. weights =aaa=[1]*20000
  2. sampler=WeightedRandomSampler(weights,num_samples=200,replacement=True)
  3. _image_size = 32
  4. _mean = [0.485, 0.456, 0.406]
  5. _std = [0.229, 0.224, 0.225]
  6. trans = transforms.Compose([
  7. transforms.RandomCrop(_image_size),
  8. # transforms.RandomHorizontalFlip(),
  9. # transforms.ColorJitter(.3, .3, .3),
  10. transforms.ToTensor(),
  11. # transforms.Normalize(_mean, _std),
  12. ])
  13. if __name__ == '__main__':
  14. train_ds = DogsCatsDataset(r"D:\data\ocr\wanqu\archive", "train", transform=trans)
  15. train_dl = DataLoader(train_ds, batch_size=2,num_workers=1,sampler=sampler)
  16. # train_dl = DataLoader(train_ds, batch_size=20,num_workers=1,shuffle=True)
  17. for i, (data, target) in enumerate(train_dl):
  18. # print(i,target)
  19. if len(np.where(target.numpy() == 1)[0])>0:
  20. print('find 1')

相关文章