diff --git a/crates/postgresql-cst-parser/src/tree_sitter.rs b/crates/postgresql-cst-parser/src/tree_sitter.rs index 85adab5..791ca14 100644 --- a/crates/postgresql-cst-parser/src/tree_sitter.rs +++ b/crates/postgresql-cst-parser/src/tree_sitter.rs @@ -98,6 +98,28 @@ impl std::fmt::Display for Range { } } +impl Range { + pub fn extended_by(&self, other: &Self) -> Self { + Range { + start_byte: self.start_byte.min(other.start_byte), + end_byte: self.end_byte.max(other.end_byte), + + start_position: Point { + row: self.start_position.row.min(other.start_position.row), + column: self.start_position.column.min(other.start_position.column), + }, + end_position: Point { + row: self.end_position.row.max(other.end_position.row), + column: self.end_position.column.max(other.end_position.column), + }, + } + } + + pub fn is_adjacent(&self, other: &Self) -> bool { + self.end_byte == other.start_byte || self.start_byte == other.end_byte + } +} + impl<'a> Node<'a> { pub fn walk(&self) -> TreeCursor<'a> { TreeCursor { @@ -144,6 +166,20 @@ impl<'a> Node<'a> { } } + pub fn children(&self) -> Vec> { + if let Some(node) = self.node_or_token.as_node() { + node.children_with_tokens() + .map(|node| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: node, + }) + .collect() + } else { + vec![] + } + } + pub fn next_sibling(&self) -> Option> { self.node_or_token .next_sibling_or_token() @@ -154,6 +190,16 @@ impl<'a> Node<'a> { }) } + pub fn prev_sibling(&self) -> Option> { + self.node_or_token + .prev_sibling_or_token() + .map(|sibling| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: sibling, + }) + } + pub fn parent(&self) -> Option> { self.node_or_token.parent().map(|parent| Node { input: self.input, @@ -165,6 +211,82 @@ impl<'a> Node<'a> { pub fn is_comment(&self) -> bool { matches!(self.kind(), SyntaxKind::C_COMMENT | SyntaxKind::SQL_COMMENT) } + + /// Return the rightmost token in the subtree of this node + /// this is not tree-sitter's API + pub fn last_node(&self) -> Option> { + match &self.node_or_token { + NodeOrToken::Node(node) => node.last_token().map(|token| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(token), + }), + NodeOrToken::Token(token) => Some(Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(token), + }), + } + } + + /// Returns the next token in the tree. + /// This is not necessarily a direct sibling of this node/token, + /// but will always be further right in the tree. + /// this is not tree-sitter's API + pub fn next_token(&self) -> Option> { + match &self.node_or_token { + NodeOrToken::Token(token) => token.next_token().map(|next_token| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(next_token), + }), + NodeOrToken::Node(node) => { + // For a node, find its last token and then get the next token + node.last_token() + .and_then(|last_token| last_token.next_token()) + .map(|next_token| Node { + input: self.input, + range_map: Rc::clone(&self.range_map), + node_or_token: NodeOrToken::Token(next_token), + }) + } + } + } + + /// Returns an iterator over all descendant nodes (including tokens) + /// this is not tree-sitter's API + pub fn descendants(&self) -> impl Iterator> + '_ { + struct Descendants<'a> { + iter: Box> + 'a>, + } + + impl<'a> Iterator for Descendants<'a> { + type Item = Node<'a>; + + fn next(&mut self) -> Option { + self.iter.next() + } + } + + if let Some(node) = self.node_or_token.as_node() { + let input = self.input; + let range_map = Rc::clone(&self.range_map); + Descendants { + iter: Box::new( + node.descendants_with_tokens() + .map(move |node_or_token| Node { + input, + range_map: Rc::clone(&range_map), + node_or_token, + }), + ), + } + } else { + Descendants { + iter: Box::new(std::iter::empty()), + } + } + } } impl<'a> From> for TreeCursor<'a> { @@ -214,6 +336,15 @@ impl<'a> TreeCursor<'a> { } } + pub fn goto_prev_sibling(&mut self) -> bool { + if let Some(sibling) = self.node_or_token.prev_sibling_or_token() { + self.node_or_token = sibling; + true + } else { + false + } + } + pub fn is_comment(&self) -> bool { matches!( self.node_or_token.kind(), @@ -462,4 +593,53 @@ from assert_eq!(stmt_count, 2); } + + #[test] + fn test_last_node_returns_rightmost_node() { + let src = "SELECT u.*, (v).id, name;"; + let tree = parse(src).unwrap(); + let root = tree.root_node(); + + let target_list = root + .descendants() + .find(|node| node.kind() == SyntaxKind::target_list) + .expect("should find target_list"); + + // last node of the target_list is returned + let last_node = target_list.last_node().expect("should have last node"); + assert_eq!(last_node.text(), "name"); + + let target_els = target_list + .children() + .into_iter() + .filter(|node| node.kind() == SyntaxKind::target_el) + .collect::>(); + + let mut last_nodes = target_els + .iter() + .map(|node| node.last_node().expect("should have last node")); + + // last node of each target_el is returned + assert_eq!(last_nodes.next().unwrap().text(), "*"); + assert_eq!(last_nodes.next().unwrap().text(), "id"); + assert_eq!(last_nodes.next().unwrap().text(), "name"); + assert!(last_nodes.next().is_none()); + } + + #[test] + fn test_next_token() { + let src = "SELECT tbl.name as n from TBL;"; + let tree = parse(src).unwrap(); + let root = tree.root_node(); + + let name = root + .descendants() + .find(|node| node.kind() == SyntaxKind::NAME_P) + .expect("should find NAME_P"); + + // Even if not a direct sibling or not belonging to the same subtree, the next_token can retrieve the next token. + let next_token = name.next_token().expect("should have next token"); + assert_eq!(next_token.text(), "as"); + assert_eq!(next_token.kind(), SyntaxKind::AS); + } }