diff --git a/helix-core/src/syntax.rs b/helix-core/src/syntax.rs index ca06e2dd5..8f62bead2 100644 --- a/helix-core/src/syntax.rs +++ b/helix-core/src/syntax.rs @@ -244,16 +244,55 @@ pub struct TextObjectQuery { pub query: Query, } +pub enum CapturedNode<'a> { + Single(Node<'a>), + Grouped(Vec>), +} + +impl<'a> CapturedNode<'a> { + pub fn start_byte(&self) -> usize { + match self { + Self::Single(n) => n.start_byte(), + Self::Grouped(ns) => ns[0].start_byte(), + } + } + + pub fn end_byte(&self) -> usize { + match self { + Self::Single(n) => n.end_byte(), + Self::Grouped(ns) => ns.last().unwrap().end_byte(), + } + } + + pub fn byte_range(&self) -> std::ops::Range { + self.start_byte()..self.end_byte() + } +} + impl TextObjectQuery { /// Run the query on the given node and return sub nodes which match given /// capture ("function.inside", "class.around", etc). + /// + /// Captures may contain multiple nodes by using quantifiers (+, *, etc), + /// and support for this is partial and could use improvement. + /// + /// ```query + /// ;; supported: + /// (comment)+ @capture + /// + /// ;; unsupported: + /// ( + /// (comment)+ + /// (function) + /// ) @capture + /// ``` pub fn capture_nodes<'a>( &'a self, capture_name: &str, node: Node<'a>, slice: RopeSlice<'a>, cursor: &'a mut QueryCursor, - ) -> Option>> { + ) -> Option>> { self.capture_nodes_any(&[capture_name], node, slice, cursor) } @@ -265,17 +304,28 @@ impl TextObjectQuery { node: Node<'a>, slice: RopeSlice<'a>, cursor: &'a mut QueryCursor, - ) -> Option>> { + ) -> Option>> { let capture_idx = capture_names .iter() .find_map(|cap| self.query.capture_index_for_name(cap))?; - let captures = cursor.captures(&self.query, node, RopeProvider(slice)); + let captures = cursor.matches(&self.query, node, RopeProvider(slice)); - captures - .filter_map(move |(mat, idx)| { - (mat.captures[idx].index == capture_idx).then(|| mat.captures[idx].node) - }) - .into() + let nodes = captures.flat_map(move |mat| { + let captures = mat.captures.iter().filter(move |c| c.index == capture_idx); + let nodes = captures.map(|c| c.node); + let pattern_idx = mat.pattern_index; + let quantifier = self.query.capture_quantifiers(pattern_idx)[capture_idx as usize]; + + let iter: Box> = match quantifier { + CaptureQuantifier::OneOrMore | CaptureQuantifier::ZeroOrMore => { + Box::new(std::iter::once(CapturedNode::Grouped(nodes.collect()))) + } + _ => Box::new(nodes.map(CapturedNode::Single)), + }; + + iter + }); + Some(nodes) } } @@ -1075,8 +1125,8 @@ pub(crate) fn generate_edits( use std::sync::atomic::{AtomicUsize, Ordering}; use std::{iter, mem, ops, str, usize}; use tree_sitter::{ - Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor, QueryError, - QueryMatch, Range, TextProvider, Tree, + CaptureQuantifier, Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor, + QueryError, QueryMatch, Range, TextProvider, Tree, }; const CANCELLATION_CHECK_INTERVAL: usize = 100; @@ -1928,6 +1978,50 @@ mod test { use super::*; use crate::{Rope, Transaction}; + #[test] + fn test_textobject_queries() { + let query_str = r#" + (line_comment)+ @quantified_nodes + ((line_comment)+) @quantified_nodes_grouped + ((line_comment) (line_comment)) @multiple_nodes_grouped + "#; + let source = Rope::from_str( + r#" +/// a comment on +/// mutiple lines + "#, + ); + + let loader = Loader::new(Configuration { language: vec![] }); + let language = get_language(&crate::RUNTIME_DIR, "Rust").unwrap(); + + let query = Query::new(language, query_str).unwrap(); + let textobject = TextObjectQuery { query }; + let mut cursor = QueryCursor::new(); + + let config = HighlightConfiguration::new(language, "", "", "").unwrap(); + let syntax = Syntax::new(&source, Arc::new(config), Arc::new(loader)); + + let root = syntax.tree().root_node(); + let mut test = |capture, range| { + let matches: Vec<_> = textobject + .capture_nodes(capture, root, source.slice(..), &mut cursor) + .unwrap() + .collect(); + + assert_eq!( + matches[0].byte_range(), + range, + "@{capture} expected {range:?}" + ) + }; + + test("quantified_nodes", 1..35); + // NOTE: Enable after implementing proper node group capturing + // test("quantified_nodes_grouped", 1..35); + // test("multiple_nodes_grouped", 1..35); + } + #[test] fn test_parser() { let highlight_names: Vec = [