Skip to content

Instantly share code, notes, and snippets.

@joyhuang9473
Last active February 25, 2019 15:26
Show Gist options
  • Save joyhuang9473/ab16c86f91ca75e41010c932a5bae59b to your computer and use it in GitHub Desktop.
Save joyhuang9473/ab16c86f91ca75e41010c932a5bae59b to your computer and use it in GitHub Desktop.
- [detection_train.py](https://github.com/TuSimple/simpledet/blob/master/detection_train.py)
-
```
pGen, pKv, pRpn, pRoi, pBbox, pDataset, pModel, pOpt, pTest, \
transform, data_name, label_name, metric_list = config.get_config(is_train=True
```
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L20)
- [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L11)
- from models.tridentnet.builder import TridentFasterRcnn as Detector [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L1)
-
```
class Trident:
num_branch = 3
train_scaleaware = True
test_scaleaware = True
branch_ids = range(num_branch)
branch_dilates = [1, 2, 3]
valid_ranges = [(0, 90), (30, 160), (90, -1)]
valid_ranges_on_origin = True
branch_bn_shared = True
branch_conv_shared = True
branch_deform = False
```
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L19)
-
```
class BackboneParam:
fp16 = General.fp16
depth = General.depth
normalizer = NormalizeParam.normalizer
num_branch = Trident.num_branch
branch_ids = Trident.branch_ids
branch_dilates = Trident.branch_dilates
branch_bn_shared = Trident.branch_bn_shared
branch_conv_shared = Trident.branch_conv_shared
branch_deform = Trident.branch_deform
```
- backbone = Backbone(BackboneParam) [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L124)
- from models.tridentnet.builder import TridentMXNetResNetV2 as Backbone [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r50v2c4_c5_1x.py#L2)
- [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L175)
-
```
def __init__(self, pBackbone):
super(TridentMXNetResNetV2, self).__init__(pBackbone)
p = pBackbone
b = TridentResNetV2Builder()
self.symbol = b.get_backbone("mxnet", p.depth, "c4", p.normalizer, p.fp16,
p.num_branch, p.branch_dilates, p.branch_ids,
p.branch_bn_shared, p.branch_conv_shared, p.branch_deform)
```
- from models.tridentnet.resnet_v2_for_paper import TridentResNetV2Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L8)
- [models/tridentnet/resnet_v2_for_paper.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L10)
-
```
def get_backbone(self, variant, depth, endpoint, normalizer, fp16,
num_branch, branch_dilates, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform):
```
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L260)
- factory = self.resnet_c4_factory [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L274)
- def resnet_c4_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L239)
- c1, c2, c3, c4, c5 = cls.resnet_factory(depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L242)
- def resnet_factory(cls, depth, use_3x3_conv0, use_bn_preprocess, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L219)
-
```
c4 = cls.resnet_trident_c4(c3, num_c4_unit, 2, branch_dilates, norm_type, norm_mom, ndev,
num_branch, branch_ids, branch_bn_shared, branch_conv_shared, branch_deform)
```
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L230)
- def resnet_trident_c4(cls, data, num_block, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L212)
- return cls.resnet_trident_stage( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L214)
- def resnet_trident_stage(cls, data, name, num_block, filter, stride, dilate, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L171)
-
```
data = cls.resnet_unit(data, "{}_unit1".format(name), filter, stride, 1, True, norm_type, norm_mom, ndev)
data = [data] * num_branch
```
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L194)
- [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L20)
- data = cls.resnet_trident_unit( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L205)
- def resnet_trident_unit(cls, data, name, filter, stride, dilate, proj, norm_type, norm_mom, ndev, [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L104)
- bn1 = cls.bn_shared( [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L128)
-
```
conv2 = cls.conv_shared(
relu2, name=name + "_conv2", num_filter=filter // 4, kernel=(3, 3),
pad=dilate, stride=stride, dilate=dilate,
branch_ids=branch_ids, share_weight=branch_conv_shared)
```
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L139)
- from mxnext.backbone.resnet_v2 import Builder [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/resnet_v2_for_paper.py#L5)
- [mxnext.backbone/resnet_v2.py](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L11)
- def get_backbone(self, variant, depth, endpoint, normalizer, fp16): [link](https://github.com/RogerChern/mxnext/blob/master/backbone/resnet_v2.py#L181)
- sym = pModel.train_symbol [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L50)
- class ModelParam: [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L141)
- train_sym = detector.get_train_symbol( [link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L131)
- [models/tridentnet/builder.py](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L16)
-
```
rpn_feat = backbone.get_rpn_feature()
rcnn_feat = backbone.get_rcnn_feature()
rpn_feat = neck.get_rpn_feature(rpn_feat)
rcnn_feat = neck.get_rcnn_feature(rcnn_feat)
```
[link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L33)
- rpn_loss = rpn_head.get_loss(rpn_feat, rpn_cls_label, rpn_reg_target, rpn_reg_weight) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L38)
- proposal, bbox_cls, bbox_target, bbox_weight = rpn_head.get_sampled_proposal_with_filter(rpn_feat, gt_bbox, im_info, valid_ranges) [link](https://github.com/TuSimple/simpledet/blob/master/models/tridentnet/builder.py#L40)
-
```
roi_feat = roi_extractor.get_roi_feature(rcnn_feat, proposal)
bbox_loss = bbox_head.get_loss(roi_feat, bbox_cls, bbox_target, bbox_weight)
return X.group(rpn_loss + bbox_loss)
```
- image_sets = pDataset.image_set [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L60)
-
```
class DatasetParam:
if is_train:
image_set = ("coco_train2014", "coco_valminusminival2014")
else:
image_set = ("coco_minival2014", )
```
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L118)
- roidbs = [pkl.load(open("data/cache/{}.roidb".format(i), "rb"), encoding="latin1") for i in image_sets] [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L61)
-
```
from core.detection_input import AnchorLoader
train_data = AnchorLoader(
roidb=roidb,
transform=transform,
data_name=data_name,
label_name=label_name,
batch_size=input_batch_size,
shuffle=True,
kv=kv
)
```
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L73)
- [core/detection_input.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L715)
- class AnchorLoader(mx.io.DataIter):
- v_roidb, h_roidb = self.roidb_aspect_group(roidb) [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_input.py#L721)
-
```
mod = DetModule(sym, data_names=data_names, label_names=label_names,
logger=logger, context=ctx, fixed_param_prefix=fixed_param_prefix)
```
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L141)
- [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L47)
- class DetModule(BaseModule):
- eval_metrics = mx.metric.CompositeEvalMetric(metric_lis [link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L144)
-
```
rpn_acc_metric = metric.AccWithIgnore(
"RpnAcc",
["rpn_cls_loss_output"],
["rpn_cls_label"]
)
rpn_l1_metric = metric.L1(
"RpnL1",
["rpn_reg_loss_output"],
["rpn_cls_label"]
)
# for bbox, the label is generated in network so it is an output
box_acc_metric = metric.AccWithIgnore(
"RcnnAcc",
["bbox_cls_loss_output", "bbox_label_blockgrad_output"],
[]
)
box_l1_metric = metric.L1(
"RcnnL1",
["bbox_reg_loss_output", "bbox_label_blockgrad_output"],
[]
)
metric_list = [rpn_acc_metric, rpn_l1_metric, box_acc_metric, box_l1_metric]
```
[link](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L271)
- [config/tridentnet_r101v2c4_c5_1x.py](https://github.com/TuSimple/simpledet/blob/master/config/tridentnet_r101v2c4_c5_1x.py#L269)
- class AccWithIgnore(LossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L23)
- class L1(FgLossWithIgnore): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_metric.py#L117)
-
```
mod.fit(
train_data=train_data,
eval_metric=eval_metrics,
epoch_end_callback=epoch_end_callback,
batch_end_callback=batch_end_callback,
kvstore=kv,
optimizer=pOpt.optimizer.type,
optimizer_params=optimizer_params,
initializer=init,
allow_missing=True,
arg_params=arg_params,
aux_params=aux_params,
begin_epoch=begin_epoch,
num_epoch=end_epoch
)
```
[link](https://github.com/TuSimple/simpledet/blob/master/detection_train.py#L200)
- [core/detection_module.py](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L882)
- self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label, [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L958)
-
```
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
optimizer_params=optimizer_params)
```
[link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L963)
- for epoch in range(begin_epoch, num_epoch): [link](https://github.com/TuSimple/simpledet/blob/master/core/detection_module.py#L976)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment