Last active
May 28, 2025 21:59
-
-
Save ek0/bde6c616cec6da9120e6a35016142019 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use std::fs; | |
use tree_sitter::{InputEdit, Language, Node, Parser, Point}; | |
use tree_sitter_cpp; | |
fn find_matching_parenthesis(node: &Node) -> Option<usize> { | |
for index in (0..node.child_count()).rev() { | |
println!("{}", node.child(index).unwrap()); | |
let current = node.child(index).unwrap(); | |
if current.kind() == ")" { | |
return Some(current.end_byte()); | |
} | |
} | |
None | |
} | |
fn print_if_statement<'a>(node: Node<'a>, source_code: &'a [u8]) { | |
assert_eq!(node.kind(), "if_statement"); | |
let condition_clause = node.child(1).unwrap(); | |
let text = &source_code[condition_clause.byte_range()]; | |
println!("Condition clause: {}", std::str::from_utf8(text).unwrap()); | |
if node.child_count() == 4 { | |
// else clause is present? | |
let else_statement = node.child(3).unwrap(); | |
let text = &source_code[else_statement.byte_range()]; | |
println!("Else clause: {}", std::str::from_utf8(text).unwrap()); | |
} | |
} | |
fn print_for_range_loop<'a>(node: Node<'a>, source_code: &'a [u8]) { | |
assert_eq!(node.kind(), "for_range_loop"); | |
let index = find_matching_parenthesis(&node).unwrap(); | |
let for_statement_range = node.child(0).unwrap().start_byte()..index; | |
let text = &source_code[for_statement_range]; | |
println!("For range clause: {}", std::str::from_utf8(text).unwrap()); | |
} | |
fn print_for_statement<'a>(node: Node<'a>, source_code: &'a [u8]) { | |
assert_eq!(node.kind(), "for_statement"); | |
let index = find_matching_parenthesis(&node).unwrap(); | |
let for_statement_range = node.child(0).unwrap().start_byte()..index; | |
let text = &source_code[for_statement_range]; | |
println!("For clause: {}", std::str::from_utf8(text).unwrap()); | |
} | |
fn print_while_statement<'a>(node: Node<'a>, source_code: &'a [u8]) { | |
assert_eq!(node.kind(), "while_statement"); | |
let condition_clause = node.child(1).unwrap(); | |
let text = &source_code[condition_clause.byte_range()]; | |
println!("While clause: while {}", std::str::from_utf8(text).unwrap()); | |
} | |
fn print_switch_statement<'a>(node: Node<'a>, source_code: &'a [u8]) { | |
assert_eq!(node.kind(), "switch_statement"); | |
println!("switch child count: {}", node.child_count()); | |
let condition_clause = node.child(1).unwrap(); | |
let text = &source_code[condition_clause.byte_range()]; | |
println!( | |
"Switch clause: switch {}", | |
std::str::from_utf8(text).unwrap() | |
); | |
} | |
fn print_goto_statement(node: Node, source_code: &[u8]) { | |
assert_eq!(node.kind(), "goto_statement"); | |
let text = &source_code[node.byte_range()]; | |
println!( | |
"Goto: {}", | |
std::str::from_utf8(text).unwrap() | |
); | |
} | |
fn extract_compound_statements<'a>( | |
node: Node<'a>, | |
source_code: &'a [u8], | |
results: &mut Vec<&'a str>, | |
) { | |
println!("Node kind: {}", node.kind()); | |
//println!("{}", std::str::from_utf8(&source_code[node.byte_range()]).unwrap()); | |
//if node.kind() == "compound_statement" { | |
// let text = &source_code[node.byte_range()]; | |
// results.push(std::str::from_utf8(text).unwrap()); | |
//} | |
if node.kind() == "if_statement" { | |
print_if_statement(node, source_code); | |
let condition_clause = node.child(1).unwrap(); | |
//let text = &source_code[node.byte_range()]; | |
//results.push(std::str::from_utf8(text).unwrap()); | |
let text = &source_code[condition_clause.byte_range()]; | |
results.push(std::str::from_utf8(text).unwrap()); | |
} else if node.kind() == "for_range_loop" { | |
print_for_range_loop(node, source_code); | |
} else if node.kind() == "for_statement" { | |
print_for_statement(node, source_code); | |
} else if node.kind() == "while_statement" { | |
print_while_statement(node, source_code); | |
} else if node.kind() == "switch_statement" { | |
print_switch_statement(node, source_code); | |
} else if node.kind() == "goto_statement" { | |
print_goto_statement(node, source_code); | |
} | |
for child in node.children(&mut node.walk()) { | |
extract_compound_statements(child, source_code, results); | |
} | |
} | |
fn main() { | |
//let source_code = r#" | |
// int add(int a, int b) { | |
// int result = a + b; | |
// if (result > 0) { | |
// result += 10; | |
// } | |
// return result; | |
// } | |
//"#; | |
let args: Vec<String> = std::env::args().collect(); | |
let path = args[1].clone(); | |
let contents = fs::read_to_string(path).expect("Should have been able to read the file"); | |
let mut parser = Parser::new(); | |
parser | |
.set_language(&tree_sitter_cpp::LANGUAGE.into()) | |
.expect("Error loading C++ grammar"); | |
let tree = parser | |
.parse(contents.as_str(), None) | |
.expect("Failed to parse code"); | |
let root_node = tree.root_node(); | |
let mut results = Vec::new(); | |
extract_compound_statements(root_node, contents.as_str().as_bytes(), &mut results); | |
//for (i, stmt) in results.iter().enumerate() { | |
// println!("Compound Statement {}:\n{}\n", i + 1, stmt); | |
//} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment