Last active
February 25, 2019 15:26
-
-
Save joyhuang9473/ab16c86f91ca75e41010c932a5bae59b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| - [detection_train.py](https://github.com/TuSimple/simpledet/blob/master/detection_train.py) | |
| - | |
| ``` | |
| pGen, pKv, pRpn, pRoi, pBbox, pDataset, pModel, pOpt, pTest, \ | |
| transform, data_name, label_name, metric_list = config.get_config(is_train=True | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L20) | |
| - [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L11) | |
| - from models.tridentnet.builder import TridentFasterRcnn as Detector [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L1) | |
| - | |
| ``` | |
| class Trident: | |
| num_branch = 3 | |
| train_scaleaware = True | |
| test_scaleaware = True | |
| branch_ids = range(num_branch) | |
| branch_dilates = [1, 2, 3] | |
| valid_ranges = [(0, 90), (30, 160), (90, -1)] | |
| valid_ranges_on_origin = True | |
| branch_bn_shared = True | |
| branch_conv_shared = True | |
| branch_deform = False | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L19) | |
| - | |
| ``` | |
| class BackboneParam: | |
| fp16 = General.fp16 | |
| depth = General.depth | |
| normalizer = NormalizeParam.normalizer | |
| num_branch = Trident.num_branch | |
| branch_ids = Trident.branch_ids | |
| branch_dilates = Trident.branch_dilates | |
| branch_bn_shared = Trident.branch_bn_shared | |
| branch_conv_shared = Trident.branch_conv_shared | |
| branch_deform = Trident.branch_deform | |
| ``` | |
| - backbone = Backbone(BackboneParam) [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L124) | |
| - from models.tridentnet.builder import TridentMXNetResNetV2 as Backbone [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L2) | |
| - [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L175) | |
| - | |
| ``` | |
| def __init__(self, pBackbone): | |
| super(TridentMXNetResNetV2, self).__init__(pBackbone) | |
| p = pBackbone | |
| b = TridentResNetV2Builder() | |
| self.symbol = b.get_backbone("mxnet", p.depth, "c4", p.normalizer, p.fp16, | |
| p.num_branch, p.branch_dilates, p.branch_ids, | |
| p.branch_bn_shared, p.branch_conv_shared, p.branch_deform) | |
| ``` | |
| - from models.tridentnet.resnet_v2_for_paper import TridentResNetV2Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L8) | |
| - [models/tridentnet/resnet_v2_for_paper.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L10) | |
| - | |
| ``` | |
| def get_backbone(self, variant, depth, endpoint, normalizer, fp16, | |
| num_branch, branch_dilates, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform): | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L260) | |
| - factory = self.resnet_c4_factory [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L274) | |
| - def resnet_c4_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L239) | |
| - c1, c2, c3, c4, c5 = cls.resnet_factory(depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L242) | |
| - def resnet_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L219) | |
| - | |
| ``` | |
| c4 = cls.resnet_trident_c4(c3, num_c4_unit, 2, branch_dilates, norm_type, norm_mom, ndev, | |
| num_branch, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L230) | |
| - def resnet_trident_c4(cls, data, num_block, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L212) | |
| - return cls.resnet_trident_stage( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L214) | |
| - def resnet_trident_stage(cls, data, name, num_block, filter, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L171) | |
| - | |
| ``` | |
| data = cls.resnet_unit(data, "{}_unit1".format(name), filter, stride, 1, True, norm_type, norm_mom, ndev) | |
| data = [data] * num_branch | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L194) | |
| - [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L20) | |
| - data = cls.resnet_trident_unit( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L205) | |
| - def resnet_trident_unit(cls, data, name, filter, stride, dilate, proj, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L104) | |
| - bn1 = cls.bn_shared( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L128) | |
| - | |
| ``` | |
| conv2 = cls.conv_shared( | |
| relu2, name=name + "_conv2", num_filter=filter // 4, kernel=(3, 3), | |
| pad=dilate, stride=stride, dilate=dilate, | |
| branch_ids=branch_ids, share_weight=branch_conv_shared) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L139) | |
| - from mxnext.backbone.resnet_v2 import Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L5) | |
| - [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L11) | |
| - def get_backbone(self, variant, depth, endpoint, normalizer, fp16): [link](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L181) | |
| - sym = pModel.train_symbol [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L50) | |
| - class ModelParam: [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L141) | |
| - train_sym = detector.get_train_symbol( [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L131) | |
| - [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L16) | |
| - | |
| ``` | |
| rpn_feat = backbone.get_rpn_feature() | |
| rcnn_feat = backbone.get_rcnn_feature() | |
| rpn_feat = neck.get_rpn_feature(rpn_feat) | |
| rcnn_feat = neck.get_rcnn_feature(rcnn_feat) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L33) | |
| - rpn_loss = rpn_head.get_loss(rpn_feat, rpn_cls_label, rpn_reg_target, rpn_reg_weight) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L38) | |
| - proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal_with_filter(rpn_feat, gt_bbox, im_info, valid_ranges) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L40) | |
| - | |
| ``` | |
| roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal) | |
| bbox_loss = bbox_head.get_loss(roi_feat, bbox_cls, bbox_target, bbox_weight) | |
| return X.group(rpn_loss + bbox_loss) | |
| ``` | |
| - image_sets = pDataset.image_set [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L60) | |
| - | |
| ``` | |
| class DatasetParam: | |
| if is_train: | |
| image_set = ("coco_train2014", "coco_valminusminival2014") | |
| else: | |
| image_set = ("coco_minival2014", ) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L118) | |
| - roidbs = [pkl.load(open("data/cache/{}.roidb".format(i), "rb"), encoding="latin1") for i in image_sets] [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L61) | |
| - | |
| ``` | |
| from core.detection_input import AnchorLoader | |
| train_data = AnchorLoader( | |
| roidb=roidb, | |
| transform=transform, | |
| data_name=data_name, | |
| label_name=label_name, | |
| batch_size=input_batch_size, | |
| shuffle=True, | |
| kv=kv | |
| ) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L73) | |
| - [core/detection_input.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L715) | |
| - class AnchorLoader(mx.io.DataIter): | |
| - v_roidb, h_roidb = self.roidb_aspect_group(roidb) [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L721) | |
| - | |
| ``` | |
| mod = DetModule(sym, data_names=data_names, label_names=label_names, | |
| logger=logger, context=ctx, fixed_param_prefix=fixed_param_prefix) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L141) | |
| - [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L47) | |
| - class DetModule(BaseModule): | |
| - eval_metrics = mx.metric.CompositeEvalMetric(metric_lis [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L144) | |
| - | |
| ``` | |
| rpn_acc_metric = metric.AccWithIgnore( | |
| "RpnAcc", | |
| ["rpn_cls_loss_output"], | |
| ["rpn_cls_label"] | |
| ) | |
| rpn_l1_metric = metric.L1( | |
| "RpnL1", | |
| ["rpn_reg_loss_output"], | |
| ["rpn_cls_label"] | |
| ) | |
| # for bbox, the label is generated in network so it is an output | |
| box_acc_metric = metric.AccWithIgnore( | |
| "RcnnAcc", | |
| ["bbox_cls_loss_output", "bbox_label_blockgrad_output"], | |
| [] | |
| ) | |
| box_l1_metric = metric.L1( | |
| "RcnnL1", | |
| ["bbox_reg_loss_output", "bbox_label_blockgrad_output"], | |
| [] | |
| ) | |
| metric_list = [rpn_acc_metric, rpn_l1_metric, box_acc_metric, box_l1_metric] | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L271) | |
| - [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L269) | |
| - class AccWithIgnore(LossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L23) | |
| - class L1(FgLossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L117) | |
| - | |
| ``` | |
| mod.fit( | |
| train_data=train_data, | |
| eval_metric=eval_metrics, | |
| epoch_end_callback=epoch_end_callback, | |
| batch_end_callback=batch_end_callback, | |
| kvstore=kv, | |
| optimizer=pOpt.optimizer.type, | |
| optimizer_params=optimizer_params, | |
| initializer=init, | |
| allow_missing=True, | |
| arg_params=arg_params, | |
| aux_params=aux_params, | |
| begin_epoch=begin_epoch, | |
| num_epoch=end_epoch | |
| ) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L200) | |
| - [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L882) | |
| - self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L958) | |
| - | |
| ``` | |
| self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, | |
| allow_missing=allow_missing, force_init=force_init) | |
| self.init_optimizer(kvstore=kvstore, optimizer=optimizer, | |
| optimizer_params=optimizer_params) | |
| ``` | |
| [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L963) | |
| - for epoch in range(begin_epoch, num_epoch): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L976) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment