Created
June 11, 2019 04:39
-
-
Save f0nzie/4d2760ff6502291da4610aa5ba64583d to your computer and use it in GitHub Desktop.
PyTorch - Linear Regression in R - rsuite version
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# One chunk - Linear Regression in R | |
Source: https://medium.com/dsnet/linear-regression-with-pytorch-3dde91d60b50 | |
```{r} | |
# Force using local Python environment | |
if (.Platform$OS.type == "unix") { | |
reticulate::use_python(python = file.path(script_path, "..", "conda", "bin", | |
"python3"), require = TRUE) | |
} else if (.Platform$OS.type == "windows") { | |
reticulate::use_python(python = file.path(script_path, "..", "conda"), | |
require = TRUE) | |
} | |
``` | |
```{r} | |
library(reticulate) | |
torch = import("torch") | |
np = import("numpy") | |
torch$manual_seed(0) | |
device = torch$device('cpu') | |
# Input (temp, rainfall, humidity) | |
inputs = np$array(list(list(73, 67, 43), | |
list(91, 88, 64), | |
list(87, 134, 58), | |
list(102, 43, 37), | |
list(69, 96, 70)), dtype='float32') | |
# Targets (apples, oranges) | |
targets = np$array(list(list(56, 70), | |
list(81, 101), | |
list(119, 133), | |
list(22, 37), | |
list(103, 119)), dtype='float32') | |
# Convert inputs and targets to tensors | |
inputs = torch$from_numpy(inputs) | |
targets = torch$from_numpy(targets) | |
print(inputs) | |
print(targets) | |
# random numbers for weights and biases. Then convert to double() | |
torch$set_default_dtype(torch$float64) | |
w = torch$randn(2L, 3L, requires_grad=TRUE) #$double() | |
b = torch$randn(2L, requires_grad=TRUE) #$double() | |
print(w) | |
print(b) | |
model <- function(x) { | |
wt <- w$t() | |
# print(wt) | |
# print(x) | |
# mm <- torch$mm(x, wt$double()) | |
return(torch$add(torch$mm(x, wt), b)) | |
} | |
# Generate predictions | |
preds = model(inputs) | |
print(preds) | |
print(targets) | |
# MSE loss | |
mse = function(t1, t2) { | |
diff <- torch$sub(t1, t2) | |
mul <- torch$sum(torch$mul(diff, diff)) | |
return(torch$div(mul, diff$numel())) | |
} | |
# Compute loss | |
loss = mse(preds, targets) | |
print(loss) | |
# 46194 | |
# 33060.8070 | |
# Compute gradients | |
loss$backward() | |
# Gradients for weights | |
print(w) | |
print(w$grad) | |
# Reset the gradients | |
w$grad$zero_() | |
b$grad$zero_() | |
print(w$grad) | |
print(b$grad) | |
# Generate predictions | |
preds = model(inputs) | |
print(preds) | |
# Calculate the loss | |
loss = mse(preds, targets) | |
print(loss) | |
# Compute gradients | |
loss$backward() | |
# print(w); print(b) | |
print(w$grad) | |
print(b$grad) | |
with(torch$no_grad(), { | |
print(w); print(b) # requires_grad attribute remains | |
w$data <- torch$sub(w$data, torch$mul(w$grad$data, torch$scalar_tensor(1e-5))) | |
b$data <- torch$sub(b$data, torch$mul(b$grad$data, torch$scalar_tensor(1e-5))) | |
print(w$grad$data$zero_()) | |
print(b$grad$data$zero_()) | |
}) | |
print(w) | |
print(b) | |
# Calculate loss | |
preds = model(inputs) | |
loss = mse(preds, targets) | |
print(loss) | |
# Adjust weights and reset gradients | |
for (i in 1:100) { | |
preds = model(inputs) | |
loss = mse(preds, targets) | |
loss$backward() | |
with(torch$no_grad(), { | |
w$data <- torch$sub(w$data, torch$mul(w$grad, torch$scalar_tensor(1e-5))) | |
b$data <- torch$sub(b$data, torch$mul(b$grad, torch$scalar_tensor(1e-5))) | |
w$grad$zero_() | |
b$grad$zero_() | |
}) | |
} | |
# Calculate loss | |
preds = model(inputs) | |
loss = mse(preds, targets) | |
print(loss) | |
# predictions | |
preds | |
# Targets | |
targets | |
``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment