Created
October 24, 2011 10:06
-
-
Save jonifreeman/1308712 to your computer and use it in GitHub Desktop.
Function to submit mlclass week 2 excercises
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
-- LinExtras.hs, some hmtarx helpers | |
module LinExtras where | |
import Numeric.LinearAlgebra | |
import Foreign.Storable (Storable) | |
vector xs = fromList xs :: Vector Double | |
-- Each entry is a column. | |
matrix :: [[Double]] -> Matrix Double | |
matrix cols = trans $ fromLists cols | |
-- Get ith column (0 indexed) | |
m @! idx = head (toColumns $ dropColumns idx m) :: Vector Double | |
-- blocks [[1, 2, 3], [5, 6, 7]] | |
blocks mms = fromBlocks mms :: Matrix Double | |
ones r c = konst (1 :: Double) (r, c) | |
-- Vectors can be created conveniently: | |
-- 19 # constant 3 5 # (-2) # 11 | |
infixl 9 # | |
a # b = join [a, b] :: Vector Double | |
-- Convenient syntax to create row and column matrices: | |
-- m + row [10, 20 .. 50] | |
row = asRow . vector | |
col = asColumn . vector | |
-- Submit.hs, implement module Ext2 and then call function 'submit' | |
import System.IO | |
import Control.Exception | |
import Numeric.LinearAlgebra | |
import Data.Digest.Pure.SHA | |
import Data.ByteString.Lazy.Char8 as BS8 (pack) | |
import Data.List (sort) | |
import System.Random (randomRIO) | |
import Network.Curl | |
import Text.Printf (printf) | |
import Data.List.Split (splitOn) | |
import Data.Char (isSpace) | |
import Control.Monad (when) | |
import LinExtras | |
import Ex2 -- The excercises are implemented in this module | |
submit = do | |
putStrLn $ "==\n== [ml-class] Submitting Solutions | Programming Exercise " ++ homeworkId | |
partId <- promptPart | |
(login, pass) <- loginPrompt | |
putStrLn "\n== Connecting to ml-class ... " | |
(login, ch, signature) <- getChallenge login | |
let hasError = any ((==) 0 . length) [login, ch, signature] | |
when hasError $ putStrLn $ "\n!! Error: " ++ login ++ "\n\n" | |
when (not hasError) $ submitAnswer partId login pass ch signature | |
where | |
submitAnswer partId login pass ch signature = do | |
chResp <- challengeResponse login pass ch | |
putStrLn $ "SHA1: " ++ (show chResp) | |
result <- submitSolution login chResp partId (output partId) (source partId) signature | |
putStrLn $ "\n== [ml-class] Submitted Homework " ++ homeworkId ++ " - Part " ++ (show partId) ++ " - " ++ (validParts !! (partId - 1)) | |
putStrLn $ "== " ++ result | |
getChallenge login = withCurlDo $ do | |
curl <- initialize | |
resp <- do_curl_ curl challengeUrl (CurlPostFields [ "email_address=" ++ login ] : method_POST) :: IO CurlResponse | |
let s = (respBody resp) | |
let elems = splitOn "|" (trim s) | |
putStrLn $ "== Get challenge " ++ (show elems) | |
return (elems !! 0, elems !! 1, elems !! 2) | |
submitSolution login chResp partId output source signature = withCurlDo $ do | |
curl <- initialize | |
resp <- do_curl_ curl submitUrl (CurlPostFields fields : method_POST) :: IO CurlResponse | |
return (respBody resp) | |
where fields = [ "homework=" ++ homeworkId | |
, "part=" ++ (show partId) | |
, "email=" ++ login | |
, "output=" ++ output | |
, "source=" ++ source | |
, "challenge_response=" ++ chResp | |
, "signature=" ++ signature ] | |
challengeResponse login passwd challenge = do | |
rperm <- randperm [0..((length str) - 1)] | |
return $ select (sort $ take 16 rperm) str | |
where salt = ")~/|]QMB3[!W`?OVt7qC\"@+}" | |
s = salt ++ login ++ passwd | |
hash = sha1 . BS8.pack | |
str = showDigest $ hash $ challenge ++ (showDigest (hash s)) | |
promptPart = do | |
putStrLn $ "== Select which part(s) to submit: " ++ homeworkId | |
mapM_ putStrLn $ zipWith (\i p -> "== " ++ (show i) ++ " [" ++ p ++ "]") [1..] validParts | |
putStrLn "Enter your choice: " | |
partId <- getLine | |
let part = read partId :: Int | |
return part | |
loginPrompt = do | |
putStrLn "Login (Email address): " | |
login <- getLine | |
putStrLn "Password: " | |
pass <- withEcho False getLine | |
return (login, pass) | |
challengeUrl = "http://www.ml-class.org/course/homework/challenge" | |
submitUrl = "http://www.ml-class.org/course/homework/submit" | |
-- How to get sources? | |
source partId = "" | |
outputMatrix :: Matrix Double -> String | |
outputMatrix m = unwords $ map outputVector (toColumns m) | |
outputVector :: Vector Double -> String | |
outputVector v = unwords $ map (printf "%0.5f") (toList v) | |
outputDouble = printf "%0.5f" | |
-- General stuff | |
withEcho :: Bool -> IO a -> IO a | |
withEcho echo action = do | |
old <- hGetEcho stdin | |
bracket_ (hSetEcho stdin echo) (hSetEcho stdin old) action | |
select :: [Int] -> String -> String | |
select idxs s = map ((!!) s) idxs | |
randperm :: [a] -> IO [a] | |
randperm xs = selektion (length xs) xs | |
where selektion :: Int -> [a] -> IO [a] | |
selektion 0 xs = return [] | |
selektion k xs = do | |
i <- randomRIO (0, length xs - 1) | |
let (here, y : there) = splitAt i xs | |
ys <- selektion (pred k) $ here ++ there | |
return $ y : ys | |
trim :: String -> String | |
trim = f . f | |
where f = reverse . dropWhile isSpace | |
-- Homework specific stuff | |
homeworkId = "2" | |
validParts = [ "Sigmoid Function " | |
, "Logistic Regression Cost" | |
, "Logistic Regression Gradient" | |
, "Predict" | |
, "Regularized Logistic Regression Cost" | |
, "Regularized Logistic Regression Gradient" ] | |
output :: Int -> String | |
output partId = case partId of | |
1 -> outputMatrix $ sigmoid x | |
2 -> outputDouble $ fst $ costFunction (fromList [0.25, 0.5, -0.5]) x y | |
3 -> outputMatrix $ snd $ costFunction (fromList [0.25, 0.5, -0.5]) x y | |
4 -> outputVector $ predict (fromList [0.25, 0.5, -0.5]) x | |
5 -> outputDouble $ fst $ costFunctionReg (fromList [0.25, 0.5, -0.5]) x y 0.1 | |
6 -> outputMatrix $ snd $ costFunctionReg (fromList [0.25, 0.5, -0.5]) x y 0.1 | |
where x = matrix [take 20 $ repeat 1, [exp 1 * sin x | x <- [1..20]], [exp 0.5 * cos x | x <- [1..20]]] | |
y = mapVector (\v -> if v > 0 then 1 else 0) $ sin (x @! 0 + x @! 1) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment