Created
April 9, 2010 11:16
-
-
Save yatsuta/361063 to your computer and use it in GitHub Desktop.
PRML Max-sum Implementation for Erlang
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-module(prml). | |
-compile(export_all). | |
%% fn, calc | |
make_fn(VarArg, Body) -> {VarArg, Body}. | |
map_var_arg(VarBinding, VarArg) -> | |
[lookup_dom_val(Var, VarBinding) || Var <- VarArg]. | |
lookup_dom_val(Var, VarBinding) -> | |
{value, DomVal} = lists:keysearch(Var, 1, VarBinding), | |
DomVal. | |
calc({VarArg, Body}, VarBinding) -> | |
Body(map_var_arg(VarBinding, VarArg)). | |
%% konst | |
%% konst(N) -> | |
%% make_fn([], fun(_) -> N end). | |
%% add, mult, prod (utilities) | |
add(X, Y) -> X + Y. | |
mult(X, Y) -> X * Y. | |
prod(L) -> lists:foldl(fun mult/2, 1, L). | |
%% map_reduce (utilities) | |
map_reduce(function, Map, Reduce, L) -> | |
Reduce([Map(X) || X <- L]); | |
map_reduce(process, Map, Reduce, L) -> | |
N = I = length(L), | |
map_spawn(Map, Reduce, L, N, I). | |
map_spawn(_, Reduce, [], N, 0) -> | |
map_wait([], Reduce, N); | |
map_spawn(Map, Reduce, [H|T], N, I) -> | |
spawn(?MODULE, map_spawned, [self(), I, Map, H]), | |
map_spawn(Map, Reduce, T, N, I - 1). | |
map_spawned(Pid, I, Map, X) -> Pid ! {I, Map(X)}. | |
map_wait(Results, Reduce, 0) -> | |
{_, ResultValues} = | |
lists:unzip( | |
lists:sort(fun({I1, _}, {I2, _}) | |
-> I1 > I2 | |
end, | |
Results)), | |
Reduce(ResultValues); | |
map_wait(Results, Reduce, N) -> | |
receive Result -> | |
map_wait([Result|Results], Reduce, N - 1) | |
end. | |
%% sum_fn, sum_for, add_fun | |
%% prod_fun, prod_for, mult_fn | |
%% maxi_fn, maxi_for, max_fn | |
type() -> process. | |
var_arg_union(VarArgs) -> | |
lists:usort(lists:append(VarArgs)). | |
pos(E, L) -> pos(E, L, 0). | |
pos(_, [], _) -> false; | |
pos(E, [H|_], N) when E=:=H -> N; | |
pos(E, [_|T], N) -> pos(E, T, N + 1). | |
val_by_pos(L, N) -> lists:nth(N + 1, L). | |
get_val(Arg, Pos) -> val_by_pos(Arg, Pos). | |
get_pos(Var, VarArg) -> pos(Var, VarArg). | |
extract_arg(VarArg, SourceVarArg, SourceArg) -> | |
[get_val(SourceArg, VarPos) || | |
VarPos <- [get_pos(Var, SourceVarArg) || | |
Var <- VarArg]]. | |
split_into_var_args_and_bodies(Fns) -> lists:unzip(Fns). | |
reduce_fns(Reduce, Fns) -> | |
{VarArgs, _} = split_into_var_args_and_bodies(Fns), | |
VarArgFolded = var_arg_union(VarArgs), | |
BodyFolded = | |
fun(ArgFolded) -> | |
Map = fun({VarArg, Body}) -> | |
Arg = extract_arg(VarArg, | |
VarArgFolded, | |
ArgFolded), | |
Body(Arg) | |
end, | |
map_reduce(type(), Map, Reduce, Fns) | |
end, | |
make_fn(VarArgFolded, BodyFolded). | |
sum_fn(Fns) -> reduce_fns(fun lists:sum/1, Fns). | |
sum_for(L, F) -> map_reduce(type(), F, fun ?MODULE:sum_fn/1, L). | |
add_fn(Fn1, Fn2) -> sum_fn([Fn1, Fn2]). | |
prod_fn(Fns) -> reduce_fns(fun ?MODULE:prod/1, Fns). | |
prod_for(L, F) -> map_reduce(type(), F, fun ?MODULE:prod_fn/1, L). | |
mult_fn(Fn1, Fn2) -> prod_fn([Fn1, Fn2]). | |
maxi_fn(Fns) -> reduce_fns(fun lists:max/1, Fns). | |
maxi_for(L, F) -> map_reduce(type(), F, fun ?MODULE:maxi_fn/1, L). | |
max_fn(Fn1, Fn2) -> maxi_fn([Fn1, Fn2]). | |
%% log_of, exp_of | |
compose_fn(F, {VarArg, Body}) -> | |
make_fn(VarArg, fun(Arg) -> F(Body(Arg)) end). | |
log_of(Fn) -> compose_fn(fun math:log/1, Fn). | |
exp_of(Fn) -> compose_fn(fun math:exp/1, Fn). | |
%% sum_vars, maxi_vars | |
remove_var(VarPos, VarArg) -> | |
{Before, [_|After]} = lists:split(VarPos, VarArg), | |
Before ++ After. | |
insert_dom_val(Pos, Arg, DomVal) -> | |
{Before, After} = lists:split(Pos, Arg), | |
Before ++ [DomVal|After]. | |
partial_fn(Var, DomVal, {VarArg, Body}) -> | |
VarPos = pos(Var, VarArg), | |
VarArgPartial = remove_var(VarPos, VarArg), | |
BodyPartial = | |
fun(ArgPartial) -> | |
Body(insert_dom_val( | |
VarPos, ArgPartial, DomVal)) | |
end, | |
make_fn(VarArgPartial, BodyPartial). | |
sum_var(Var, Fn) -> | |
sum_for(domain_of(Var), | |
fun(DomVal) -> | |
partial_fn(Var, DomVal, Fn) | |
end). | |
sum_vars(Vars, Fn) -> | |
lists:foldr(fun ?MODULE:sum_var/2, Fn, Vars). | |
maxi_var(Var, Fn) -> | |
maxi_for(domain_of(Var), | |
fun(DomVal) -> | |
partial_fn(Var, DomVal, Fn) | |
end). | |
maxi_vars(Vars, Fn) -> | |
lists:foldr(fun ?MODULE:maxi_var/2, Fn, Vars). | |
%% max-sum | |
except(L, E) -> [X || X <- L, X =/= E]. | |
ms_mu_f_x(F, X) -> | |
maxi_vars( | |
except(ne(F), X), | |
add_fn( | |
log_of(fn_of(F)), | |
sum_for( | |
except(ne(F), X), | |
fun(X2) -> ms_mu_x_f(X2, F) end))). | |
ms_mu_x_f(X, F) -> | |
sum_for( | |
except(ne(X), F), | |
fun(F2) -> ms_mu_f_x(F2, X) end). | |
pmax(X) -> | |
exp_of( | |
maxi_var( | |
X, | |
sum_for( | |
ne(X), | |
fun(F2) -> ms_mu_f_x(F2, X) end))). | |
ne(x1) -> [f1]; | |
ne(x2) -> [f1, f2]; | |
ne(x3) -> [f2]; | |
ne(f1) -> [x1, x2]; | |
ne(f2) -> [x2, x3]. | |
domain_of(x1) -> [0, 1]; | |
domain_of(x2) -> [0, 1]; | |
domain_of(x3) -> [0, 1]. | |
p_x1() -> | |
make_fn([x1], | |
fun([0]) -> 0.3; | |
([1]) -> 1.0 - 0.3 | |
end). | |
p_x2_given_x1() -> | |
make_fn([x2, x1], | |
fun([0, 0]) -> 0.7; | |
([1, 0]) -> 1.0 - 0.7; | |
([0, 1]) -> 0.4; | |
([1, 1]) -> 1.0 - 0.4 | |
end). | |
p_x3_given_x2() -> | |
make_fn([x3, x2], | |
fun([0, 0]) -> 0.0000001; | |
([1, 0]) -> 1.0 - 0.0000001; | |
([0, 1]) -> 0.6; | |
([1, 1]) -> 1.0 - 0.6 | |
end). | |
fn_of(f1) -> mult_fn(p_x2_given_x1(), p_x1()); | |
fn_of(f2) -> p_x3_given_x2(). | |
%% main | |
main() -> | |
io:format("pmax(x1): ~w~n" ++ | |
"pmax(x2): ~w~n" ++ | |
"pmax(x3): ~w~n", | |
[calc(pmax(X), []) || X <- [x1, x2, x3]]). |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment