Skip to content

Instantly share code, notes, and snippets.

@antoinefortin
Created September 25, 2024 02:41
Show Gist options
  • Save antoinefortin/f127ab89372b97a38eeafdb043a3dc14 to your computer and use it in GitHub Desktop.
Save antoinefortin/f127ab89372b97a38eeafdb043a3dc14 to your computer and use it in GitHub Desktop.
Compute shader template
#define STB_IMAGE_IMPLEMENTATION
#include "stbimage.h"
#include <iostream> // for console I/O
#include <fstream> // for file I/O
#include <sstream> // for stringstream
#include <string> // for std::string
#include <vector> // for std::vector
#include <GL/glew.h> // for OpenGL and GLEW
#include <GLFW/glfw3.h> // for GLFW
// Function to load the shader source code from a file
std::string loadShaderSource(const std::string& filepath) {
std::ifstream shaderFile(filepath);
if (!shaderFile.is_open()) {
std::cerr << "Failed to open shader file: " << filepath << std::endl;
exit(-1);
}
std::stringstream shaderStream;
shaderStream << shaderFile.rdbuf(); // Correct usage for reading entire file into stringstream
shaderFile.close();
return shaderStream.str(); // Return the shader source as a string
}
// Function to write the RGB values to a text file
void writeRGBDataToFile(const std::vector<float>& rgbData, int width, int height, const std::string& filepath) {
std::ofstream outputFile(filepath);
if (!outputFile.is_open()) {
std::cerr << "Failed to open output file: " << filepath << std::endl;
exit(-1);
}
for (int i = 0; i < width * height; ++i) {
outputFile << static_cast<int>(rgbData[i * 3 + 0] * 255) << "-"
<< static_cast<int>(rgbData[i * 3 + 1] * 255) << "-"
<< static_cast<int>(rgbData[i * 3 + 2] * 255) << std::endl;
}
outputFile.close();
}
// Check for shader compilation errors
void checkShaderCompileStatus(GLuint shader) {
GLint success;
glGetShaderiv(shader, GL_COMPILE_STATUS, &success);
if (!success) {
char infoLog[512];
glGetShaderInfoLog(shader, 512, NULL, infoLog);
std::cerr << "ERROR::SHADER::COMPILATION_FAILED\n" << infoLog << std::endl;
}
}
// Create and compile the compute shader from source file
GLuint createComputeShader(const std::string& shaderSource) {
GLuint shader = glCreateShader(GL_COMPUTE_SHADER);
const char* sourceCStr = shaderSource.c_str();
glShaderSource(shader, 1, &sourceCStr, NULL);
glCompileShader(shader);
checkShaderCompileStatus(shader);
GLuint program = glCreateProgram();
glAttachShader(program, shader);
glLinkProgram(program);
// Check program linking status
GLint linkStatus;
glGetProgramiv(program, GL_LINK_STATUS, &linkStatus);
if (linkStatus == GL_FALSE) {
char infoLog[512];
glGetProgramInfoLog(program, 512, NULL, infoLog);
std::cerr << "ERROR::PROGRAM::LINKING_FAILED\n" << infoLog << std::endl;
}
glDeleteShader(shader); // Shader linked, can delete now
return program;
}
int main() {
// Initialize GLFW
if (!glfwInit()) {
std::cerr << "Failed to initialize GLFW" << std::endl;
return -1;
}
// Create a windowed mode window and its OpenGL context
GLFWwindow* window = glfwCreateWindow(800, 600, "Compute Shader Image Processing", NULL, NULL);
if (!window) {
glfwTerminate();
return -1;
}
// Make the window's context current
glfwMakeContextCurrent(window);
// Initialize GLEW
glewExperimental = GL_TRUE;
if (glewInit() != GLEW_OK) {
std::cerr << "Failed to initialize GLEW" << std::endl;
return -1;
}
// Load an image using stb_image (force RGBA format)
int width, height, channels;
unsigned char* imageData = stbi_load("input_image.png", &width, &height, &channels, 4); // Load as RGBA
if (imageData == nullptr) {
std::cerr << "Failed to load image. Ensure the file path is correct." << std::endl;
return -1;
}
std::cout << "Image Loaded: " << width << "x" << height << std::endl;
// Create texture from image data
GLuint texture;
glGenTextures(1, &texture);
glBindTexture(GL_TEXTURE_2D, texture);
// Set texture parameters
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR);
glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR);
// Load texture data into OpenGL (RGBA format)
glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA8, width, height, 0, GL_RGBA, GL_UNSIGNED_BYTE, imageData);
// Check for OpenGL errors
GLenum err;
while ((err = glGetError()) != GL_NO_ERROR) {
std::cerr << "OpenGL error: " << err << std::endl;
}
std::cout << "Texture created successfully." << std::endl;
// Free the image memory after loading it to GPU
stbi_image_free(imageData);
// Create a buffer to store the output RGB data
GLuint ssbo;
glGenBuffers(1, &ssbo);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
glBufferData(GL_SHADER_STORAGE_BUFFER, width * height * 3 * sizeof(float), NULL, GL_DYNAMIC_COPY);
glBindBufferBase(GL_SHADER_STORAGE_BUFFER, 1, ssbo);
// Load the compute shader source from an external file
std::string shaderSource = loadShaderSource("extract_rgb.glsl");
// Create and use the compute shader
GLuint computeProgram = createComputeShader(shaderSource);
glUseProgram(computeProgram);
// Bind the texture to image unit 0
glBindImageTexture(0, texture, 0, GL_FALSE, 0, GL_READ_ONLY, GL_RGBA8);
// Dispatch the compute shader (workgroup size of the image dimensions)
glDispatchCompute(width, height, 1);
// Ensure that the GPU has finished processing
glMemoryBarrier(GL_SHADER_STORAGE_BARRIER_BIT);
// Retrieve the results (RGB data)
std::vector<float> rgbData(width * height * 3);
glBindBuffer(GL_SHADER_STORAGE_BUFFER, ssbo);
float* ptr = (float*)glMapBuffer(GL_SHADER_STORAGE_BUFFER, GL_READ_ONLY);
std::copy(ptr, ptr + width * height * 3, rgbData.begin());
// Write the RGB data to the output file
writeRGBDataToFile(rgbData, width, height, "output_rgb.txt");
// Unmap and clean up
glUnmapBuffer(GL_SHADER_STORAGE_BUFFER);
glDeleteBuffers(1, &ssbo);
glDeleteProgram(computeProgram);
// Terminate GLFW
glfwDestroyWindow(window);
glfwTerminate();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment