This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Fuse conv-bn pattern in torch.Module, an example for torch.fx | |
see: https://pytorch.org/tutorials/intermediate/fx_conv_bn_fuser.html | |
""" | |
import copy | |
from typing import Tuple, Dict, Any | |
import torch | |
import torch.fx as fx | |
import torch.nn as nn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
class CharDataset(object): | |
def __init__(self, path): | |
if not os.path.exists(path): | |
raise RuntimeError('Cannot open the file: {}'.format(path)) | |
self.raw_data = open(path, 'r').read() | |
self.chars = list(set(self.raw_data)) | |
self.data_size = len(self.raw_data) | |
print('There are {} characters in the file'.format(self.data_size)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
def init_weight(m): | |
classname = m.__class__.__name__ | |
if classname.find('Conv') != -1: | |
m.weight.data.normal_(0., 0.02) | |
elif classname.find('BatchNorm') != -1: | |
m.weight.data.normal_(1., 0.02) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import math | |
## the model definition | |
# see HeKaiming's implementation using torch: | |
# https://github.com/KaimingHe/resnet-1k-layers/blob/master/README.md | |
class Bottleneck(nn.Module): | |
expansion = 4 # # output cahnnels / # input channels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Enter your network definition here. | |
# Use Shift+Enter to update the visualization. | |
layer { | |
name: "CustomData1" | |
type: "CustomData" | |
top: "blob0" | |
top: "blob1" | |
top: "blob2" | |
top: "blob3" | |
include { |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: UTF-8 -*- | |
# File: logger.py | |
# Author: Yuxin Wu <[email protected]> | |
from __future__ import print_function | |
import logging | |
import os | |
import errno | |
import shutil | |
import os.path | |
from datetime import datetime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# convert yolo cfg file to caffe prototxt file | |
import sys | |
import argparse | |
TYPES = ["Conv", "ReLU", "Pool", "Route", "Reorg", "Region"] | |
ACTIVATION_TYPES = ['leaky', 'linear'] | |
layer_names = []; | |
def HasConflictNameError(conflicted_name): | |
print 'Error! The layer name \"{}\" has been in the list.'.format(conflicted_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
import torch.nn as nn | |
from torch.autograd import Variable | |
import torchvision.datasets as dset | |
import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
import torch.optim as optim | |
## load mnist dataset | |
use_cuda = torch.cuda.is_available() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <vector> | |
#include "gtest/gtest.h" | |
#include "caffe/common.hpp" | |
#include "caffe/blob.hpp" | |
#include "caffe/layers/reorg_layer.hpp" | |
#include "caffe/test/test_caffe_main.hpp" | |
namespace caffe { | |
template <typename TypeParam> | |
class ReorgLayerTest : public MultiDeviceTest<TypeParam> { | |
typedef typename TypeParam::Dtype Dtype; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# has 3 lstm layer | |
name: "ocr" | |
layer { | |
name: "data" | |
type: "OCRData" | |
top: "data" | |
top: "label" | |
image_data_param { | |
is_color: false |
NewerOlder