with timer.env('Load Data'): # img:(550, 550, 3) # gt:(3, 5),3是3个物体,5是中心点,宽高,类别 # gt_mask: (3, 1080, 1920),1080*1920是加载的原图大小 # h, w:1080, 1920 # num_crowd:0 img, gt, gt_masks, h, w, num_crowd = dataset.pull_item(image_idx) # pull_item在coco.py中 def pull_item(): return torch.from_numpy(img).permute(2, 0, 1), target, masks, height, width, num_crowdsforward
YOLACT将实例分割问题分解为两个并行的部分,分别产生 “prototype masks” 和 “mask coefficients”
- Protonet分支
使用全卷积网络(FCN)来生成一组“原型掩码”(prototype masks),该掩码不依赖于任何一个特定的实例,是共用的,对于每张输入图像预测k(32)个prototype masks
- Prediction Head分支
向目标检测分支(预测 anchor )添加额外的 head 为每一个 实例 / anchor 预测一系列 “掩模系数”(mask coefficients)。生成各候选框的类别 confidence、anchor 的 location 和 prototype mask 的 coefficient
- Mask Assembly
P:h×w×k的 prototype mask;C:n×k的mask系数矩阵
preds = net(batch) # forward在yolact.py中 # x:torch.Size([1, 3, 550, 550]) def forward(self, x): _, _, img_h, img_w = x.size() cfg._tmp_img_h = img_h cfg._tmp_img_w = img_w with timer.env('backbone'): #outs是一个tuple, len(outs)=4 #outs[0]:torch.Size([1, 256, 138, 138]) #outs[1]:torch.Size([1, 512, 69, 69]) #outs[2]:torch.Size([1, 1024, 35, 35]) #outs[3]:torch.Size([1, 2048, 18, 18]) outs = self.backbone(x) if cfg.fpn is not None: with timer.env('fpn'): #cfg.backbone.selected_layers=[1, 2, 3] outs = [outs[i] for i in cfg.backbone.selected_layers] #经过fpn后,会产生5个输出,len(outs)=5 #outs[0]:torch.Size([1, 256, 69, 69])-----P3 #outs[1]:torch.Size([1, 256, 35, 35])-----P4 #outs[2]:torch.Size([1, 256, 18, 18])-----P5 #outs[3]:torch.Size([1, 256, 9, 9])-----P6 #outs[4]:torch.Size([1, 256, 5, 5])-----P7 outs = self.fpn(outs) proto_out = None if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: with timer.env('proto'): # proto_x:对应outs[0],torch.Size([1, 256, 69, 69])-----P3 proto_x = x if self.proto_src is None else outs[self.proto_src] if self.num_grids > 0: grids = self.grid.repeat(proto_x.size(0), 1, 1, 1) proto_x = torch.cat([proto_x, grids], dim=1) #论文中认为32是最理想的 #proto_out: torch.Size([1, 32, 138, 138]) proto_out = self.proto_net(proto_x) #论文中选择使用ReLU #cfg.mask_proto_prototype_activation: activation_func.relu proto_out = cfg.mask_proto_prototype_activation(proto_out) if cfg.mask_proto_prototypes_as_features: # Clone here because we don't want to permute this, though idk if contiguous makes this unnecessary proto_downsampled = proto_out.clone() if cfg.mask_proto_prototypes_as_features_no_grad: proto_downsampled = proto_out.detach() #proto_out: torch.Size([1, 138, 138, 32]) proto_out = proto_out.permute(0, 2, 3, 1).contiguous() if cfg.mask_proto_bias: bias_shape = [x for x in proto_out.size()] bias_shape[-1] = 1 proto_out = torch.cat([proto_out, torch.ones(*bias_shape)], -1) with timer.env('pred_heads'): pred_outs = { 'loc': [], 'conf': [], 'mask': [], 'priors': [] } if cfg.use_mask_scoring: pred_outs['score'] = [] if cfg.use_instance_coeff: pred_outs['inst'] = [] #self.selected_layers: [0, 1, 2, 3, 4] #self.prediction_layers:ModuleList( # (0): PredictionModule( # (upfeature): Sequential( # (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (1): ReLU(inplace=True) # ) # (bbox_layer): Conv2d(256, 12, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (conf_layer): Conv2d(256, 30, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # (mask_layer): Conv2d(256, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) # ) # (1): PredictionModule() # (2): PredictionModule() # (3): PredictionModule() # (4): PredictionModule()) for idx, pred_layer in zip(self.selected_layers, self.prediction_layers): pred_x = outs[idx] if cfg.mask_type == mask_type.lincomb and cfg.mask_proto_prototypes_as_features: # Scale the prototypes down to the current prediction layer's size and add it as inputs proto_downsampled = F.interpolate(proto_downsampled, size=outs[idx].size()[2:], mode='bilinear', align_corners=False) pred_x = torch.cat([pred_x, proto_downsampled], dim=1) # idx=1,2,3,4的时候,会添加ModuleList的第0层 if cfg.share_prediction_module and pred_layer is not self.prediction_layers[0]: pred_layer.parent = [self.prediction_layers[0]] # ModuleList[0]的输出为:p={'loc':torch.Size([1, 14283, 4]), 'conf':torch.Size([1, 14283, 10]), 'mask':torch.Size([1, 14283, 32]), 'priors':torch.Size([14283, 4])} # ModuleList[1]的输出为:p={'loc':torch.Size([1, 3675, 4]), 'conf':torch.Size([1, 3675, 10]), 'mask':torch.Size([1, 3675, 32]), 'priors':torch.Size([3675, 4])} # ModuleList[2]的输出为:p={'loc':torch.Size([1, 972, 4]), 'conf':torch.Size([1, 972, 10]), 'mask':torch.Size([1, 972, 32]), 'priors':torch.Size([972, 4])} # ModuleList[3]的输出为:p={'loc':torch.Size([1, 243, 4]), 'conf':torch.Size([1, 243, 10]), 'mask':torch.Size([1, 243, 32]), 'priors':torch.Size([243, 4])} # ModuleList[4]的输出为:p={'loc':torch.Size([1, 75, 4]), 'conf':torch.Size([1, 75, 10]), 'mask':torch.Size([1, 75, 32]), 'priors':torch.Size([75, 4])} p = pred_layer(pred_x) for k, v in p.items(): pred_outs[k].append(v) # 将同一个key的value拼接起来,如:pred_outs['loc']=torch.Size([1, 19248, 4]) for k, v in pred_outs.items(): pred_outs[k] = torch.cat(v, -2) if proto_out is not None: pred_outs['proto'] = proto_out pred_outs['conf'] = F.softmax(pred_outs['conf'], -1) # 函数在layers/functions/detection.py中 return self.detect(pred_outs, self)
layers/functions/detection.py文件
def __call__(self, predictions, net): loc_data = predictions['loc'] conf_data = predictions['conf'] mask_data = predictions['mask'] prior_data = predictions['priors'] proto_data = predictions['proto'] if 'proto' in predictions else None inst_data = predictions['inst'] if 'inst' in predictions else None out = [] with timer.env('Detect'): batch_size = loc_data.size(0) num_priors = prior_data.size(0) conf_preds = conf_data.view(batch_size, num_priors, self.num_classes).transpose(2, 1).contiguous() for batch_idx in range(batch_size): decoded_boxes = decode(loc_data[batch_idx], prior_data) result = self.detect(batch_idx, conf_preds, decoded_boxes, mask_data, inst_data) if result is not None and proto_data is not None: result['proto'] = proto_data[batch_idx] out.append({'detection': result, 'net': net}) return out def decode(loc, priors, use_yolo_regressors:bool=False): if use_yolo_regressors: # Decoded boxes in center-size notation boxes = torch.cat(( loc[:, :2] + priors[:, :2], priors[:, 2:] * torch.exp(loc[:, 2:]) ), 1) boxes = point_form(boxes) else: variances = [0.1, 0.2] boxes = torch.cat(( priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) boxes[:, :2] -= boxes[:, 2:] / 2 boxes[:, 2:] += boxes[:, :2] return boxes def detect(self, batch_idx, conf_preds, decoded_boxes, mask_data, inst_data): #torch.Size([1, 10, 19248])---torch.Size([9, 19248])---torch.Size([19248]) cur_scores = conf_preds[batch_idx, 1:, :] conf_scores, _ = torch.max(cur_scores, dim=0) # keep:torch.Size([19248]) keep = (conf_scores > self.conf_thresh) #scores: torch.Size([9, 12]),12是可变的,与keep中的True有关 #boxes: torch.Size([12, 4]) #masks: torch.Size([12, 32]) scores = cur_scores[:, keep] boxes = decoded_boxes[keep, :] masks = mask_data[batch_idx, keep, :] if self.use_fast_nms: if self.use_cross_class_nms: boxes, masks, classes, scores = self.cc_fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k) else: boxes, masks, classes, scores = self.fast_nms(boxes, masks, scores, self.nms_thresh, self.top_k) # box: torch.Size([36, 4]) # mask: torch.Size([36, 32]) # class: torch.Size([36]) # score: torch.Size([36]) return {'box': boxes, 'mask': masks, 'class': classes, 'score': scores}postprocess
def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear', visualize_lincomb=False, crop_masks=True, score_threshold=0): dets = det_output[batch_idx] net = dets['net'] dets = dets['detection'] if score_threshold > 0: keep = dets['score'] > score_threshold for k in dets: if k != 'proto': dets[k] = dets[k][keep] if dets['score'].size(0) == 0: return [torch.Tensor()] * 4 classes = dets['class'] boxes = dets['box'] scores = dets['score'] masks = dets['mask'] if cfg.mask_type == mask_type.lincomb and cfg.eval_mask_branch: proto_data = dets['proto'] if cfg.mask_proto_debug: np.save('scripts/proto.npy', proto_data.cpu().numpy()) if visualize_lincomb: display_lincomb(proto_data, masks) # 对应论文中的Mask Assembly,@表示矩阵乘法 # proto_data: torch.Size([138, 138, 32]) # masks.t: torch.Size([32, 6]) # masks: torch.Size([138, 138, 6]) # 'mask_proto_mask_activation': activation_func.sigmoid masks = proto_data @ masks.t() masks = cfg.mask_proto_mask_activation(masks) if crop_masks: masks = crop(masks, boxes) masks = masks.permute(2, 0, 1).contiguous() if cfg.use_maskiou: with timer.env('maskiou_net'): with torch.no_grad(): maskiou_p = net.maskiou_net(masks.unsqueeze(1)) maskiou_p = torch.gather(maskiou_p, dim=1, index=classes.unsqueeze(1)).squeeze(1) if cfg.rescore_mask: if cfg.rescore_bbox: scores = scores * maskiou_p else: scores = [scores, scores * maskiou_p] masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0) masks.gt_(0.5) boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=False) boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=False) boxes = boxes.long() if cfg.mask_type == mask_type.direct and cfg.eval_mask_branch: # Upscale masks full_masks = torch.zeros(masks.size(0), h, w) for jdx in range(masks.size(0)): x1, y1, x2, y2 = boxes[jdx, :] mask_w = x2 - x1 mask_h = y2 - y1 # Just in case if mask_w * mask_h <= 0 or mask_w < 0: continue mask = masks[jdx, :].view(1, 1, cfg.mask_size, cfg.mask_size) mask = F.interpolate(mask, (mask_h, mask_w), mode=interpolation_mode, align_corners=False) mask = mask.gt(0.5).float() full_masks[jdx, y1:y2, x1:x2] = mask masks = full_masks return classes, scores, boxes, masksLoss
def forward(self, net, predictions, targets, masks, num_crowds): loc_data = predictions['loc'] conf_data = predictions['conf'] mask_data = predictions['mask'] priors = predictions['priors'] if cfg.mask_type == mask_type.lincomb: proto_data = predictions['proto'] score_data = predictions['score'] if cfg.use_mask_scoring else None inst_data = predictions['inst'] if cfg.use_instance_coeff else None labels = [None] * len(targets) # Used in sem segm loss batch_size = loc_data.size(0) num_priors = priors.size(0) num_classes = self.num_classes # Match priors (default boxes) and ground truth boxes # These tensors will be created with the same device as loc_data loc_t = loc_data.new(batch_size, num_priors, 4) gt_box_t = loc_data.new(batch_size, num_priors, 4) conf_t = loc_data.new(batch_size, num_priors).long() idx_t = loc_data.new(batch_size, num_priors).long() if cfg.use_class_existence_loss: class_existence_t = loc_data.new(batch_size, num_classes-1) for idx in range(batch_size): truths = targets[idx][:, :-1].data labels[idx] = targets[idx][:, -1].data.long() if cfg.use_class_existence_loss: # Construct a one-hot vector for each object and collapse it into an existence vector with max # Also it's fine to include the crowd annotations here class_existence_t[idx, :] = torch.eye(num_classes-1, device=conf_t.get_device())[labels[idx]].max(dim=0)[0] # Split the crowd annotations because they come bundled in cur_crowds = num_crowds[idx] if cur_crowds > 0: split = lambda x: (x[-cur_crowds:], x[:-cur_crowds]) crowd_boxes, truths = split(truths) # We don't use the crowd labels or masks _, labels[idx] = split(labels[idx]) _, masks[idx] = split(masks[idx]) else: crowd_boxes = None match(self.pos_threshold, self.neg_threshold, truths, priors.data, labels[idx], crowd_boxes, loc_t, conf_t, idx_t, idx, loc_data[idx]) gt_box_t[idx, :, :] = truths[idx_t[idx]] # wrap targets loc_t = Variable(loc_t, requires_grad=False) conf_t = Variable(conf_t, requires_grad=False) idx_t = Variable(idx_t, requires_grad=False) pos = conf_t > 0 num_pos = pos.sum(dim=1, keepdim=True) # Shape: [batch,num_priors,4] pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) losses = {} # Localization Loss (Smooth L1) if cfg.train_boxes: loc_p = loc_data[pos_idx].view(-1, 4) loc_t = loc_t[pos_idx].view(-1, 4) losses['B'] = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') * cfg.bbox_alpha if cfg.train_masks: if cfg.mask_type == mask_type.direct: if cfg.use_gt_bboxes: pos_masks = [] for idx in range(batch_size): pos_masks.append(masks[idx][idx_t[idx, pos[idx]]]) masks_t = torch.cat(pos_masks, 0) masks_p = mask_data[pos, :].view(-1, cfg.mask_dim) losses['M'] = F.binary_cross_entropy(torch.clamp(masks_p, 0, 1), masks_t, reduction='sum') * cfg.mask_alpha else: losses['M'] = self.direct_mask_loss(pos_idx, idx_t, loc_data, mask_data, priors, masks) elif cfg.mask_type == mask_type.lincomb: ret = self.lincomb_mask_loss(pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels) if cfg.use_maskiou: loss, maskiou_targets = ret else: loss = ret losses.update(loss) if cfg.mask_proto_loss is not None: if cfg.mask_proto_loss == 'l1': losses['P'] = torch.mean(torch.abs(proto_data)) / self.l1_expected_area * self.l1_alpha elif cfg.mask_proto_loss == 'disj': losses['P'] = -torch.mean(torch.max(F.log_softmax(proto_data, dim=-1), dim=-1)[0]) # Confidence loss if cfg.use_focal_loss: if cfg.use_sigmoid_focal_loss: losses['C'] = self.focal_conf_sigmoid_loss(conf_data, conf_t) elif cfg.use_objectness_score: losses['C'] = self.focal_conf_objectness_loss(conf_data, conf_t) else: losses['C'] = self.focal_conf_loss(conf_data, conf_t) else: if cfg.use_objectness_score: losses['C'] = self.conf_objectness_loss(conf_data, conf_t, batch_size, loc_p, loc_t, priors) else: losses['C'] = self.ohem_conf_loss(conf_data, conf_t, pos, batch_size) # Mask IoU Loss if cfg.use_maskiou and maskiou_targets is not None: losses['I'] = self.mask_iou_loss(net, maskiou_targets) # These losses also don't depend on anchors if cfg.use_class_existence_loss: losses['E'] = self.class_existence_loss(predictions['classes'], class_existence_t) if cfg.use_semantic_segmentation_loss: losses['S'] = self.semantic_segmentation_loss(predictions['segm'], masks, labels) # Divide all losses by the number of positives. # Don't do it for loss[P] because that doesn't depend on the anchors. total_num_pos = num_pos.data.sum().float() for k in losses: if k not in ('P', 'E', 'S'): losses[k] /= total_num_pos else: losses[k] /= batch_size # Loss Key: # - B: Box Localization Loss # - C: Class Confidence Loss # - M: Mask Loss # - P: Prototype Loss # - D: Coefficient Diversity Loss # - E: Class Existence Loss # - S: Semantic Segmentation Loss return losses def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data, masks, gt_box_t, score_data, inst_data, labels, interpolation_mode='bilinear'): # proto_data: torch.Size([8, 138, 138, 32]) mask_h = proto_data.size(1) mask_w = proto_data.size(2) process_gt_bboxes = cfg.mask_proto_normalize_emulate_roi_pooling or cfg.mask_proto_crop if cfg.mask_proto_remove_empty_masks: # Make sure to store a copy of this because we edit it to get rid of all-zero masks pos = pos.clone() loss_m = 0 loss_d = 0 # Coefficient diversity loss maskiou_t_list = [] maskiou_net_input_list = [] label_t_list = [] # mask_data: torch.Size([8, 19248, 32]) for idx in range(mask_data.size(0)): with torch.no_grad(): # masks[0]: torch.Size([33, 550, 550]) # downsampled_masks: torch.Size([33, 138, 138]) downsampled_masks = F.interpolate(masks[idx].unsqueeze(0), (mask_h, mask_w), mode=interpolation_mode, align_corners=False).squeeze(0) downsampled_masks = downsampled_masks.permute(1, 2, 0).contiguous() if cfg.mask_proto_binarize_downsampled_gt: # torch.gt(a,b)函数比较a中元素大于(这里是严格大于)b中对应元素,大于则为1,不大于则为0 downsampled_masks = downsampled_masks.gt(0.5).float() # cur_pos: torch.Size([19248]) cur_pos = pos[idx] # pos_idx_t: torch.Size([91]) pos_idx_t = idx_t[idx, cur_pos] if process_gt_bboxes: # Note: this is in point-form if cfg.mask_proto_crop_with_pred_box: pos_gt_box_t = decode(loc_data[idx, :, :], priors.data, cfg.use_yolo_regressors)[cur_pos] else: # pos_gt_box_t: torch.Size([8, 19248, 4]) pos_gt_box_t = gt_box_t[idx, cur_pos] # proto_masks: torch.Size([138, 138, 32]) proto_masks = proto_data[idx] # proto_coef: torch.Size([91, 32]) proto_coef = mask_data[idx, cur_pos, :] # If we have over the allowed number of masks, select a random sample old_num_pos = proto_coef.size(0) if old_num_pos > cfg.masks_to_train: perm = torch.randperm(proto_coef.size(0)) select = perm[:cfg.masks_to_train] proto_coef = proto_coef[select, :] pos_idx_t = pos_idx_t[select] if process_gt_bboxes: pos_gt_box_t = pos_gt_box_t[select, :] if cfg.use_mask_scoring: mask_scores = mask_scores[select, :] num_pos = proto_coef.size(0) # mask_t: torch.Size([138, 138, 91]) mask_t = downsampled_masks[:, :, pos_idx_t] # label_t: torch.Size([91]) label_t = labels[idx][pos_idx_t] # pred_masks: torch.Size([138, 138, 91]) # cfg.mask_proto_mask_activatio: activation_func.sigmoid pred_masks = proto_masks @ proto_coef.t() pred_masks = cfg.mask_proto_mask_activation(pred_masks) if cfg.mask_proto_crop: pred_masks = crop(pred_masks, pos_gt_box_t) if cfg.mask_proto_mask_activation == activation_func.sigmoid: pre_loss = F.binary_cross_entropy(torch.clamp(pred_masks, 0, 1), mask_t, reduction='none') else: pre_loss = F.smooth_l1_loss(pred_masks, mask_t, reduction='none') if cfg.mask_proto_normalize_emulate_roi_pooling: # weight = 138*138 = 19044 weight = mask_h * mask_w if cfg.mask_proto_crop else 1 pos_gt_csize = center_size(pos_gt_box_t) gt_box_width = pos_gt_csize[:, 2] * mask_w gt_box_height = pos_gt_csize[:, 3] * mask_h pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight # If the number of masks were limited scale the loss accordingly if old_num_pos > num_pos: pre_loss *= old_num_pos / num_pos loss_m += torch.sum(pre_loss) if cfg.use_maskiou: if cfg.discard_mask_area > 0: gt_mask_area = torch.sum(mask_t, dim=(0, 1)) select = gt_mask_area > cfg.discard_mask_area if torch.sum(select) < 1: continue pos_gt_box_t = pos_gt_box_t[select, :] pred_masks = pred_masks[:, :, select] mask_t = mask_t[:, :, select] label_t = label_t[select] maskiou_net_input = pred_masks.permute(2, 0, 1).contiguous().unsqueeze(1) pred_masks = pred_masks.gt(0.5).float() maskiou_t = self._mask_iou(pred_masks, mask_t) maskiou_net_input_list.append(maskiou_net_input) maskiou_t_list.append(maskiou_t) label_t_list.append(label_t) losses = {'M': loss_m * cfg.mask_alpha / mask_h / mask_w} if cfg.mask_proto_coeff_diversity_loss: losses['D'] = loss_d if cfg.use_maskiou: # discard_mask_area discarded every mask in the batch, so nothing to do here if len(maskiou_t_list) == 0: return losses, None maskiou_t = torch.cat(maskiou_t_list) label_t = torch.cat(label_t_list) maskiou_net_input = torch.cat(maskiou_net_input_list) num_samples = maskiou_t.size(0) if cfg.maskious_to_train > 0 and num_samples > cfg.maskious_to_train: perm = torch.randperm(num_samples) select = perm[:cfg.masks_to_train] maskiou_t = maskiou_t[select] label_t = label_t[select] maskiou_net_input = maskiou_net_input[select] return losses, [maskiou_net_input, maskiou_t, label_t] return losses