Created
March 23, 2022 01:12
-
-
Save SharanSMenon/ee13eb74eb997c565488beb76cc85f2d to your computer and use it in GitHub Desktop.
Vulkano 0.29 Vector-Scalar Product program. Multiplies a vector of 65K numbers by 12 using the GPU device. Vulkano and Vulkano shaders required for this.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate vulkano; | |
extern crate vulkano_shaders; | |
use vulkano::{ | |
sync, | |
buffer::{BufferUsage, CpuAccessibleBuffer}, | |
command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage}, | |
device::{ | |
physical::{PhysicalDevice}, | |
QueueCreateInfo, | |
}, | |
Version, | |
instance::{Instance, InstanceCreateInfo, InstanceExtensions}, | |
descriptor_set::{PersistentDescriptorSet, WriteDescriptorSet}, | |
pipeline::{Pipeline, PipelineBindPoint}, | |
sync::{GpuFuture} | |
}; | |
fn main() { | |
let instance_info = InstanceCreateInfo { | |
max_api_version: Some(Version::V1_2), | |
enabled_extensions: InstanceExtensions::none(), | |
enabled_layers: vec![], | |
..InstanceCreateInfo::application_from_cargo_toml() | |
}; | |
let instance = Instance::new(instance_info).expect("failed to create instance"); | |
let physical = PhysicalDevice::enumerate(&instance) | |
.next() | |
.expect("failed to enumerate physical devices"); | |
for family in physical.queue_families() { | |
println!("Found a queue family with {:?} queue(s)", family.queues_count()); | |
} | |
let queue_family = physical.queue_families() | |
.find(|&q| q.supports_graphics()) | |
.expect("couldn't find a graphical queue family"); | |
use vulkano::device::{Device, DeviceExtensions, DeviceCreateInfo}; | |
let device_extensions = DeviceExtensions { | |
..DeviceExtensions::none() | |
}; | |
let (device, mut queues) = Device::new( | |
physical, | |
DeviceCreateInfo { | |
// here we pass the desired queue families that we want to use | |
enabled_extensions: physical | |
.required_extensions() | |
.union(&device_extensions), | |
queue_create_infos: vec![QueueCreateInfo::family(queue_family)], | |
..Default::default() | |
}, | |
) | |
.expect("failed to create device"); | |
let queue = queues.next().expect("failed to find associated queue"); | |
let data_iter = 0..65535; | |
let data_buffer = CpuAccessibleBuffer::from_iter( | |
device.clone(), | |
BufferUsage::all(), | |
false, | |
data_iter, | |
).unwrap(); | |
let shader = cs::load(device.clone()).expect("failed to create shader module"); | |
use vulkano::pipeline::ComputePipeline; | |
let compute_pipeline = ComputePipeline::new( | |
device.clone(), | |
shader.entry_point("main").unwrap(), | |
&(), | |
None, | |
|_| {}, | |
).expect("failed to create compute pipeline"); | |
let layout = compute_pipeline.layout().set_layouts() | |
.get(0).unwrap(); | |
let set = PersistentDescriptorSet::new( | |
layout.clone(), | |
[WriteDescriptorSet::buffer(0, data_buffer.clone())], // 0 is the binding | |
).unwrap(); | |
let mut builder = AutoCommandBufferBuilder::primary(device.clone(), | |
queue.family(), | |
CommandBufferUsage::OneTimeSubmit).unwrap(); | |
builder | |
.bind_pipeline_compute(compute_pipeline.clone()) | |
.bind_descriptor_sets( | |
PipelineBindPoint::Compute, | |
compute_pipeline.layout().clone(), | |
0, | |
set, | |
) | |
.dispatch([1024, 1, 1]).unwrap(); | |
let command_buffer = builder.build().unwrap(); | |
let future = sync::now(device.clone()) | |
.then_execute(queue.clone(), command_buffer) | |
.unwrap() | |
.then_signal_fence_and_flush() | |
.unwrap(); | |
future.wait(None).unwrap(); | |
let content = data_buffer.read().unwrap(); | |
for (n, val) in content.iter().enumerate() { | |
assert_eq!(*val, n as u32 * 12); | |
} | |
println!("Everything succeeded!"); | |
} | |
mod cs { | |
vulkano_shaders::shader! { | |
ty: "compute", | |
src: " | |
#version 450 | |
layout(local_size_x=64, local_size_y=1, local_size_z=1) in; | |
layout(set=0, binding=0) buffer Data { | |
uint data[]; | |
} data; | |
void main() { | |
uint idx = gl_GlobalInvocationID.x; | |
data.data[idx] *= 12; | |
}" | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Here is
cargo.toml