Skip to content

Instantly share code, notes, and snippets.

@sssa2000
Last active December 26, 2015 05:59
Show Gist options
  • Save sssa2000/7104220 to your computer and use it in GitHub Desktop.
Save sssa2000/7104220 to your computer and use it in GitHub Desktop.
LU分解解方程算法
// LUDecompress.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include <iostream>
#include "equation.h"
#include <memory>
#include <vector>
class LU_Decompose_Slover
{
public:
LU_Decompose_Slover(equation_group* eq)
{
_eq=eq;
int varnum=eq->get_var_count();
_orgMatProxy=std::make_shared<matrix_proxy>(eq->get_matrix());
_LMat=std::make_shared<matrix>(varnum,varnum);
_LMat->make_identity();
_LMatProxy=std::make_shared<matrix_proxy>(_LMat.get());
forward_sb_sol.resize(varnum);
}
bool Slove()
{
lu_decompress();
lu_forward_substitution();
lu_back_substitution();
return true;
}
private:
equation_group* _eq;
std::shared_ptr<matrix_proxy> _orgMatProxy;
std::shared_ptr<matrix_proxy> _LMatProxy;
std::shared_ptr<matrix> _LMat;
//std::shared_ptr<matrix_proxy> _UMatProxy;
//std::shared_ptr<matrix> _UMat;
std::vector<float> forward_sb_sol;
void select_pivot(int current_pivot)
{
//遍历矩阵中当前主元下面的行,找到绝对值最大的
float cur_pv=fabs(_orgMatProxy->get_mat_item(current_pivot,current_pivot));
int max_row=current_pivot;
for (int pivotidx=current_pivot+1;
pivotidx<=_orgMatProxy->getrn()-1;
++pivotidx)
{
float _pv=fabs(_orgMatProxy->get_mat_item(pivotidx,current_pivot));
if(_pv>cur_pv)
max_row=pivotidx;
}
if(max_row!=current_pivot)
_orgMatProxy->swap_row(max_row,current_pivot);
}
void lu_decompress()
{
//逐个主元
for (int pivotidx=0;pivotidx<_orgMatProxy->getrn()-1;++pivotidx)
{
select_pivot(pivotidx);
//遍历主元下面的行
for (int rowidx=pivotidx+1;rowidx<_orgMatProxy->getrn();++rowidx)
{
//这里L只需要计算主元所在的列即可
float factor=_orgMatProxy->get_mat_item(rowidx,pivotidx)/_orgMatProxy->get_mat_item(pivotidx,pivotidx);
_LMatProxy->set_mat_item(rowidx,pivotidx,factor);
//计算U还是需要进行消元,不过不用对增广矩阵消元
for (int colidx=pivotidx;colidx<_orgMatProxy->getcn()-1;++colidx)
{
float newv=_orgMatProxy->get_mat_item(rowidx,colidx)-factor * _orgMatProxy->get_mat_item(pivotidx,colidx);
_orgMatProxy->set_mat_item(rowidx,colidx,newv);
}
}
}
}
void lu_back_substitution()
{
int varnum=_eq->get_var_count();
float* sol=_eq->get_solution();
sol[varnum-1]=forward_sb_sol[varnum-1] / _orgMatProxy->get_mat_item(varnum-1,varnum-1);
for (int i=varnum-2;i>=0;--i)
{
float sum=forward_sb_sol[i];
int j=varnum-1;
for (;j>i;--j)
{
sum-=(_orgMatProxy->get_mat_item(i,j) * sol[j]);
}
sol[i]= sum/_orgMatProxy->get_mat_item(i,i);
}
}
//进行LU分解的前代运算,结果保存在forward_sb_sol中
void lu_forward_substitution()
{
int varnum=_eq->get_var_count();
forward_sb_sol[0]=_orgMatProxy->get_mat_item(0,varnum); //因为L矩阵对角线是1,所以不需要除
for (int i=1;i<varnum;++i)
{
float sum=_orgMatProxy->get_mat_item(i,varnum);
int j=0;
for (;j<i;++j)
{
sum-=(_LMatProxy->get_mat_item(i,j) * forward_sb_sol[j]);
}
forward_sb_sol[i]= sum;//因为L矩阵对角线是1,所以不需要除
}
}
};
void test1()
{
float coff[]={
3,-0.1,-0.2,7.85,
0.1,7,-0.3,-19.3,
0.3,-0.2,10,71.4};
equation_group eq(coff,3,4);
LU_Decompose_Slover slov(&eq);
slov.Slove();
float* res=eq.get_solution(); //3,-2.5,7.00003
/*
1 0 0
l= 0.0333333 1 0
0.1 -0.02713 1
3 -0.1 -0.2
u= 0 7.00333 -0.293333
0 0 10.0120
*/
}
int _tmain(int argc, _TCHAR* argv[])
{
test1();
system("pause");
return 0;
}
/********************************************************************
created: 2013/10/20
created: 20:10:2013 21:35
filename: E:\MyPassage\linear_algebra\LUDecompress\equation.h
file path: E:\MyPassage\linear_algebra\LUDecompress
file base: equation
file ext: h
author: sssa2000
purpose:
*********************************************************************/
#pragma once
class matrix
{
public:
matrix(int r,int c);
~matrix();
void set(float* data);
float get_item(int r,int c);
void set_item(int r,int c,float v);
int getrn();
int getcn();
void make_identity();
private:
float* _data;
int _row_num;
int _col_num;
};
class equation_group
{
public:
equation_group(float* coff,int row,int col);
~equation_group();
matrix* get_matrix();
float* get_solution();
//得到方程数量
int get_equat_count();
//得到未知变量的数量
int get_var_count();
private:
matrix* _augmented_matrix; //增广矩阵
int _row_num;
int _col_num;
float* _solution; //保存解的数组
};
//用于保存主元交换的结果
class matrix_proxy
{
public:
matrix_proxy(matrix* mat);
~matrix_proxy();
void swap_row(int rowidx1,int rowidx2);
float get_mat_item(int row,int col);
void set_mat_item(int row,int col,float v);
int getrn();
int getcn();
private:
int* _rerange_res;
int _row_num;
matrix* _mat_ref;
};
#include "equation.h"
#include <stdarg.h>
#include <math.h>
#include <iostream>
matrix::matrix(int r,int c)
{
_row_num=r;
_col_num=c;
_data=new float[_row_num*_col_num];
memset(_data,0,_row_num*_col_num*sizeof(float));
}
matrix::~matrix()
{
delete[] _data;
}
//设置某一行的所有元素 调用者必须保证传入的数据不会越界
void matrix::set(float* data)
{
for (int row=0;row<_row_num;++row)
{
for (int col=0;col<_col_num;++col)
{
set_item(row,col,data[row*_col_num+col]);
}
}
}
float matrix::get_item(int r,int c)
{
return _data[r*_col_num+c];
}
void matrix::set_item(int r,int c,float v)
{
_data[r*_col_num+c]=v;
}
int matrix::getrn(){return _row_num;}
int matrix::getcn(){return _col_num;}
void matrix::make_identity()
{
memset(_data,0,_row_num*_col_num*sizeof(float));
for (int i=0;i<_row_num;++i)
{
set_item(i,i,1);
}
}
equation_group::equation_group(float* coff,int row,int col)
{
_row_num=row;
_col_num=col;
_solution=new float[_col_num];
_augmented_matrix=new matrix(row,col);
_augmented_matrix->set(coff);
}
equation_group::~equation_group()
{
delete[] _solution;
delete _augmented_matrix;
}
matrix* equation_group::get_matrix()
{
return _augmented_matrix;
}
float* equation_group::get_solution()
{
return _solution;
}
//得到方程数量
int equation_group::get_equat_count()
{
return _row_num;
}
//得到未知变量的数量
int equation_group::get_var_count()
{
return _col_num-1;//因为是增广矩阵所以变量数量比列数少1
}
matrix_proxy::matrix_proxy(matrix* mat)
{
_mat_ref=mat;
_rerange_res=new int[_mat_ref->getrn()];
_row_num=mat->getrn();
for (int i=0;i<_row_num;++i)
{
_rerange_res[i]=i;
}
}
matrix_proxy::~matrix_proxy()
{
delete[] _rerange_res;
}
void matrix_proxy::swap_row(int rowidx1,int rowidx2)
{
int tmp=_rerange_res[rowidx1];
_rerange_res[rowidx1]=_rerange_res[rowidx2];
_rerange_res[rowidx2]=tmp;
}
float matrix_proxy::get_mat_item(int row,int col)
{
return _mat_ref->get_item(_rerange_res[row],col);
}
void matrix_proxy::set_mat_item(int row,int col,float v)
{
return _mat_ref->set_item(_rerange_res[row],col,v);
}
int matrix_proxy::getrn()
{
return _mat_ref->getrn();
}
int matrix_proxy::getcn()
{
return _mat_ref->getcn();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment