Skip to content

Instantly share code, notes, and snippets.

@sssa2000
Last active December 26, 2015 05:59
Show Gist options
  • Save sssa2000/7104189 to your computer and use it in GitHub Desktop.
Save sssa2000/7104189 to your computer and use it in GitHub Desktop.
高斯消元法
// guass_elim.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include <windows.h>
#include <iostream>
#include <stdarg.h>
#include <assert.h>
#include <math.h>
#include <memory>
#include "equation.h"
class Guass_Elim_Slover
{
public:
Guass_Elim_Slover(equation_group* _eq)
{
eq=_eq;
_orgMatProxy=std::make_shared<matrix_proxy>(eq->get_matrix());
}
bool Slove()
{
bool b=guass_elim();
if(b)
back_substitution();
return b;
}
private:
equation_group* eq;
std::shared_ptr<matrix_proxy> _orgMatProxy;
//选择主行。遍历所有的行,选取最大主元的行
//执行完该函数后current_pivot所在的行应该是主元最大的行
void select_pivot(int current_pivot)
{
//遍历矩阵中当前主元下面的行,找到绝对值最大的
float cur_pv=fabs(_orgMatProxy->get_mat_item(current_pivot,current_pivot));
int max_row=current_pivot;
for (int pivotidx=current_pivot+1;pivotidx<=_orgMatProxy->getrn()-1;++pivotidx)
{
float _pv=fabs(_orgMatProxy->get_mat_item(pivotidx,current_pivot));
if(_pv>cur_pv)
max_row=pivotidx;
}
if(max_row!=current_pivot)
_orgMatProxy->swap_row(max_row,current_pivot);
}
//高斯消元
//参数是方程的增广矩阵
//返回false表示方程无解
bool guass_elim()
{
int rowNum=_orgMatProxy->getrn();
int colNum=_orgMatProxy->getcn();
//遍历每一个主元
for (int pivotidx=0;pivotidx<rowNum-1;++pivotidx) //rowNum-1是一个优化
{
//选择最合适的主行和主元
select_pivot(pivotidx);
if(_orgMatProxy->get_mat_item(pivotidx,pivotidx)<=0.000001f)
return false;
//遍历主元下方的每一行
for (int ridx=pivotidx+1;ridx<rowNum;++ridx)
{
float factor=_orgMatProxy->get_mat_item(ridx,pivotidx)/_orgMatProxy->get_mat_item(pivotidx,pivotidx);
//遍历列
for (int cidx=pivotidx+1;cidx<colNum;++cidx)//这里cidx=pivotidx+1是一个优化 可以少计算一个数
{
float fv=_orgMatProxy->get_mat_item(ridx,cidx)-_orgMatProxy->get_mat_item(pivotidx,cidx)*factor;
_orgMatProxy->set_mat_item(ridx,cidx,fv);
}
}
}
//检查最后一个主元是否等于0
bool lastpivot=_orgMatProxy->get_mat_item(rowNum-1,rowNum-1)>0.000001f;
return lastpivot;
}
//回代
void back_substitution()
{
float* sul=eq->get_solution();
int rn=eq->get_equat_count();
int cn=eq->get_var_count();
sul[cn-1]=_orgMatProxy->get_mat_item(rn-1,cn)/_orgMatProxy->get_mat_item(rn-1,cn-1);
for (int poivt_rowidx=rn-2;poivt_rowidx>=0;--poivt_rowidx)
{
float sum=_orgMatProxy->get_mat_item(poivt_rowidx,cn);
int colidx=cn-1;
//只需要遍历对角线以前的列
for (;colidx>poivt_rowidx;--colidx)
{
sum-=_orgMatProxy->get_mat_item(poivt_rowidx,colidx)*sul[colidx];
}
sul[poivt_rowidx]=sum/_orgMatProxy->get_mat_item(poivt_rowidx,poivt_rowidx);
}
}
};
void test1()
{
float coff[]={
1,2,1,3,
3,-1,-3,-1,
2,3,1,4};
equation_group eq(coff,3,4);
Guass_Elim_Slover slov(&eq);
slov.Slove();
float* res=eq.get_solution(); //3,-2,4
return;
}
void test2()
{
float coff[]={
6,-4,2,-2,
4,2,1,4,
2,-1,1,-1};
equation_group eq(coff,3,4);
Guass_Elim_Slover slov(&eq);
slov.Slove();
float* res=eq.get_solution();//1,1,-2
return;
}
void test3()
{
//测试主元=0的情况
float coff[]={
0,3,1,1,
1,2,-2,7,
2,5,4,-1};
equation_group eq(coff,3,4);
Guass_Elim_Slover slov(&eq);
slov.Slove();
float* res=eq.get_solution(); //1,1,-2
return;
}
void test4()
{
//测试二元方程组
float coff[]={
1,2,10,
1.1,2,10.4};
equation_group eq(coff,2,3);
Guass_Elim_Slover slov(&eq);
slov.Slove();
float* res=eq.get_solution(); //4,3
return;
}
void test5()
{
//测试奇异矩阵的情况
float coff[]={
1,2,10,
2,4,20};
equation_group eq(coff,2,3);
Guass_Elim_Slover slov(&eq);
bool b=slov.Slove();
float* res=eq.get_solution(); //4,3
return;
}
void test6()
{
//测试奇异矩阵的情况
float coff[]={
1,2,1,3,
2,3,2,6,
3,6,3,9};
equation_group eq(coff,3,4);
Guass_Elim_Slover slov(&eq);
bool b=slov.Slove();
float* res=eq.get_solution(); //1,1,-2
return;
}
int _tmain(int argc, _TCHAR* argv[])
{
test1();
test2();
test3();
test4();
test5();
test6();
system("pause");
return 0;
}
/********************************************************************
created: 2013/10/20
created: 20:10:2013 21:35
filename: E:\MyPassage\linear_algebra\LUDecompress\equation.h
file path: E:\MyPassage\linear_algebra\LUDecompress
file base: equation
file ext: h
author: sssa2000
purpose:
*********************************************************************/
#pragma once
class matrix
{
public:
matrix(int r,int c);
~matrix();
void set(float* data);
float get_item(int r,int c);
void set_item(int r,int c,float v);
int getrn();
int getcn();
void make_identity();
private:
float* _data;
int _row_num;
int _col_num;
};
class equation_group
{
public:
equation_group(float* coff,int row,int col);
~equation_group();
matrix* get_matrix();
float* get_solution();
//得到方程数量
int get_equat_count();
//得到未知变量的数量
int get_var_count();
private:
matrix* _augmented_matrix; //增广矩阵
int _row_num;
int _col_num;
float* _solution; //保存解的数组
};
//用于保存主元交换的结果
class matrix_proxy
{
public:
matrix_proxy(matrix* mat);
~matrix_proxy();
void swap_row(int rowidx1,int rowidx2);
float get_mat_item(int row,int col);
void set_mat_item(int row,int col,float v);
int getrn();
int getcn();
private:
int* _rerange_res;
int _row_num;
matrix* _mat_ref;
};
#include "equation.h"
#include <stdarg.h>
#include <math.h>
#include <iostream>
matrix::matrix(int r,int c)
{
_row_num=r;
_col_num=c;
_data=new float[_row_num*_col_num];
memset(_data,0,_row_num*_col_num*sizeof(float));
}
matrix::~matrix()
{
delete[] _data;
}
//设置某一行的所有元素 调用者必须保证传入的数据不会越界
void matrix::set(float* data)
{
for (int row=0;row<_row_num;++row)
{
for (int col=0;col<_col_num;++col)
{
set_item(row,col,data[row*_col_num+col]);
}
}
}
float matrix::get_item(int r,int c)
{
return _data[r*_col_num+c];
}
void matrix::set_item(int r,int c,float v)
{
_data[r*_col_num+c]=v;
}
int matrix::getrn(){return _row_num;}
int matrix::getcn(){return _col_num;}
void matrix::make_identity()
{
memset(_data,0,_row_num*_col_num*sizeof(float));
for (int i=0;i<_row_num;++i)
{
set_item(i,i,1);
}
}
equation_group::equation_group(float* coff,int row,int col)
{
_row_num=row;
_col_num=col;
_solution=new float[_col_num];
_augmented_matrix=new matrix(row,col);
_augmented_matrix->set(coff);
}
equation_group::~equation_group()
{
delete[] _solution;
delete _augmented_matrix;
}
matrix* equation_group::get_matrix()
{
return _augmented_matrix;
}
float* equation_group::get_solution()
{
return _solution;
}
//得到方程数量
int equation_group::get_equat_count()
{
return _row_num;
}
//得到未知变量的数量
int equation_group::get_var_count()
{
return _col_num-1;//因为是增广矩阵所以变量数量比列数少1
}
matrix_proxy::matrix_proxy(matrix* mat)
{
_mat_ref=mat;
_rerange_res=new int[_mat_ref->getrn()];
_row_num=mat->getrn();
for (int i=0;i<_row_num;++i)
{
_rerange_res[i]=i;
}
}
matrix_proxy::~matrix_proxy()
{
delete[] _rerange_res;
}
void matrix_proxy::swap_row(int rowidx1,int rowidx2)
{
int tmp=_rerange_res[rowidx1];
_rerange_res[rowidx1]=_rerange_res[rowidx2];
_rerange_res[rowidx2]=tmp;
}
float matrix_proxy::get_mat_item(int row,int col)
{
return _mat_ref->get_item(_rerange_res[row],col);
}
void matrix_proxy::set_mat_item(int row,int col,float v)
{
return _mat_ref->set_item(_rerange_res[row],col,v);
}
int matrix_proxy::getrn()
{
return _mat_ref->getrn();
}
int matrix_proxy::getcn()
{
return _mat_ref->getcn();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment