Last active
February 25, 2019 15:26
-
-
Save joyhuang9473/ab16c86f91ca75e41010c932a5bae59b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- [detection_train.py](https://github.com/TuSimple/simpledet/blob/master/detection_train.py) | |
- | |
``` | |
pGen, pKv, pRpn, pRoi, pBbox, pDataset, pModel, pOpt, pTest, \ | |
transform, data_name, label_name, metric_list = config.get_config(is_train=True | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L20) | |
- [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L11) | |
- from models.tridentnet.builder import TridentFasterRcnn as Detector [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L1) | |
- | |
``` | |
class Trident: | |
num_branch = 3 | |
train_scaleaware = True | |
test_scaleaware = True | |
branch_ids = range(num_branch) | |
branch_dilates = [1, 2, 3] | |
valid_ranges = [(0, 90), (30, 160), (90, -1)] | |
valid_ranges_on_origin = True | |
branch_bn_shared = True | |
branch_conv_shared = True | |
branch_deform = False | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L19) | |
- | |
``` | |
class BackboneParam: | |
fp16 = General.fp16 | |
depth = General.depth | |
normalizer = NormalizeParam.normalizer | |
num_branch = Trident.num_branch | |
branch_ids = Trident.branch_ids | |
branch_dilates = Trident.branch_dilates | |
branch_bn_shared = Trident.branch_bn_shared | |
branch_conv_shared = Trident.branch_conv_shared | |
branch_deform = Trident.branch_deform | |
``` | |
- backbone = Backbone(BackboneParam) [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L124) | |
- from models.tridentnet.builder import TridentMXNetResNetV2 as Backbone [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L2) | |
- [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L175) | |
- | |
``` | |
def __init__(self, pBackbone): | |
super(TridentMXNetResNetV2, self).__init__(pBackbone) | |
p = pBackbone | |
b = TridentResNetV2Builder() | |
self.symbol = b.get_backbone("mxnet", p.depth, "c4", p.normalizer, p.fp16, | |
p.num_branch, p.branch_dilates, p.branch_ids, | |
p.branch_bn_shared, p.branch_conv_shared, p.branch_deform) | |
``` | |
- from models.tridentnet.resnet_v2_for_paper import TridentResNetV2Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L8) | |
- [models/tridentnet/resnet_v2_for_paper.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L10) | |
- | |
``` | |
def get_backbone(self, variant, depth, endpoint, normalizer, fp16, | |
num_branch, branch_dilates, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform): | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L260) | |
- factory = self.resnet_c4_factory [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L274) | |
- def resnet_c4_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L239) | |
- c1, c2, c3, c4, c5 = cls.resnet_factory(depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L242) | |
- def resnet_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L219) | |
- | |
``` | |
c4 = cls.resnet_trident_c4(c3, num_c4_unit, 2, branch_dilates, norm_type, norm_mom, ndev, | |
num_branch, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L230) | |
- def resnet_trident_c4(cls, data, num_block, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L212) | |
- return cls.resnet_trident_stage( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L214) | |
- def resnet_trident_stage(cls, data, name, num_block, filter, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L171) | |
- | |
``` | |
data = cls.resnet_unit(data, "{}_unit1".format(name), filter, stride, 1, True, norm_type, norm_mom, ndev) | |
data = [data] * num_branch | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L194) | |
- [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L20) | |
- data = cls.resnet_trident_unit( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L205) | |
- def resnet_trident_unit(cls, data, name, filter, stride, dilate, proj, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L104) | |
- bn1 = cls.bn_shared( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L128) | |
- | |
``` | |
conv2 = cls.conv_shared( | |
relu2, name=name + "_conv2", num_filter=filter // 4, kernel=(3, 3), | |
pad=dilate, stride=stride, dilate=dilate, | |
branch_ids=branch_ids, share_weight=branch_conv_shared) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L139) | |
- from mxnext.backbone.resnet_v2 import Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L5) | |
- [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L11) | |
- def get_backbone(self, variant, depth, endpoint, normalizer, fp16): [link](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L181) | |
- sym = pModel.train_symbol [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L50) | |
- class ModelParam: [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L141) | |
- train_sym = detector.get_train_symbol( [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L131) | |
- [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L16) | |
- | |
``` | |
rpn_feat = backbone.get_rpn_feature() | |
rcnn_feat = backbone.get_rcnn_feature() | |
rpn_feat = neck.get_rpn_feature(rpn_feat) | |
rcnn_feat = neck.get_rcnn_feature(rcnn_feat) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L33) | |
- rpn_loss = rpn_head.get_loss(rpn_feat, rpn_cls_label, rpn_reg_target, rpn_reg_weight) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L38) | |
- proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal_with_filter(rpn_feat, gt_bbox, im_info, valid_ranges) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L40) | |
- | |
``` | |
roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal) | |
bbox_loss = bbox_head.get_loss(roi_feat, bbox_cls, bbox_target, bbox_weight) | |
return X.group(rpn_loss + bbox_loss) | |
``` | |
- image_sets = pDataset.image_set [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L60) | |
- | |
``` | |
class DatasetParam: | |
if is_train: | |
image_set = ("coco_train2014", "coco_valminusminival2014") | |
else: | |
image_set = ("coco_minival2014", ) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L118) | |
- roidbs = [pkl.load(open("data/cache/{}.roidb".format(i), "rb"), encoding="latin1") for i in image_sets] [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L61) | |
- | |
``` | |
from core.detection_input import AnchorLoader | |
train_data = AnchorLoader( | |
roidb=roidb, | |
transform=transform, | |
data_name=data_name, | |
label_name=label_name, | |
batch_size=input_batch_size, | |
shuffle=True, | |
kv=kv | |
) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L73) | |
- [core/detection_input.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L715) | |
- class AnchorLoader(mx.io.DataIter): | |
- v_roidb, h_roidb = self.roidb_aspect_group(roidb) [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L721) | |
- | |
``` | |
mod = DetModule(sym, data_names=data_names, label_names=label_names, | |
logger=logger, context=ctx, fixed_param_prefix=fixed_param_prefix) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L141) | |
- [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L47) | |
- class DetModule(BaseModule): | |
- eval_metrics = mx.metric.CompositeEvalMetric(metric_lis [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L144) | |
- | |
``` | |
rpn_acc_metric = metric.AccWithIgnore( | |
"RpnAcc", | |
["rpn_cls_loss_output"], | |
["rpn_cls_label"] | |
) | |
rpn_l1_metric = metric.L1( | |
"RpnL1", | |
["rpn_reg_loss_output"], | |
["rpn_cls_label"] | |
) | |
# for bbox, the label is generated in network so it is an output | |
box_acc_metric = metric.AccWithIgnore( | |
"RcnnAcc", | |
["bbox_cls_loss_output", "bbox_label_blockgrad_output"], | |
[] | |
) | |
box_l1_metric = metric.L1( | |
"RcnnL1", | |
["bbox_reg_loss_output", "bbox_label_blockgrad_output"], | |
[] | |
) | |
metric_list = [rpn_acc_metric, rpn_l1_metric, box_acc_metric, box_l1_metric] | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L271) | |
- [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L269) | |
- class AccWithIgnore(LossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L23) | |
- class L1(FgLossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L117) | |
- | |
``` | |
mod.fit( | |
train_data=train_data, | |
eval_metric=eval_metrics, | |
epoch_end_callback=epoch_end_callback, | |
batch_end_callback=batch_end_callback, | |
kvstore=kv, | |
optimizer=pOpt.optimizer.type, | |
optimizer_params=optimizer_params, | |
initializer=init, | |
allow_missing=True, | |
arg_params=arg_params, | |
aux_params=aux_params, | |
begin_epoch=begin_epoch, | |
num_epoch=end_epoch | |
) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L200) | |
- [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L882) | |
- self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L958) | |
- | |
``` | |
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params, | |
allow_missing=allow_missing, force_init=force_init) | |
self.init_optimizer(kvstore=kvstore, optimizer=optimizer, | |
optimizer_params=optimizer_params) | |
``` | |
[link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L963) | |
- for epoch in range(begin_epoch, num_epoch): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L976) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment