weights 的长度是数据的长度
如果是[0.5,0.5], 可能只取到了第1条和第2条数据
数据为2万条,weights 就是2万条,你可以控制个值,来控制这条数据的权重
把1000条数据,概率相等的采样,采200条数据:
from torch.utils.data import WeightedRandomSampler
weights=[1]*1000
bbb=list(WeightedRandomSampler(weights, 200, replacement=True))
print(bbb)
一个与dataset合用的例子:
weights =aaa=[1]*20000
sampler=WeightedRandomSampler(weights,num_samples=200,replacement=True)
_image_size = 32
_mean = [0.485, 0.456, 0.406]
_std = [0.229, 0.224, 0.225]
trans = transforms.Compose([
transforms.RandomCrop(_image_size),
# transforms.RandomHorizontalFlip(),
# transforms.ColorJitter(.3, .3, .3),
transforms.ToTensor(),
# transforms.Normalize(_mean, _std),
])
if __name__ == '__main__':
train_ds = DogsCatsDataset(r"D:\data\ocr\wanqu\archive", "train", transform=trans)
train_dl = DataLoader(train_ds, batch_size=2,num_workers=1,sampler=sampler)
# train_dl = DataLoader(train_ds, batch_size=20,num_workers=1,shuffle=True)
for i, (data, target) in enumerate(train_dl):
# print(i,target)
if len(np.where(target.numpy() == 1)[0])>0:
print('find 1')
版权说明 : 本文为转载文章, 版权归原作者所有 版权申明
原文链接 : https://blog.csdn.net/jacke121/article/details/123365946
内容来源于网络,如有侵权,请联系作者删除!