栏目分类:
子分类:
返回
文库吧用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
文库吧 > IT > 软件开发 > 游戏开发 > Cocos2d-x

YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样

Cocos2d-x 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

YOLOv5的Tricks | 【Trick8】图片采样策略——按数据集各类别权重采样


如有错误,恳请指出。


文章目录
  • 1. 图片采样策略想法
  • 2. 图片采样策略代码

这篇文章用来记录一下yolov5在训练过程中提出的一个图片采样策略,简单来说,就是根据图片的权重来决定其采样顺序。


1. 图片采样策略想法
  • 图片采样策略想法

在我们训练数据集的时候,一般是对数据集随机采样几张图像然后构建成一个mini-batch来批量输入网络处理。个人猜想,一个可能的想法就是,这种随机的图像采集会不会过于随意,因为有些图像的目标是过少的,那么这种图像可能对网络来说比较简单;而有些图像的目标是比较多的,这种是比较困难的。而对于开始训练的初期就使用这种简答图像对网络的训练可能带来不了多大的学习提升。

所以,如果可以对数据集中的每张图像做一个权重的划分,在训练模型的时候依照图像的权重大小依次按难到易的大概顺序来进行训练,让模型从一开始的困难的样本较快的学习到潜在特征,到之后通过简单的图像样本来对参数进行微调,说不定是一个好的方法。

(以上内容是个人的思考猜测,可能是有误的,欢迎探讨。)

  • 图片采样策略思路

那么具体的实现思路就是,对整个数据集的图像目标做类别统计,然后类别的数目越大权重越小(成反比的关系)。然后再使用整个数据集的类别权重对每一张图像做类别权重的叠加。也就是根据每一张的图片的类别权重和来作为采样的权重,决定其采用的顺序。在代码的实现中是从大到小排序的。


2. 图片采样策略代码
  • yolov5参考代码

大概的注释都写在代码里了:

def train():
	...
	model.class_weights = labels_to_class_weights(dataset.labels, nc).to(device) * nc  # attach class weights
	...
	for epoch in range(start_epoch, epochs): 
		model.train()

		# Update image weights (optional, single-GPU only)
        if opt.image_weights:
            # 根据数据集的类别数目构建每个类别的权重(类别权重与类别数目成反比)
            cw = model.class_weights.cpu().numpy() * (1 - maps) ** 2 / nc  # class weights
            # 对每张图片的目标计算其类别权重和作为图片的采集权重
            iw = labels_to_image_weights(dataset.labels, nc=nc, class_weights=cw)  # image weights
            # 再更具每张图片的采集权重来构建图片的采样顺序
            dataset.indices = random.choices(range(dataset.n), weights=iw, k=dataset.n)  # rand weighted idx
	...


def labels_to_class_weights(labels, nc=80):
    # Get class weights (inverse frequency) from training labels
    # labels是当前数据集的训练集的所有图像: {list: 682}
    # 列表的每个对象格式是: (ndarray: (k, 5)) k表示当前图像的目表个数, 5是(class+xywh)
    if labels[0] is None:  # no labels loaded
        return torch.Tensor()
    # 把图像的标签列表直接转化为标签列表:{ndarray: (labels, 5)} labels表示全部图像的所有标签个数
    labels = np.concatenate(labels, 0)  # labels.shape = (866643, 5) for COCO
    # 提取类别 labels[:, 0] 数据来为每一类做统计 .astype(np.int): 取整
    classes = labels[:, 0].astype(np.int)  # labels = [class xywh]
    # weight: 统计每个类别出现的次数
    weights = np.bincount(classes, minlength=nc)  # occurrences per class

    # Prepend gridpoint count (for uCE training)
    # gpi = ((320 / 32 * np.array([1, 2, 4])) ** 2 * 3).sum()  # gridpoints per image
    # weights = np.hstack([gpi * len(labels)  - weights.sum() * 9, weights * 9]) ** 0.5  # prepend gridpoints to start

    # 将出现次数为0的类别权重全部取1
    weights[weights == 0] = 1  # replace empty bins with 1
    # 类别权重取类别出现次数的倒数, 也就是表示类别次数与权重成反比, 标签频率越高的类别权重越低, 因为越不罕见
    weights = 1 / weights  # number of targets per class
    # 归一化操作: 求出每一类别的占比
    weights /= weights.sum()  # normalize
    return torch.from_numpy(weights)  # numpy -> tensor


def labels_to_image_weights(labels, nc=80, class_weights=np.ones(80)):
    # Produces image weights based on class_weights and image contents
    # out:{ndarray: (682,3)} 统计每一张图片中类类别的数目 这里我用的是mask数据集有3个类别 每个位置存储图像中对应类别目标出现的个数
    class_counts = np.array([np.bincount(x[:, 0].astype(np.int), minlength=nc) for x in labels])
    # class_weights:[n_class] -> [1, n_class]
    # 每张图片的每个类别个数[label_nums, n_class] * 整个数据集每个类别的权重[1, n_class] = 每张图片的对应每个类别的权重[label_nums, n_class_weight]
    # 然后每个类别的权重加在一起等于当前这张图片的权重
    image_weights = (class_weights.reshape(1, nc) * class_counts).sum(1)
    # index = random.choices(range(n), weights=image_weights, k=1)  # weight image sample
    return image_weights
  • 构造Dataset使用的地方
class LoadImagesAndLabels(Dataset):
	def __init__(self, img_size=640, batch_size=16, image_weights=False, ...):
		...
		self.indices = range(n)
	
	def __len__(self):
        return len(self.img_files)
	
    def __getitem__(self, index):
    	# 重点使用部分, 就是用权重采样策略替代了随机采样
    	# 随机采样: index返回的是随机值(shuffle = True),所以注意到其实在
    	# 权重采样: index是按顺序从0开始, 然后依次提取indices所指向的图像索引
        index = self.indices[index]  # linear, shuffled, or image_weights
        img, labels = load_mosaic(self, index)
        ...
        return torch.from_numpy(img), labels_out, self.img_files[index], shapes

# 因为可以注意到, 构建dataloader的时候yolov5代码中是没有使用shuffle=True这个随机采样的参数的
def create_dataloader(path, imgsz, batch_size, stride, single_cls=False, hyp=None, augment=False, cache=False, pad=0.0,
                      rect=False, rank=-1, workers=8, image_weights=False, quad=False, prefix=''):
    # Make sure only the first process in DDP process the dataset first, and the following others can use the cache
    with torch_distributed_zero_first(rank):
        dataset = LoadImagesAndLabels(path, imgsz, batch_size,
                                      augment=augment,  # augment images
                                      hyp=hyp,  # augmentation hyperparameters
                                      rect=rect,  # rectangular training
                                      cache_images=cache,
                                      single_cls=single_cls,
                                      stride=int(stride),
                                      pad=pad,
                                      image_weights=image_weights,
                                      prefix=prefix)

    batch_size = min(batch_size, len(dataset))

    # 这里对num_worker进行更改
    # nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, workers])  # number of workers
    nw = 0  # 可以适当提高这个参数0, 2, 4, 8, 16…

    sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
    loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
    # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
    
    # 没有使用 shuffle=True 这个参数
	dataloader = loader(dataset,
                        batch_size=batch_size,
                        num_workers=nw,
                        sampler=sampler,
                        pin_memory=True,
                        collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn)
    return dataloader, dataset

所以从代码中可以看见,如果不使用图像采样策略,这里也不会使用随机的选择策略,而且index从0开始提取,验证如下:

第一次断点调试:index从0开始,想法验证成功


参考资料:

1. 【YOLOV5-5.x 源码解读】general.py

转载请注明:文章转载自 www.wk8.com.cn
本文地址:https://www.wk8.com.cn/it/968043.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 wk8.com.cn

ICP备案号:晋ICP备2021003244-6号