Created
November 18, 2023 04:43
-
-
Save jedjoud10/47565fffbcb94cf821b5d732b208e059 to your computer and use it in GitHub Desktop.
Specialize SpecConstants in SPIRV. Rust.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Specialize spec constants ourselves cause there's no other way to do it (fuck) | |
fn specialize_spec_constants(binary: &mut [u32], constants: &Constants) { | |
// Converts a SpecConstant op code to it's specialized variant (Constant) | |
fn specialize(op_code_index: usize, binary: &mut [u32], defined: SpecConstant) { | |
// Get the op code of the spec constant | |
let op_code = binary[op_code_index] & 0x0000ffff; | |
// Get the index of the spec constant literal | |
let literal_index = match op_code { | |
48 | 49 => op_code_index + 2, | |
50 => op_code_index + 3, | |
_ => panic!(), | |
}; | |
// Write to the literal value if it's not a boolean | |
let literal = &mut binary[literal_index]; | |
*literal = match defined { | |
SpecConstant::I32(val) => bytemuck::cast(val), | |
SpecConstant::U32(val) => val, | |
SpecConstant::F32(val) => bytemuck::cast(val), | |
_ => *literal, | |
}; | |
// Update the OpCode of the spec constant to the proper one in case it's a boolean | |
let new = match op_code { | |
48 | 49 => { | |
if let SpecConstant::BOOL(val) = defined { | |
match val { | |
true => 48, | |
false => 49, | |
} | |
} else { | |
panic!() | |
} | |
} | |
x => x, | |
}; | |
// Write new op code heheheha | |
binary[op_code_index] &= 0xffff0000; | |
binary[op_code_index] |= new; | |
} | |
// List of op codes that we must change | |
let spec_consts_op_codes = [48u32, 49, 50]; | |
// Contains the SpecId decorations for each constant | |
let mut spec_ids = AHashMap::<u32, u32>::default(); | |
// Loop till we find an OpSpecConstant | |
// TODO: Atm this will just look into the spirv and find words that match up with the predefined ones | |
// Very unsafe cause it could lead to corrupted data. Must find a way to make it safe | |
for i in 0..binary.len() { | |
// Get the op code and argument count | |
let word = binary[i]; | |
let op = word & 0x0000ffff; | |
let count = (word & 0xffff0000) >> 16; | |
// Keep track of SpecId decorations | |
if op == 71 && count == 4 && binary[i + 2] == 1 { | |
spec_ids.insert(binary[i + 1], binary[i + 3]); | |
} | |
// For now, we only support 32 bit types | |
if spec_consts_op_codes.contains(&op) && (count == 4 || count == 3) { | |
// Get the literal value that this spec constant is defaulted to | |
let id = binary[i + 2]; | |
let spec_id = spec_ids.get(&id).unwrap(); | |
// Get the value specified by the user | |
if let Some(value) = constants.get(spec_id) { | |
specialize(i, binary, *value); | |
} | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment