Allow capturing multiple nodes in textobject queries

Treesitter captures can contain multiple nodes like so:

```
(line_comment)+ @comment
```

This would match each line in a comment as a separate
`@comment` capture when what we actually want is the
whole set of contiguous `line_comment` nodes to be
captured under the `@comment` capture. This commit enables
this behaviour.
pull/1726/head
Gokul Soumya 2 years ago committed by Blaž Hrastnik
parent 78d37fd332
commit e6c36e82cf

@ -244,16 +244,55 @@ pub struct TextObjectQuery {
pub query: Query,
}
pub enum CapturedNode<'a> {
Single(Node<'a>),
Grouped(Vec<Node<'a>>),
}
impl<'a> CapturedNode<'a> {
pub fn start_byte(&self) -> usize {
match self {
Self::Single(n) => n.start_byte(),
Self::Grouped(ns) => ns[0].start_byte(),
}
}
pub fn end_byte(&self) -> usize {
match self {
Self::Single(n) => n.end_byte(),
Self::Grouped(ns) => ns.last().unwrap().end_byte(),
}
}
pub fn byte_range(&self) -> std::ops::Range<usize> {
self.start_byte()..self.end_byte()
}
}
impl TextObjectQuery {
/// Run the query on the given node and return sub nodes which match given
/// capture ("function.inside", "class.around", etc).
///
/// Captures may contain multiple nodes by using quantifiers (+, *, etc),
/// and support for this is partial and could use improvement.
///
/// ```query
/// ;; supported:
/// (comment)+ @capture
///
/// ;; unsupported:
/// (
/// (comment)+
/// (function)
/// ) @capture
/// ```
pub fn capture_nodes<'a>(
&'a self,
capture_name: &str,
node: Node<'a>,
slice: RopeSlice<'a>,
cursor: &'a mut QueryCursor,
) -> Option<impl Iterator<Item = Node<'a>>> {
) -> Option<impl Iterator<Item = CapturedNode<'a>>> {
self.capture_nodes_any(&[capture_name], node, slice, cursor)
}
@ -265,17 +304,28 @@ impl TextObjectQuery {
node: Node<'a>,
slice: RopeSlice<'a>,
cursor: &'a mut QueryCursor,
) -> Option<impl Iterator<Item = Node<'a>>> {
) -> Option<impl Iterator<Item = CapturedNode<'a>>> {
let capture_idx = capture_names
.iter()
.find_map(|cap| self.query.capture_index_for_name(cap))?;
let captures = cursor.captures(&self.query, node, RopeProvider(slice));
let captures = cursor.matches(&self.query, node, RopeProvider(slice));
captures
.filter_map(move |(mat, idx)| {
(mat.captures[idx].index == capture_idx).then(|| mat.captures[idx].node)
})
.into()
let nodes = captures.flat_map(move |mat| {
let captures = mat.captures.iter().filter(move |c| c.index == capture_idx);
let nodes = captures.map(|c| c.node);
let pattern_idx = mat.pattern_index;
let quantifier = self.query.capture_quantifiers(pattern_idx)[capture_idx as usize];
let iter: Box<dyn Iterator<Item = CapturedNode>> = match quantifier {
CaptureQuantifier::OneOrMore | CaptureQuantifier::ZeroOrMore => {
Box::new(std::iter::once(CapturedNode::Grouped(nodes.collect())))
}
_ => Box::new(nodes.map(CapturedNode::Single)),
};
iter
});
Some(nodes)
}
}
@ -1075,8 +1125,8 @@ pub(crate) fn generate_edits(
use std::sync::atomic::{AtomicUsize, Ordering};
use std::{iter, mem, ops, str, usize};
use tree_sitter::{
Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor, QueryError,
QueryMatch, Range, TextProvider, Tree,
CaptureQuantifier, Language as Grammar, Node, Parser, Point, Query, QueryCaptures, QueryCursor,
QueryError, QueryMatch, Range, TextProvider, Tree,
};
const CANCELLATION_CHECK_INTERVAL: usize = 100;
@ -1928,6 +1978,50 @@ mod test {
use super::*;
use crate::{Rope, Transaction};
#[test]
fn test_textobject_queries() {
let query_str = r#"
(line_comment)+ @quantified_nodes
((line_comment)+) @quantified_nodes_grouped
((line_comment) (line_comment)) @multiple_nodes_grouped
"#;
let source = Rope::from_str(
r#"
/// a comment on
/// mutiple lines
"#,
);
let loader = Loader::new(Configuration { language: vec![] });
let language = get_language(&crate::RUNTIME_DIR, "Rust").unwrap();
let query = Query::new(language, query_str).unwrap();
let textobject = TextObjectQuery { query };
let mut cursor = QueryCursor::new();
let config = HighlightConfiguration::new(language, "", "", "").unwrap();
let syntax = Syntax::new(&source, Arc::new(config), Arc::new(loader));
let root = syntax.tree().root_node();
let mut test = |capture, range| {
let matches: Vec<_> = textobject
.capture_nodes(capture, root, source.slice(..), &mut cursor)
.unwrap()
.collect();
assert_eq!(
matches[0].byte_range(),
range,
"@{capture} expected {range:?}"
)
};
test("quantified_nodes", 1..35);
// NOTE: Enable after implementing proper node group capturing
// test("quantified_nodes_grouped", 1..35);
// test("multiple_nodes_grouped", 1..35);
}
#[test]
fn test_parser() {
let highlight_names: Vec<String> = [

Loading…
Cancel
Save