Skip to content

Instantly share code, notes, and snippets.

@ek0
Last active May 28, 2025 21:59
Show Gist options
  • Save ek0/bde6c616cec6da9120e6a35016142019 to your computer and use it in GitHub Desktop.
Save ek0/bde6c616cec6da9120e6a35016142019 to your computer and use it in GitHub Desktop.
use std::fs;
use tree_sitter::{InputEdit, Language, Node, Parser, Point};
use tree_sitter_cpp;
fn find_matching_parenthesis(node: &Node) -> Option<usize> {
for index in (0..node.child_count()).rev() {
println!("{}", node.child(index).unwrap());
let current = node.child(index).unwrap();
if current.kind() == ")" {
return Some(current.end_byte());
}
}
None
}
fn print_if_statement<'a>(node: Node<'a>, source_code: &'a [u8]) {
assert_eq!(node.kind(), "if_statement");
let condition_clause = node.child(1).unwrap();
let text = &source_code[condition_clause.byte_range()];
println!("Condition clause: {}", std::str::from_utf8(text).unwrap());
if node.child_count() == 4 {
// else clause is present?
let else_statement = node.child(3).unwrap();
let text = &source_code[else_statement.byte_range()];
println!("Else clause: {}", std::str::from_utf8(text).unwrap());
}
}
fn print_for_range_loop<'a>(node: Node<'a>, source_code: &'a [u8]) {
assert_eq!(node.kind(), "for_range_loop");
let index = find_matching_parenthesis(&node).unwrap();
let for_statement_range = node.child(0).unwrap().start_byte()..index;
let text = &source_code[for_statement_range];
println!("For range clause: {}", std::str::from_utf8(text).unwrap());
}
fn print_for_statement<'a>(node: Node<'a>, source_code: &'a [u8]) {
assert_eq!(node.kind(), "for_statement");
let index = find_matching_parenthesis(&node).unwrap();
let for_statement_range = node.child(0).unwrap().start_byte()..index;
let text = &source_code[for_statement_range];
println!("For clause: {}", std::str::from_utf8(text).unwrap());
}
fn print_while_statement<'a>(node: Node<'a>, source_code: &'a [u8]) {
assert_eq!(node.kind(), "while_statement");
let condition_clause = node.child(1).unwrap();
let text = &source_code[condition_clause.byte_range()];
println!("While clause: while {}", std::str::from_utf8(text).unwrap());
}
fn print_switch_statement<'a>(node: Node<'a>, source_code: &'a [u8]) {
assert_eq!(node.kind(), "switch_statement");
println!("switch child count: {}", node.child_count());
let condition_clause = node.child(1).unwrap();
let text = &source_code[condition_clause.byte_range()];
println!(
"Switch clause: switch {}",
std::str::from_utf8(text).unwrap()
);
}
fn print_goto_statement(node: Node, source_code: &[u8]) {
assert_eq!(node.kind(), "goto_statement");
let text = &source_code[node.byte_range()];
println!(
"Goto: {}",
std::str::from_utf8(text).unwrap()
);
}
fn extract_compound_statements<'a>(
node: Node<'a>,
source_code: &'a [u8],
results: &mut Vec<&'a str>,
) {
println!("Node kind: {}", node.kind());
//println!("{}", std::str::from_utf8(&source_code[node.byte_range()]).unwrap());
//if node.kind() == "compound_statement" {
// let text = &source_code[node.byte_range()];
// results.push(std::str::from_utf8(text).unwrap());
//}
if node.kind() == "if_statement" {
print_if_statement(node, source_code);
let condition_clause = node.child(1).unwrap();
//let text = &source_code[node.byte_range()];
//results.push(std::str::from_utf8(text).unwrap());
let text = &source_code[condition_clause.byte_range()];
results.push(std::str::from_utf8(text).unwrap());
} else if node.kind() == "for_range_loop" {
print_for_range_loop(node, source_code);
} else if node.kind() == "for_statement" {
print_for_statement(node, source_code);
} else if node.kind() == "while_statement" {
print_while_statement(node, source_code);
} else if node.kind() == "switch_statement" {
print_switch_statement(node, source_code);
} else if node.kind() == "goto_statement" {
print_goto_statement(node, source_code);
}
for child in node.children(&mut node.walk()) {
extract_compound_statements(child, source_code, results);
}
}
fn main() {
//let source_code = r#"
// int add(int a, int b) {
// int result = a + b;
// if (result > 0) {
// result += 10;
// }
// return result;
// }
//"#;
let args: Vec<String> = std::env::args().collect();
let path = args[1].clone();
let contents = fs::read_to_string(path).expect("Should have been able to read the file");
let mut parser = Parser::new();
parser
.set_language(&tree_sitter_cpp::LANGUAGE.into())
.expect("Error loading C++ grammar");
let tree = parser
.parse(contents.as_str(), None)
.expect("Failed to parse code");
let root_node = tree.root_node();
let mut results = Vec::new();
extract_compound_statements(root_node, contents.as_str().as_bytes(), &mut results);
//for (i, stmt) in results.iter().enumerate() {
// println!("Compound Statement {}:\n{}\n", i + 1, stmt);
//}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment