Skip to content

Instantly share code, notes, and snippets.

@zhangce
Created August 14, 2014 06:15
Show Gist options
  • Select an option

  • Save zhangce/44bd80968cf3df3a55f8 to your computer and use it in GitHub Desktop.

Select an option

Save zhangce/44bd80968cf3df3a55f8 to your computer and use it in GitHub Desktop.
push!(LOAD_PATH, "/Users/czhang/Desktop/Projects/dw_/julia/")
import DimmWitted
DimmWitted.set_libpath("/Users/czhang/Desktop/Projects/dw_/libdw_julia")
######################################
# The following function creates a
# synthetic data set:
# - Data type is Cdouble
# - Modle type is Array{Cdouble}
#
nexp = 100000
nfeat = 1024
examples = Array(Cdouble, nexp, nfeat+1)
for row = 1:nexp
for col = 1:nfeat
examples[row, col] = 1
end
if rand() > 0.8
examples[row, nfeat+1] = 0
else
examples[row, nfeat+1] = 1
end
end
model = Cdouble[0 for i = 1:nfeat]
######################################
# Define the loss function and gradient
# function for logistic regression
#
function loss(row::Array{Cdouble,1}, model::Array{Cdouble,1})
@inbounds begin
const label = row[length(row)]
const nfeat = length(model)
d = 0.0
for i = 1:nfeat
d = row[i]*model[i]
end
e = log(exp(d)+1.0)
end
return (-label * d + log(exp(d) + 1.0))
end
function grad(row::Array{Cdouble,1}, model::Array{Cdouble,1})
@inbounds begin
const label = row[length(row)]
const nfeat = length(model)
d = 0.0
for i = 1:nfeat
d = row[i]*model[i]
end
d = exp(-d)
Z = 0.00001 * (-label + 1.0/(1.0+d))
for i = 1:nfeat
model[i] = model[i] - row[i] * Z
end
end
return 1.0
end
######################################
# Create a DimmWitted object using data
# and model. You do not need to specify
# the type, they are infer'ed by the
# open() function, which is parametric.
#
DimmWitted.open(examples, model)
######################################
# Register functions.
#
handle_loss = DimmWitted.register_row(loss)
handle_grad = DimmWitted.register_row(grad)
######################################
# Run 10 epoches.
#
for iepoch = 1:10
rs = DimmWitted.exec(handle_loss)
println("LOSS: ", rs/nexp)
rs = DimmWitted.exec(handle_grad)
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment