Created
December 26, 2017 02:14
-
-
Save lancelet/2acddfbfdecc811993b853b860e395fa to your computer and use it in GitHub Desktop.
OpenCL + Vector + Kernel Quasiquoting
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| {-# LANGUAGE QuasiQuotes #-} | |
| module Main where | |
| import Control.Parallel.OpenCL | |
| import Data.Vector.Storable (Vector) | |
| import qualified Data.Vector.Storable as V | |
| import Data.Vector.Storable.Mutable (IOVector) | |
| import qualified Data.Vector.Storable.Mutable as MV | |
| import Foreign (castPtr, nullPtr, sizeOf) | |
| import Foreign.C.Types (CFloat) | |
| import Language.C.Quote.OpenCL (cfun) | |
| import Text.PrettyPrint.Mainland (prettyCompact) | |
| import Text.PrettyPrint.Mainland.Class (ppr) | |
| kernelSource :: String | |
| kernelSource = prettyCompact . ppr $ [cfun| | |
| kernel void doubleArray( | |
| global float *in, | |
| global float *out | |
| ) { | |
| int i = get_global_id(0); | |
| out[i] = 2 * in[i]; | |
| } | |
| |] | |
| main :: IO () | |
| main = do | |
| putStrLn "Hello World" | |
| -- fetch a platform | |
| (platform : _) <- clGetPlatformIDs | |
| platName <- clGetPlatformInfo platform CL_PLATFORM_NAME | |
| putStrLn $ "Platform name: " ++ platName | |
| -- fetch a GPU device | |
| (dev : _) <- clGetDeviceIDs platform CL_DEVICE_TYPE_GPU | |
| devName <- clGetDeviceName dev | |
| putStrLn $ "Device name: " ++ devName | |
| -- create a context | |
| context <- clCreateContext [CL_CONTEXT_PLATFORM platform] [dev] print | |
| q <- clCreateCommandQueue context dev [] | |
| -- initialize the kernel | |
| program <- clCreateProgramWithSource context kernelSource | |
| clBuildProgram program [dev] "" | |
| kernel <- clCreateKernel program "doubleArray" | |
| -- create some data | |
| let | |
| inputData :: Vector CFloat | |
| inputData = V.fromList [0 .. 1000] | |
| nBytes :: Int | |
| nBytes = V.length inputData * sizeOf (0 :: CFloat) | |
| outVec <- V.unsafeWith inputData $ \inptr -> do | |
| mem_in <- clCreateBuffer context | |
| [CL_MEM_READ_ONLY, CL_MEM_COPY_HOST_PTR] | |
| (nBytes, castPtr inptr) | |
| mem_out <- clCreateBuffer context [CL_MEM_WRITE_ONLY] (nBytes, nullPtr) | |
| clSetKernelArgSto kernel 0 mem_in | |
| clSetKernelArgSto kernel 1 mem_out | |
| -- execute the kernel | |
| eventExec <- clEnqueueNDRangeKernel q kernel [V.length inputData] [] [] | |
| -- get result | |
| outMutableVec <- MV.unsafeNew (V.length inputData) :: IO (IOVector CFloat) | |
| MV.unsafeWith outMutableVec $ \outptr -> do | |
| eventRead <- clEnqueueReadBuffer q | |
| mem_out | |
| True | |
| 0 | |
| nBytes | |
| (castPtr outptr) | |
| [eventExec] | |
| pure () | |
| V.freeze outMutableVec | |
| -- release the context | |
| clReleaseContext context | |
| putStrLn $ "Result vector = " ++ show outVec |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment