栏目分类:
子分类:
返回
文库吧用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
文库吧 > IT > 软件开发 > 后端开发 > Python

Yolact源码解析

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Yolact源码解析

数据加载
with timer.env('Load Data'):
	# img:(550, 550, 3)
	# gt:(3, 5),3是3个物体,5是中心点,宽高,类别
	# gt_mask: (3, 1080, 1920),1080*1920是加载的原图大小
	# h, w:1080, 1920
	# num_crowd:0
    img, gt, gt_masks, h, w, num_crowd = dataset.pull_item(image_idx)
# pull_item在coco.py中
def pull_item():
	return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowds
forward


YOLACT将实例分割问题分解为两个并行的部分,分别产生 “prototype masks” 和 “mask coefficients”

  • Protonet分支

使用全卷积网络(FCN)来生成一组“原型掩码”(prototype masks),该掩码不依赖于任何一个特定的实例,是共用的,对于每张输入图像预测k(32)个prototype masks

  • Prediction Head分支

向目标检测分支(预测 anchor )添加额外的 head 为每一个 实例 / anchor 预测一系列 “掩模系数”(mask coefficients)。生成各候选框的类别 confidence、anchor 的 location 和 prototype mask 的 coefficient

  • Mask Assembly


P:h×w×k的 prototype mask;C:n×k的mask系数矩阵

preds = net(batch)
# forward在yolact.py中
# x:torch.Size([1, 3, 550, 550])
def forward(self, x):
    _, _, img_h, img_w = x.size()
    cfg._tmp_img_h = img_h
    cfg._tmp_img_w = img_w
    
    with timer.env('backbone'):
    #outs是一个tuple, len(outs)=4
    #outs[0]:torch.Size([1, 256, 138, 138])
    #outs[1]:torch.Size([1, 512, 69, 69])
    #outs[2]:torch.Size([1, 1024, 35, 35])
    #outs[3]:torch.Size([1, 2048, 18, 18])
        outs = self.backbone(x)

    if cfg.fpn is not None:
        with timer.env('fpn'):
        #cfg.backbone.selected_layers=[1, 2, 3]
            outs = [outs[i] for i in cfg.backbone.selected_layers]
            #经过fpn后,会产生5个输出,len(outs)=5
            #outs[0]:torch.Size([1, 256, 69, 69])-----P3
            #outs[1]:torch.Size([1, 256, 35, 35])-----P4
            #outs[2]:torch.Size([1, 256, 18, 18])-----P5
            #outs[3]:torch.Size([1, 256, 9, 9])-----P6
            #outs[4]:torch.Size([1, 256, 5, 5])-----P7
            outs = self.fpn(outs)

    proto_out = None
    if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
        with timer.env('proto'):
        # proto_x:对应outs[0],torch.Size([1, 256, 69, 69])-----P3
            proto_x = x if self.proto_src is None else outs[self.proto_src]
            if self.num_grids > 0:
                grids = self.grid.repeat(proto_x.size(0), 1, 1, 1)
                proto_x = torch.cat([proto_x, grids], dim=1)
            #论文中认为32是最理想的
			#proto_out: torch.Size([1, 32, 138, 138])
            proto_out = self.proto_net(proto_x)
            #论文中选择使用ReLU
            #cfg.mask_proto_prototype_activation: activation_func.relu
            proto_out = cfg.mask_proto_prototype_activation(proto_out)

            if cfg.mask_proto_prototypes_as_features:
                # Clone here because we don't want to permute this, though idk if contiguous makes this unnecessary
                proto_downsampled = proto_out.clone()

                if cfg.mask_proto_prototypes_as_features_no_grad:
                    proto_downsampled = proto_out.detach()
            
            #proto_out: torch.Size([1, 138, 138, 32])
            proto_out = proto_out.permute(0, 2, 3, 1).contiguous()

            if cfg.mask_proto_bias:
                bias_shape = [x for x in proto_out.size()]
                bias_shape[-1] = 1
                proto_out = torch.cat([proto_out, torch.ones(*bias_shape)], -1)
            
        with timer.env('pred_heads'):
            pred_outs = { 'loc': [], 'conf': [], 'mask': [], 'priors': [] }
            if cfg.use_mask_scoring:
                pred_outs['score'] = []
            if cfg.use_instance_coeff:
                pred_outs['inst'] = []
            #self.selected_layers: [0, 1, 2, 3, 4]
            #self.prediction_layers:ModuleList(
			# (0): PredictionModule(
			#   (upfeature): Sequential(
			#     (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			#     (1): ReLU(inplace=True)
			#   )
			#   (bbox_layer): Conv2d(256, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			#   (conf_layer): Conv2d(256, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			#   (mask_layer): Conv2d(256, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
			# )
			# (1): PredictionModule()
			# (2): PredictionModule()
			# (3): PredictionModule()
			# (4): PredictionModule())
            for idx, pred_layer in zip(self.selected_layers, self.prediction_layers):
                pred_x = outs[idx]
                if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_prototypes_as_features:
                    # Scale the prototypes down to the current prediction layer's size and add it as inputs
                    proto_downsampled = F.interpolate(proto_downsampled, size=outs[idx].size()[2:], mode='bilinear', align_corners=False)
                    pred_x = torch.cat([pred_x, proto_downsampled], dim=1)

                # idx=1,2,3,4的时候,会添加ModuleList的第0层
                if cfg.share_prediction_module and pred_layer is not self.prediction_layers[0]:
                    pred_layer.parent = [self.prediction_layers[0]]
            # ModuleList[0]的输出为:p={'loc':torch.Size([1, 14283, 4]), 'conf':torch.Size([1, 14283, 10]), 'mask':torch.Size([1, 14283, 32]), 'priors':torch.Size([14283, 4])}
            # ModuleList[1]的输出为:p={'loc':torch.Size([1, 3675, 4]), 'conf':torch.Size([1, 3675, 10]), 'mask':torch.Size([1, 3675, 32]), 'priors':torch.Size([3675, 4])}
            # ModuleList[2]的输出为:p={'loc':torch.Size([1, 972, 4]), 'conf':torch.Size([1, 972, 10]), 'mask':torch.Size([1, 972, 32]), 'priors':torch.Size([972, 4])}
            # ModuleList[3]的输出为:p={'loc':torch.Size([1, 243, 4]), 'conf':torch.Size([1, 243, 10]), 'mask':torch.Size([1, 243, 32]), 'priors':torch.Size([243, 4])}
            # ModuleList[4]的输出为:p={'loc':torch.Size([1, 75, 4]), 'conf':torch.Size([1, 75, 10]), 'mask':torch.Size([1, 75, 32]), 'priors':torch.Size([75, 4])}
                p = pred_layer(pred_x)
                
                for k, v in p.items():
                    pred_outs[k].append(v)
		# 将同一个key的value拼接起来,如:pred_outs['loc']=torch.Size([1, 19248, 4])
        for k, v in pred_outs.items():
            pred_outs[k] = torch.cat(v, -2)
        if proto_out is not None:
            pred_outs['proto'] = proto_out
        pred_outs['conf'] = F.softmax(pred_outs['conf'], -1)
        # 函数在layers/functions/detection.py中
    	return self.detect(pred_outs, self)

layers/functions/detection.py文件

def __call__(self, predictions, net):
    loc_data   = predictions['loc']
    conf_data  = predictions['conf']
    mask_data  = predictions['mask']
    prior_data = predictions['priors']
    proto_data = predictions['proto'] if 'proto' in predictions else None
    inst_data  = predictions['inst']  if 'inst'  in predictions else None
    out = []
    with timer.env('Detect'):
        batch_size = loc_data.size(0)
        num_priors = prior_data.size(0)

        conf_preds = conf_data.view(batch_size, num_priors, self.num_classes).transpose(2, 1).contiguous()

        for batch_idx in range(batch_size):
            decoded_boxes = decode(loc_data[batch_idx], prior_data)
            result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data)

            if result is not None and proto_data is not None:
                result['proto'] = proto_data[batch_idx]
            out.append({'detection': result, 'net': net})
    return out
    
def decode(loc, priors, use_yolo_regressors:bool=False):
    if use_yolo_regressors:
        # Decoded boxes in center-size notation
        boxes = torch.cat((
            loc[:, :2] + priors[:, :2],
            priors[:, 2:] * torch.exp(loc[:, 2:])
        ), 1)
        boxes = point_form(boxes)
    else:
        variances = [0.1, 0.2]
        boxes = torch.cat((
            priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
            priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
        boxes[:, :2] -= boxes[:, 2:] / 2
        boxes[:, 2:] += boxes[:, :2]
    return boxes

def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data, inst_data):
		#torch.Size([1, 10, 19248])---torch.Size([9, 19248])---torch.Size([19248])
        cur_scores = conf_preds[batch_idx, 1:, :]
        conf_scores, _ = torch.max(cur_scores, dim=0)
		# keep:torch.Size([19248])
        keep = (conf_scores > self.conf_thresh)
        #scores: torch.Size([9, 12]),12是可变的,与keep中的True有关
        #boxes: torch.Size([12, 4])
        #masks: torch.Size([12, 32])
        scores = cur_scores[:, keep]
        boxes = decoded_boxes[keep, :]
        masks = mask_data[batch_idx, keep, :]
        if self.use_fast_nms:
            if self.use_cross_class_nms:
                boxes, masks, classes, scores = self.cc_fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k)
            else:
                boxes, masks, classes, scores = self.fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k)
       
		# box: torch.Size([36, 4])
		# mask: torch.Size([36, 32])
		# class: torch.Size([36])
		# score: torch.Size([36])
        return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores}
postprocess
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
                visualize_lincomb=False, crop_masks=True, score_threshold=0):

    dets = det_output[batch_idx]
    net = dets['net']
    dets = dets['detection']

    if score_threshold > 0:
        keep = dets['score'] > score_threshold

        for k in dets:
            if k != 'proto':
                dets[k] = dets[k][keep]
        
        if dets['score'].size(0) == 0:
            return [torch.Tensor()] * 4
    
    classes = dets['class']
    boxes   = dets['box']
    scores  = dets['score']
    masks   = dets['mask']

    if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch:
        proto_data = dets['proto']
        
        if cfg.mask_proto_debug:
            np.save('scripts/proto.npy', proto_data.cpu().numpy())
        
        if visualize_lincomb:
            display_lincomb(proto_data, masks)
		# 对应论文中的Mask Assembly,@表示矩阵乘法
		# proto_data: torch.Size([138, 138, 32])
		# masks.t: torch.Size([32, 6])
		# masks: torch.Size([138, 138, 6])
		# 'mask_proto_mask_activation': activation_func.sigmoid
        masks = proto_data @ masks.t()
        masks = cfg.mask_proto_mask_activation(masks)

        if crop_masks:
            masks = crop(masks, boxes)

        masks = masks.permute(2, 0, 1).contiguous()

        if cfg.use_maskiou:
            with timer.env('maskiou_net'):                
                with torch.no_grad():
                    maskiou_p = net.maskiou_net(masks.unsqueeze(1))
                    maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1)
                    if cfg.rescore_mask:
                        if cfg.rescore_bbox:
                            scores = scores * maskiou_p
                        else:
                            scores = [scores, scores * maskiou_p]

        masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)

        masks.gt_(0.5)

    
    boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False)
    boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False)
    boxes = boxes.long()

    if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch:
        # Upscale masks
        full_masks = torch.zeros(masks.size(0), h, w)

        for jdx in range(masks.size(0)):
            x1, y1, x2, y2 = boxes[jdx, :]

            mask_w = x2 - x1
            mask_h = y2 - y1

            # Just in case
            if mask_w * mask_h <= 0 or mask_w < 0:
                continue
            
            mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size)
            mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False)
            mask = mask.gt(0.5).float()
            full_masks[jdx, y1:y2, x1:x2] = mask
        
        masks = full_masks

    return classes, scores, boxes, masks
Loss
def forward(self, net, predictions, targets, masks, num_crowds):
    loc_data  = predictions['loc']
    conf_data = predictions['conf']
    mask_data = predictions['mask']
    priors    = predictions['priors']

    if cfg.mask_type == mask_type.lincomb:
        proto_data = predictions['proto']

    score_data = predictions['score'] if cfg.use_mask_scoring   else None   
    inst_data  = predictions['inst']  if cfg.use_instance_coeff else None
    
    labels = [None] * len(targets) # Used in sem segm loss

    batch_size = loc_data.size(0)
    num_priors = priors.size(0)
    num_classes = self.num_classes

    # Match priors (default boxes) and ground truth boxes
    # These tensors will be created with the same device as loc_data
    loc_t = loc_data.new(batch_size, num_priors, 4)
    gt_box_t = loc_data.new(batch_size, num_priors, 4)
    conf_t = loc_data.new(batch_size, num_priors).long()
    idx_t = loc_data.new(batch_size, num_priors).long()

    if cfg.use_class_existence_loss:
        class_existence_t = loc_data.new(batch_size, num_classes-1)

    for idx in range(batch_size):
        truths      = targets[idx][:, :-1].data
        labels[idx] = targets[idx][:, -1].data.long()

        if cfg.use_class_existence_loss:
            # Construct a one-hot vector for each object and collapse it into an existence vector with max
            # Also it's fine to include the crowd annotations here
            class_existence_t[idx, :] = torch.eye(num_classes-1, device=conf_t.get_device())[labels[idx]].max(dim=0)[0]

        # Split the crowd annotations because they come bundled in
        cur_crowds = num_crowds[idx]
        if cur_crowds > 0:
            split = lambda x: (x[-cur_crowds:], x[:-cur_crowds])
            crowd_boxes, truths = split(truths)

            # We don't use the crowd labels or masks
            _, labels[idx] = split(labels[idx])
            _, masks[idx]  = split(masks[idx])
        else:
            crowd_boxes = None

        
        match(self.pos_threshold, self.neg_threshold,
              truths, priors.data, labels[idx], crowd_boxes,
              loc_t, conf_t, idx_t, idx, loc_data[idx])
              
        gt_box_t[idx, :, :] = truths[idx_t[idx]]

    # wrap targets
    loc_t = Variable(loc_t, requires_grad=False)
    conf_t = Variable(conf_t, requires_grad=False)
    idx_t = Variable(idx_t, requires_grad=False)

    pos = conf_t > 0
    num_pos = pos.sum(dim=1, keepdim=True)
    
    # Shape: [batch,num_priors,4]
    pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
    
    losses = {}

    # Localization Loss (Smooth L1)
    if cfg.train_boxes:
        loc_p = loc_data[pos_idx].view(-1, 4)
        loc_t = loc_t[pos_idx].view(-1, 4)
        losses['B'] = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha

    if cfg.train_masks:
        if cfg.mask_type == mask_type.direct:
            if cfg.use_gt_bboxes:
                pos_masks = []
                for idx in range(batch_size):
                    pos_masks.append(masks[idx][idx_t[idx, pos[idx]]])
                masks_t = torch.cat(pos_masks, 0)
                masks_p = mask_data[pos, :].view(-1, cfg.mask_dim)
                losses['M'] = F.binary_cross_entropy(torch.clamp(masks_p, 0, 1), masks_t, reduction='sum') * cfg.mask_alpha
            else:
                losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks)
        elif cfg.mask_type == mask_type.lincomb:
            ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels)
            if cfg.use_maskiou:
                loss, maskiou_targets = ret
            else:
                loss = ret
            losses.update(loss)

            if cfg.mask_proto_loss is not None:
                if cfg.mask_proto_loss == 'l1':
                    losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha
                elif cfg.mask_proto_loss == 'disj':
                    losses['P'] = -torch.mean(torch.max(F.log_softmax(proto_data, dim=-1), dim=-1)[0])

    # Confidence loss
    if cfg.use_focal_loss:
        if cfg.use_sigmoid_focal_loss:
            losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t)
        elif cfg.use_objectness_score:
            losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t)
        else:
            losses['C'] = self.focal_conf_loss(conf_data, conf_t)
    else:
        if cfg.use_objectness_score:
            losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors)
        else:
            losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size)

    # Mask IoU Loss
    if cfg.use_maskiou and maskiou_targets is not None:
        losses['I'] = self.mask_iou_loss(net, maskiou_targets)

    # These losses also don't depend on anchors
    if cfg.use_class_existence_loss:
        losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t)
    if cfg.use_semantic_segmentation_loss:
        losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels)

    # Divide all losses by the number of positives.
    # Don't do it for loss[P] because that doesn't depend on the anchors.
    total_num_pos = num_pos.data.sum().float()
    for k in losses:
        if k not in ('P', 'E', 'S'):
            losses[k] /= total_num_pos
        else:
            losses[k] /= batch_size

    # Loss Key:
    #  - B: Box Localization Loss
    #  - C: Class Confidence Loss
    #  - M: Mask Loss
    #  - P: Prototype Loss
    #  - D: Coefficient Diversity Loss
    #  - E: Class Existence Loss
    #  - S: Semantic Segmentation Loss
    return losses
 

def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'):
	# proto_data: torch.Size([8, 138, 138, 32])
    mask_h = proto_data.size(1)
    mask_w = proto_data.size(2)

    process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop

    if cfg.mask_proto_remove_empty_masks:
        # Make sure to store a copy of this because we edit it to get rid of all-zero masks
        pos = pos.clone()

    loss_m = 0
    loss_d = 0 # Coefficient diversity loss

    maskiou_t_list = []
    maskiou_net_input_list = []
    label_t_list = []
	# mask_data: torch.Size([8, 19248, 32])
    for idx in range(mask_data.size(0)):
        with torch.no_grad():
        	# masks[0]: torch.Size([33, 550, 550])
        	# downsampled_masks: torch.Size([33, 138, 138])
            downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w),
                                              mode=interpolation_mode, align_corners=False).squeeze(0)
            downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous()

            if cfg.mask_proto_binarize_downsampled_gt:
            	# torch.gt(a,b)函数比较a中元素大于(这里是严格大于)b中对应元素,大于则为1,不大于则为0
                downsampled_masks = downsampled_masks.gt(0.5).float()
        # cur_pos: torch.Size([19248])
        cur_pos = pos[idx]
        # pos_idx_t: torch.Size([91])
        pos_idx_t = idx_t[idx, cur_pos]
        
        if process_gt_bboxes:
            # Note: this is in point-form
            if cfg.mask_proto_crop_with_pred_box:
                pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos]
            else:
            	# pos_gt_box_t: torch.Size([8, 19248, 4])
                pos_gt_box_t = gt_box_t[idx, cur_pos]
        # proto_masks: torch.Size([138, 138, 32])
        proto_masks = proto_data[idx]
        # proto_coef: torch.Size([91, 32])
        proto_coef  = mask_data[idx, cur_pos, :]
        
        # If we have over the allowed number of masks, select a random sample
        old_num_pos = proto_coef.size(0)
        if old_num_pos > cfg.masks_to_train:
            perm = torch.randperm(proto_coef.size(0))
            select = perm[:cfg.masks_to_train]

            proto_coef = proto_coef[select, :]
            pos_idx_t  = pos_idx_t[select]
            
            if process_gt_bboxes:
                pos_gt_box_t = pos_gt_box_t[select, :]
            if cfg.use_mask_scoring:
                mask_scores = mask_scores[select, :]

        num_pos = proto_coef.size(0)
        # mask_t: torch.Size([138, 138, 91])
        mask_t = downsampled_masks[:, :, pos_idx_t]
        # label_t: torch.Size([91])
        label_t = labels[idx][pos_idx_t]

        # pred_masks: torch.Size([138, 138, 91])
        # cfg.mask_proto_mask_activatio: activation_func.sigmoid
        pred_masks = proto_masks @ proto_coef.t()
        pred_masks = cfg.mask_proto_mask_activation(pred_masks)

        if cfg.mask_proto_crop:
            pred_masks = crop(pred_masks, pos_gt_box_t)
        
        if cfg.mask_proto_mask_activation == activation_func.sigmoid:
            pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none')
        else:
            pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none')
       
        if cfg.mask_proto_normalize_emulate_roi_pooling:
        	# weight = 138*138 = 19044
            weight = mask_h * mask_w if cfg.mask_proto_crop else 1
            pos_gt_csize = center_size(pos_gt_box_t)
            gt_box_width  = pos_gt_csize[:, 2] * mask_w
            gt_box_height = pos_gt_csize[:, 3] * mask_h
            pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight

        # If the number of masks were limited scale the loss accordingly
        if old_num_pos > num_pos:
            pre_loss *= old_num_pos / num_pos

        loss_m += torch.sum(pre_loss)

        if cfg.use_maskiou:
            if cfg.discard_mask_area > 0:
                gt_mask_area = torch.sum(mask_t, dim=(0, 1))
                select = gt_mask_area > cfg.discard_mask_area

                if torch.sum(select) < 1:
                    continue

                pos_gt_box_t = pos_gt_box_t[select, :]
                pred_masks = pred_masks[:, :, select]
                mask_t = mask_t[:, :, select]
                label_t = label_t[select]

            maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1)
            pred_masks = pred_masks.gt(0.5).float()                
            maskiou_t = self._mask_iou(pred_masks, mask_t)
            
            maskiou_net_input_list.append(maskiou_net_input)
            maskiou_t_list.append(maskiou_t)
            label_t_list.append(label_t)
    
    losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w}
    
    if cfg.mask_proto_coeff_diversity_loss:
        losses['D'] = loss_d

    if cfg.use_maskiou:
        # discard_mask_area discarded every mask in the batch, so nothing to do here
        if len(maskiou_t_list) == 0:
            return losses, None

        maskiou_t = torch.cat(maskiou_t_list)
        label_t = torch.cat(label_t_list)
        maskiou_net_input = torch.cat(maskiou_net_input_list)

        num_samples = maskiou_t.size(0)
        if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train:
            perm = torch.randperm(num_samples)
            select = perm[:cfg.masks_to_train]
            maskiou_t = maskiou_t[select]
            label_t = label_t[select]
            maskiou_net_input = maskiou_net_input[select]
        return losses, [maskiou_net_input, maskiou_t, label_t]
    return losses
转载请注明:文章转载自 www.wk8.com.cn
本文地址:https://www.wk8.com.cn/it/280330.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 wk8.com.cn

ICP备案号:晋ICP备2021003244-6号