Created
December 5, 2017 21:41
-
-
Save LaurentBerger/4934cd09db05088b8ff2eacae7539a00 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
int train_backprop( const Mat& inputs, const Mat& outputs, const Mat& _sw, TermCriteria termCrit ) | |
{ | |
int i, j, k; | |
double prev_E = DBL_MAX*0.5, E = 0; | |
int itype = inputs.type(), otype = outputs.type(); | |
int count = inputs.rows; | |
int iter = -1, max_iter = 100;//termCrit.maxCount*count; | |
double epsilon = termCrit.epsilon*count; | |
int l_count = layer_count(); | |
int ivcount = layer_sizes[0]; | |
int ovcount = layer_sizes.back(); | |
// allocate buffers | |
vector<vector<double> > x(l_count); | |
vector<vector<double> > df(l_count); | |
vector<Mat> dw(l_count); | |
for( i = 0; i < l_count; i++ ) | |
{ | |
int n = layer_sizes[i]; | |
x[i].resize(n+1); | |
df[i].resize(n); | |
dw[i] = Mat::zeros(weights[i].size(), CV_64F); | |
} | |
Mat _idx_m(1, count, CV_32S); | |
int* _idx = _idx_m.ptr<int>(); | |
for( i = 0; i < count; i++ ) | |
_idx[i] = i; | |
AutoBuffer<double> _buf(max_lsize*2); | |
double* buf[] = { _buf, (double*)_buf + max_lsize }; | |
const double* sw = _sw.empty() ? 0 : _sw.ptr<double>(); | |
// run back-propagation loop | |
/* | |
y_i = w_i*x_{i-1} | |
x_i = f(y_i) | |
E = 1/2*||u - x_N||^2 | |
grad_N = (x_N - u)*f'(y_i) | |
dw_i(t) = momentum*dw_i(t-1) + dw_scale*x_{i-1}*grad_i | |
w_i(t+1) = w_i(t) + dw_i(t) | |
grad_{i-1} = w_i^t*grad_i | |
*/ | |
for( iter = 0; iter < max_iter; iter++ ) | |
{ | |
int idx = iter % count; | |
double sweight = sw ? count*sw[idx] : 1.; | |
if( idx == 0 ) | |
{ | |
//printf("%d. E = %g\n", iter/count, E); | |
if( fabs(prev_E - E) < epsilon ) | |
break; | |
prev_E = E; | |
E = 0; | |
// shuffle indices | |
for( i = 0; i < 0*count; i++ )// DISABLE SHUFFLE = NO RANDOM NUMBER | |
{ | |
j = rng.uniform(0, count); | |
k = rng.uniform(0, count); | |
std::swap(_idx[j], _idx[k]); | |
} | |
} | |
idx = _idx[idx]; | |
const uchar* x0data_p = inputs.ptr(idx); | |
const float* x0data_f = (const float*)x0data_p; | |
const double* x0data_d = (const double*)x0data_p; | |
double* w = weights[0].ptr<double>(); | |
for( j = 0; j < ivcount; j++ ) | |
x[0][j] = (itype == CV_32F ? (double)x0data_f[j] : x0data_d[j])*w[j*2] + w[j*2 + 1]; | |
Mat x1( 1, ivcount, CV_64F, &x[0][0] ); | |
// forward pass, compute y[i]=w*x[i-1], x[i]=f(y[i]), df[i]=f'(y[i]) | |
for( i = 1; i < l_count; i++ ) | |
{ | |
int n = layer_sizes[i]; | |
Mat x2(1, n, CV_64F, &x[i][0] ); | |
Mat _w = weights[i].rowRange(0, x1.cols); | |
gemm(x1, _w, 1, noArray(), 0, x2); | |
Mat _df(1, n, CV_64F, &df[i][0] ); | |
calc_activ_func_deriv( x2, _df, weights[i] ); | |
x1 = x2; | |
} | |
Mat grad1( 1, ovcount, CV_64F, buf[l_count&1] ); | |
w = weights[l_count+1].ptr<double>(); | |
// calculate error | |
const uchar* udata_p = outputs.ptr(idx); | |
const float* udata_f = (const float*)udata_p; | |
const double* udata_d = (const double*)udata_p; | |
double* gdata = grad1.ptr<double>(); | |
for( k = 0; k < ovcount; k++ ) | |
{ | |
double t = (otype == CV_32F ? (double)udata_f[k] : udata_d[k])*w[k*2] + w[k*2+1] - x[l_count-1][k]; | |
gdata[k] = t*sweight; | |
E += t*t; | |
} | |
E *= sweight; | |
// backward pass, update weights | |
for( i = l_count-1; i > 0; i-- ) | |
{ | |
int n1 = layer_sizes[i-1], n2 = layer_sizes[i]; | |
Mat _df(1, n2, CV_64F, &df[i][0]); | |
multiply( grad1, _df, grad1 ); | |
Mat _x(n1+1, 1, CV_64F, &x[i-1][0]); | |
x[i-1][n1] = 1.; | |
gemm( _x, grad1, params.bpDWScale, dw[i], params.bpMomentScale, dw[i] ); | |
add( weights[i], dw[i], weights[i] ); | |
if( i > 1 ) | |
{ | |
Mat grad2(1, n1, CV_64F, buf[i&1]); | |
Mat _w = weights[i].rowRange(0, n1); | |
gemm( grad1, _w, 1, noArray(), 0, grad2, GEMM_2_T ); | |
grad1 = grad2; | |
} | |
} | |
} | |
iter /= count; | |
return 1; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment