diff --git a/Cargo.lock b/Cargo.lock index aa8b613cd..610f7e650 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1435,11 +1435,14 @@ dependencies = [ "bitflags 2.6.0", "hashbrown 0.14.5", "helix-stdx", + "libloading", "log", "once_cell", "regex", + "regex-cursor", "ropey", "slotmap", + "thiserror", "tree-sitter", ] diff --git a/helix-syntax/Cargo.toml b/helix-syntax/Cargo.toml index 3ba12ddd1..c2fadb0a4 100644 --- a/helix-syntax/Cargo.toml +++ b/helix-syntax/Cargo.toml @@ -26,3 +26,9 @@ bitflags = "2.4" ahash = "0.8.9" hashbrown = { version = "0.14.3", features = ["raw"] } log = "0.4" +regex-cursor = "0.1.4" +libloading = "0.8.3" +thiserror = "1.0.59" + +[build-dependencies] +cc = "1.0.95" diff --git a/helix-syntax/build.rs b/helix-syntax/build.rs new file mode 100644 index 000000000..49a0dc593 --- /dev/null +++ b/helix-syntax/build.rs @@ -0,0 +1,28 @@ +use std::path::PathBuf; +use std::{env, fs}; + +fn main() { + if env::var_os("DISABLED_TS_BUILD").is_some() { + return; + } + let mut config = cc::Build::new(); + + let manifest_path = PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").unwrap()); + let include_path = manifest_path.join("../vendor/tree-sitter/include"); + let src_path = manifest_path.join("../vendor/tree-sitter/src"); + for entry in fs::read_dir(&src_path).unwrap() { + let entry = entry.unwrap(); + let path = src_path.join(entry.file_name()); + println!("cargo:rerun-if-changed={}", path.to_str().unwrap()); + } + + config + .flag_if_supported("-std=c11") + .flag_if_supported("-fvisibility=hidden") + .flag_if_supported("-Wshadow") + .flag_if_supported("-Wno-unused-parameter") + .include(&src_path) + .include(&include_path) + .file(src_path.join("lib.c")) + .compile("tree-sitter"); +} diff --git a/helix-syntax/src/injections_tree.rs b/helix-syntax/src/injections_tree.rs index e181a7754..fdc53e491 100644 --- a/helix-syntax/src/injections_tree.rs +++ b/helix-syntax/src/injections_tree.rs @@ -3,7 +3,7 @@ use std::iter::Peekable; use std::sync::Arc; use hashbrown::HashMap; -use slotmap::{new_key_type, HopSlotMap, SlotMap}; +use slotmap::{new_key_type, SlotMap}; use tree_sitter::Tree; use crate::parse::LayerUpdateFlags; diff --git a/helix-syntax/src/lib.rs b/helix-syntax/src/lib.rs index 915d8df5b..074f87272 100644 --- a/helix-syntax/src/lib.rs +++ b/helix-syntax/src/lib.rs @@ -1,6 +1,6 @@ use ::ropey::RopeSlice; -use slotmap::{new_key_type, HopSlotMap}; -use tree_sitter::{Node, Parser, Point, Query, QueryCursor, Range, Tree}; +use ::tree_sitter::{Node, Parser, Point, Query, QueryCursor, Range, Tree}; +use slotmap::HopSlotMap; use std::borrow::Cow; use std::cell::RefCell; @@ -26,6 +26,7 @@ mod parse; mod pretty_print; mod ropey; mod tree_cursor; +pub mod tree_sitter; #[derive(Debug)] pub struct Syntax { @@ -321,7 +322,7 @@ fn byte_range_to_str(range: std::ops::Range, source: RopeSlice) -> Cow, } diff --git a/helix-syntax/src/tree_sitter.rs b/helix-syntax/src/tree_sitter.rs new file mode 100644 index 000000000..d75c5b245 --- /dev/null +++ b/helix-syntax/src/tree_sitter.rs @@ -0,0 +1,27 @@ +mod grammar; +mod parser; +mod query; +mod ropey; +mod syntax_tree; +mod syntax_tree_node; + +pub use grammar::Grammar; +pub use parser::{Parser, ParserInputRaw}; +pub use syntax_tree::{InputEdit, SyntaxTree}; +pub use syntax_tree_node::SyntaxTreeNode; + +#[repr(C)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Point { + pub row: u32, + pub column: u32, +} + +#[repr(C)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct Range { + pub start_point: Point, + pub end_point: Point, + pub start_byte: u32, + pub end_byte: u32, +} diff --git a/helix-syntax/src/tree_sitter/grammar.rs b/helix-syntax/src/tree_sitter/grammar.rs new file mode 100644 index 000000000..a97769248 --- /dev/null +++ b/helix-syntax/src/tree_sitter/grammar.rs @@ -0,0 +1,101 @@ +use std::fmt; +use std::path::{Path, PathBuf}; +use std::ptr::NonNull; + +use libloading::{Library, Symbol}; + +/// supported TS versions, WARNING: update when updating vendored c sources +pub const MIN_COMPATIBLE_ABI_VERSION: u32 = 13; +pub const ABI_VERSION: u32 = 14; + +// opaque pointer +enum GrammarData {} + +#[repr(transparent)] +#[derive(Clone, Copy, PartialEq, Eq, Hash)] +pub struct Grammar { + ptr: NonNull, +} + +unsafe impl Send for Grammar {} +unsafe impl Sync for Grammar {} + +impl std::fmt::Debug for Grammar { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Grammar").finish_non_exhaustive() + } +} + +impl Grammar { + pub unsafe fn new(name: &str, library_path: &Path) -> Result { + let library = unsafe { + Library::new(&library_path).map_err(|err| Error::DlOpen { + err, + path: library_path.to_owned(), + })? + }; + let language_fn_name = format!("tree_sitter_{}", name.replace('-', "_")); + let grammar = unsafe { + let language_fn: Symbol NonNull> = library + .get(language_fn_name.as_bytes()) + .map_err(|err| Error::DlSym { + err, + symbol: name.to_owned(), + })?; + Grammar { ptr: language_fn() } + }; + let version = grammar.version(); + if MIN_COMPATIBLE_ABI_VERSION <= version && version <= ABI_VERSION { + std::mem::forget(library); + Ok(grammar) + } else { + Err(Error::IncompatibleVersion { version }) + } + } + pub fn version(self) -> u32 { + unsafe { ts_language_version(self) } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("Error opening dynamic library {path:?}")] + DlOpen { + #[source] + err: libloading::Error, + path: PathBuf, + }, + #[error("Failed to load symbol {symbol}")] + DlSym { + #[source] + err: libloading::Error, + symbol: String, + }, + #[error("Tried to load grammar with incompatible ABI {version}.")] + IncompatibleVersion { version: u32 }, +} + +/// An error that occurred when trying to assign an incompatible [`Grammar`] to +/// a [`Parser`]. +#[derive(Debug, PartialEq, Eq)] +pub struct IncompatibleGrammarError { + version: u32, +} + +impl fmt::Display for IncompatibleGrammarError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "Tried to load grammar with incompatible ABI {}.", + self.version, + ) + } +} +impl std::error::Error for IncompatibleGrammarError {} + +extern "C" { + /// Get the ABI version number for this language. This version number + /// is used to ensure that languages were generated by a compatible version of + /// Tree-sitter. See also [`ts_parser_set_language`]. + pub fn ts_language_version(grammar: Grammar) -> u32; +} diff --git a/helix-syntax/src/tree_sitter/parser.rs b/helix-syntax/src/tree_sitter/parser.rs new file mode 100644 index 000000000..611637390 --- /dev/null +++ b/helix-syntax/src/tree_sitter/parser.rs @@ -0,0 +1,204 @@ +use std::os::raw::c_void; +use std::panic::catch_unwind; +use std::ptr::NonNull; +use std::{fmt, ptr}; + +use crate::tree_sitter::syntax_tree::{SyntaxTree, SyntaxTreeData}; +use crate::tree_sitter::{Grammar, Point, Range}; + +// opaque data +enum ParserData {} + +/// A stateful object that this is used to produce a [`Tree`] based on some +/// source code. +pub struct Parser { + ptr: NonNull, +} + +impl Parser { + /// Create a new parser. + #[must_use] + pub fn new() -> Parser { + Parser { + ptr: unsafe { ts_parser_new() }, + } + } + + /// Set the language that the parser should use for parsing. + pub fn set_language(&mut self, grammar: Grammar) { + unsafe { ts_parser_set_language(self.ptr, grammar) }; + } + + /// Set the ranges of text that the parser should include when parsing. By default, the parser + /// will always include entire documents. This function allows you to parse only a *portion* + /// of a document but still return a syntax tree whose ranges match up with the document as a + /// whole. You can also pass multiple disjoint ranges. + /// + /// `ranges` must be non-overlapping and sorted. + pub fn set_included_ranges(&mut self, ranges: &[Range]) -> Result<(), InvalidRangesErrror> { + // TODO: save some memory by only storing byte ranges and converting them to TS ranges in an + // internal buffer here. Points are not used by TS. Alternatively we can path the TS C code + // to accept a simple pair (struct with two fields) of byte positions here instead of a full + // tree sitter range + let success = unsafe { + ts_parser_set_included_ranges(self.ptr, ranges.as_ptr(), ranges.len() as u32) + }; + if success { + Ok(()) + } else { + Err(InvalidRangesErrror) + } + } + + #[must_use] + pub fn parse( + &mut self, + input: impl IntoParserInput, + old_tree: Option<&SyntaxTree>, + ) -> Option { + let mut input = input.into_parser_input(); + unsafe extern "C" fn read( + payload: NonNull, + byte_index: u32, + _position: Point, + bytes_read: &mut u32, + ) -> *const u8 { + match catch_unwind(|| { + let cursor: &mut C = payload.cast().as_mut(); + cursor.read(byte_index as usize) + }) { + Ok(slice) => { + *bytes_read = slice.len() as u32; + slice.as_ptr() + } + Err(_) => { + *bytes_read = 0; + ptr::null() + } + } + } + let input = ParserInputRaw { + payload: NonNull::from(&mut input).cast(), + read: read::, + // utf8 + encoding: 0, + }; + unsafe { + let old_tree = old_tree.map(|tree| tree.as_raw()); + let new_tree = ts_parser_parse(self.ptr, old_tree, input); + new_tree.map(|raw| SyntaxTree::from_raw(raw)) + } + } +} + +impl Default for Parser { + fn default() -> Self { + Self::new() + } +} + +unsafe impl Sync for Parser {} +unsafe impl Send for Parser {} +impl Drop for Parser { + fn drop(&mut self) { + unsafe { ts_parser_delete(self.ptr) } + } +} + +/// An error that occurred when trying to assign an incompatible [`Grammar`] to +/// a [`Parser`]. +#[derive(Debug, PartialEq, Eq)] +pub struct InvalidRangesErrror; + +impl fmt::Display for InvalidRangesErrror { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "include ranges are overlap or are not sorted",) + } +} +impl std::error::Error for InvalidRangesErrror {} + +type TreeSitterReadFn = unsafe extern "C" fn( + payload: NonNull, + byte_index: u32, + position: Point, + bytes_read: &mut u32, +) -> *const u8; + +#[repr(C)] +#[derive(Debug)] +pub struct ParserInputRaw { + pub payload: NonNull, + pub read: TreeSitterReadFn, + pub encoding: u32, +} + +pub trait ParserInput { + fn read(&mut self, offset: usize) -> &[u8]; +} + +pub trait IntoParserInput { + type ParserInput; + fn into_parser_input(self) -> Self::ParserInput; +} + +extern "C" { + /// Create a new parser + fn ts_parser_new() -> NonNull; + /// Delete the parser, freeing all of the memory that it used. + fn ts_parser_delete(parser: NonNull); + /// Set the language that the parser should use for parsing. Returns a boolean indicating + /// whether or not the language was successfully assigned. True means assignment + /// succeeded. False means there was a version mismatch: the language was generated with + /// an incompatible version of the Tree-sitter CLI. Check the language's version using + /// [`ts_language_version`] and compare it to this library's [`TREE_SITTER_LANGUAGE_VERSION`] + /// and [`TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION`] constants. + fn ts_parser_set_language(parser: NonNull, language: Grammar) -> bool; + /// Set the ranges of text that the parser should include when parsing. By default, the parser + /// will always include entire documents. This function allows you to parse only a *portion* + /// of a document but still return a syntax tree whose ranges match up with the document as a + /// whole. You can also pass multiple disjoint ranges. The second and third parameters specify + /// the location and length of an array of ranges. The parser does *not* take ownership of + /// these ranges; it copies the data, so it doesn't matter how these ranges are allocated. + /// If `count` is zero, then the entire document will be parsed. Otherwise, the given ranges + /// must be ordered from earliest to latest in the document, and they must not overlap. That + /// is, the following must hold for all: `i < count - 1`: `ranges[i].end_byte <= ranges[i + + /// 1].start_byte` If this requirement is not satisfied, the operation will fail, the ranges + /// will not be assigned, and this function will return `false`. On success, this function + /// returns `true` + fn ts_parser_set_included_ranges( + parser: NonNull, + ranges: *const Range, + count: u32, + ) -> bool; + + /// Use the parser to parse some source code and create a syntax tree. If you are parsing this + /// document for the first time, pass `NULL` for the `old_tree` parameter. Otherwise, if you + /// have already parsed an earlier version of this document and the document has since been + /// edited, pass the previous syntax tree so that the unchanged parts of it can be reused. + /// This will save time and memory. For this to work correctly, you must have already edited + /// the old syntax tree using the [`ts_tree_edit`] function in a way that exactly matches + /// the source code changes. The [`TSInput`] parameter lets you specify how to read the text. + /// It has the following three fields: 1. [`read`]: A function to retrieve a chunk of text + /// at a given byte offset and (row, column) position. The function should return a pointer + /// to the text and write its length to the [`bytes_read`] pointer. The parser does not + /// take ownership of this buffer; it just borrows it until it has finished reading it. The + /// function should write a zero value to the [`bytes_read`] pointer to indicate the end of the + /// document. 2. [`payload`]: An arbitrary pointer that will be passed to each invocation of + /// the [`read`] function. 3. [`encoding`]: An indication of how the text is encoded. Either + /// `TSInputEncodingUTF8` or `TSInputEncodingUTF16`. This function returns a syntax tree + /// on success, and `NULL` on failure. There are three possible reasons for failure: 1. The + /// parser does not have a language assigned. Check for this using the [`ts_parser_language`] + /// function. 2. Parsing was cancelled due to a timeout that was set by an earlier call to the + /// [`ts_parser_set_timeout_micros`] function. You can resume parsing from where the parser + /// left out by calling [`ts_parser_parse`] again with the same arguments. Or you can start + /// parsing from scratch by first calling [`ts_parser_reset`]. 3. Parsing was cancelled using + /// a cancellation flag that was set by an earlier call to [`ts_parser_set_cancellation_flag`]. + /// You can resume parsing from where the parser left out by calling [`ts_parser_parse`] again + /// with the same arguments. [`read`]: TSInput::read [`payload`]: TSInput::payload [`encoding`]: + /// TSInput::encoding [`bytes_read`]: TSInput::read + fn ts_parser_parse( + parser: NonNull, + old_tree: Option>, + input: ParserInputRaw, + ) -> Option>; +} diff --git a/helix-syntax/src/tree_sitter/query.rs b/helix-syntax/src/tree_sitter/query.rs new file mode 100644 index 000000000..44a7fa3c3 --- /dev/null +++ b/helix-syntax/src/tree_sitter/query.rs @@ -0,0 +1,574 @@ +use std::fmt::Display; +use std::iter::zip; +use std::path::{Path, PathBuf}; +use std::ptr::NonNull; +use std::{slice, str}; + +use regex_cursor::engines::meta::Regex; + +use crate::tree_sitter::Grammar; + +macro_rules! bail { + ($($args:tt)*) => {{ + return Err(format!($($args)*)) + }} +} + +macro_rules! ensure { + ($cond: expr, $($args:tt)*) => {{ + if !$cond { + return Err(format!($($args)*)) + } + }} +} + +#[derive(Debug)] +enum TextPredicateCaptureKind { + EqString(u32), + EqCapture(u32), + MatchString(Regex), + AnyString(Box<[Box]>), +} + +struct TextPredicateCapture { + capture_idx: u32, + kind: TextPredicateCaptureKind, + negated: bool, + match_all: bool, +} + +pub enum QueryData {} +pub struct Query { + raw: NonNull, + num_captures: u32, +} + +impl Query { + /// Create a new query from a string containing one or more S-expression + /// patterns. + /// + /// The query is associated with a particular grammar, and can only be run + /// on syntax nodes parsed with that grammar. References to Queries can be + /// shared between multiple threads. + pub fn new(grammar: Grammar, source: &str, path: impl AsRef) -> Result { + assert!( + source.len() <= i32::MAX as usize, + "TreeSitter queries must be smaller then 2 GiB (is {})", + source.len() as f64 / 1024.0 / 1024.0 / 1024.0 + ); + let mut error_offset = 0u32; + let mut error_kind = RawQueryError::None; + let bytes = source.as_bytes(); + + // Compile the query. + let ptr = unsafe { + ts_query_new( + grammar, + bytes.as_ptr(), + bytes.len() as u32, + &mut error_offset, + &mut error_kind, + ) + }; + + let Some(raw) = ptr else { + let offset = error_offset as usize; + let error_word = || { + source[offset..] + .chars() + .take_while(|&c| c.is_alphanumeric() || matches!(c, '_' | '-')) + .collect() + }; + let err = match error_kind { + RawQueryError::NodeType => { + let node: String = error_word(); + ParseError::InvalidNodeType { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + node.chars().count(), + ), + node, + } + } + RawQueryError::Field => { + let field = error_word(); + ParseError::InvalidFieldName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + field.chars().count(), + ), + field, + } + } + RawQueryError::Capture => { + let capture = error_word(); + ParseError::InvalidCaptureName { + location: ParserErrorLocation::new( + source, + path.as_ref(), + offset, + capture.chars().count(), + ), + capture, + } + } + RawQueryError::Syntax => ParseError::SyntaxError(ParserErrorLocation::new( + source, + path.as_ref(), + offset, + 0, + )), + RawQueryError::Structure => ParseError::ImpossiblePattern( + ParserErrorLocation::new(source, path.as_ref(), offset, 0), + ), + RawQueryError::None => { + unreachable!("tree-sitter returned a null pointer but did not set an error") + } + RawQueryError::Language => unreachable!("should be handled at grammar load"), + }; + return Err(err) + }; + + // I am not going to bother with safety comments here, all of these are + // safe as long as TS is not buggy because raw is a properly constructed query + let num_captures = unsafe { ts_query_capture_count(raw) }; + + Ok(Query { raw, num_captures }) + } + + fn parse_predicates(&mut self) { + let pattern_count = unsafe { ts_query_pattern_count(self.raw) }; + + let mut text_predicates = Vec::with_capacity(pattern_count as usize); + let mut property_predicates = Vec::with_capacity(pattern_count as usize); + let mut property_settings = Vec::with_capacity(pattern_count as usize); + let mut general_predicates = Vec::with_capacity(pattern_count as usize); + + for i in 0..pattern_count {} + } + + fn parse_predicate(&self, pattern_index: u32) -> Result<(), String> { + let mut text_predicates = Vec::new(); + let mut property_predicates = Vec::new(); + let mut property_settings = Vec::new(); + let mut general_predicates = Vec::new(); + for predicate in self.predicates(pattern_index) { + let predicate = unsafe { Predicate::new(self, predicate)? }; + + // Build a predicate for each of the known predicate function names. + match predicate.operator_name { + "eq?" | "not-eq?" | "any-eq?" | "any-not-eq?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.get_arg(0, PredicateArg::Capture)?; + let (arg2, arg2_kind) = predicate.get_any_arg(1); + + let negated = matches!(predicate.operator_name, "not-eq?" | "not-any-eq?"); + let match_all = matches!(predicate.operator_name, "eq?" | "not-eq?"); + let kind = match arg2_kind { + PredicateArg::Capture => TextPredicateCaptureKind::EqCapture(arg2), + PredicateArg::String => TextPredicateCaptureKind::EqString(arg2), + }; + text_predicates.push(TextPredicateCapture { + capture_idx, + kind, + negated, + match_all, + }); + } + + "match?" | "not-match?" | "any-match?" | "any-not-match?" => { + predicate.check_arg_count(2)?; + let capture_idx = predicate.get_arg(0, PredicateArg::Capture)?; + let regex = predicate.get_str_arg(1)?; + + let negated = + matches!(predicate.operator_name, "not-match?" | "any-not-match?"); + let match_all = matches!(predicate.operator_name, "match?" | "not-match?"); + let regex = match Regex::new(regex) { + Ok(regex) => regex, + Err(err) => bail!("invalid regex '{regex}', {err}"), + }; + text_predicates.push(TextPredicateCapture { + capture_idx, + kind: TextPredicateCaptureKind::MatchString(regex), + negated, + match_all, + }); + } + + "set!" => property_settings.push(Self::parse_property( + row, + operator_name, + &capture_names, + &string_values, + &p[1..], + )?), + + "is?" | "is-not?" => property_predicates.push(( + Self::parse_property( + row, + operator_name, + &capture_names, + &string_values, + &p[1..], + )?, + operator_name == "is?", + )), + + "any-of?" | "not-any-of?" => { + if p.len() < 2 { + return Err(predicate_error(row, format!( + "Wrong number of arguments to #any-of? predicate. Expected at least 1, got {}.", + p.len() - 1 + ))); + } + if p[1].type_ != TYPE_CAPTURE { + return Err(predicate_error(row, format!( + "First argument to #any-of? predicate must be a capture name. Got literal \"{}\".", + string_values[p[1].value_id as usize], + ))); + } + + let is_positive = operator_name == "any-of?"; + let mut values = Vec::new(); + for arg in &p[2..] { + if arg.type_ == TYPE_CAPTURE { + return Err(predicate_error(row, format!( + "Arguments to #any-of? predicate must be literals. Got capture @{}.", + capture_names[arg.value_id as usize], + ))); + } + values.push(string_values[arg.value_id as usize]); + } + text_predicates.push(TextPredicateCapture::AnyString( + p[1].value_id, + values + .iter() + .map(|x| (*x).to_string().into()) + .collect::>() + .into(), + is_positive, + )); + } + + _ => general_predicates.push(QueryPredicate { + operator: operator_name.to_string().into(), + args: p[1..] + .iter() + .map(|a| { + if a.type_ == TYPE_CAPTURE { + QueryPredicateArg::Capture(a.value_id) + } else { + QueryPredicateArg::String( + string_values[a.value_id as usize].to_string().into(), + ) + } + }) + .collect(), + }), + } + } + + text_predicates_vec.push(text_predicates.into()); + property_predicates_vec.push(property_predicates.into()); + property_settings_vec.push(property_settings.into()); + general_predicates_vec.push(general_predicates.into()); + } + + fn predicates<'a>( + &'a self, + pattern_index: u32, + ) -> impl Iterator + 'a { + let predicate_steps = unsafe { + let mut len = 0u32; + let raw_predicates = ts_query_predicates_for_pattern(self.raw, pattern_index, &mut len); + (len != 0) + .then(|| slice::from_raw_parts(raw_predicates, len as usize)) + .unwrap_or_default() + }; + predicate_steps + .split(|step| step.kind == PredicateStepKind::Done) + .filter(|predicate| !predicate.is_empty()) + } + + /// Safety: value_idx must be a valid string id (in bounds) for this query and pattern_index + unsafe fn get_pattern_string(&self, value_id: u32) -> &str { + unsafe { + let mut len = 0; + let ptr = ts_query_string_value_for_id(self.raw, value_id, &mut len); + let data = slice::from_raw_parts(ptr, len as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(data) + } + } + + #[inline] + pub fn capture_name(&self, capture_idx: u32) -> &str { + // this one needs an assertions because the ts c api is inconsisent + // and unsafe, other functions do have checks and would return null + assert!(capture_idx <= self.num_captures, "invalid capture index"); + let mut length = 0; + unsafe { + let ptr = ts_query_capture_name_for_id(self.raw, capture_idx, &mut length); + let name = slice::from_raw_parts(ptr, length as usize); + // safety: we only allow passing valid str(ings) as arguments to query::new + // name is always a substring of that. Treesitter does proper utf8 segmentation + // so any substrings it produces are codepoint aligned and therefore valid utf8 + str::from_utf8_unchecked(name) + } + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct ParserErrorLocation { + pub path: PathBuf, + /// at which line the error occured + pub line: usize, + /// at which codepoints/columns the errors starts in the line + pub column: usize, + /// how many codepoints/columns the error takes up + pub len: usize, + line_content: String, +} + +impl ParserErrorLocation { + pub fn new(source: &str, path: &Path, offset: usize, len: usize) -> ParserErrorLocation { + let (line, line_content) = source[..offset] + .split('\n') + .map(|line| line.strip_suffix('\r').unwrap_or(line)) + .enumerate() + .last() + .unwrap_or((0, "")); + let column = line_content.chars().count(); + ParserErrorLocation { + path: path.to_owned(), + line, + column, + len, + line_content: line_content.to_owned(), + } + } +} + +impl Display for ParserErrorLocation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!( + f, + " --> {}:{}:{}", + self.path.display(), + self.line, + self.column + )?; + let line = self.line.to_string(); + let prefix = format_args!(" {:width$} |", "", width = line.len()); + writeln!(f, "{prefix}"); + writeln!(f, " {line} | {}", self.line_content)?; + writeln!( + f, + "{prefix}{:width$}{:^ { + operator_name: &'a str, + args: &'a [PredicateStep], + query: &'a Query, +} + +impl<'a> Predicate<'a> { + unsafe fn new( + query: &'a Query, + predicate: &'a [PredicateStep], + ) -> Result, String> { + ensure!( + predicate[0].kind == PredicateStepKind::String, + "expected predicate to start with a function name. Got @{}.", + query.capture_name(predicate[0].value_id) + ); + let operator_name = query.get_pattern_string(predicate[0].value_id); + Ok(Predicate { + operator_name, + args: &predicate[1..], + query, + }) + } + pub fn check_arg_count(&self, n: usize) -> Result<(), String> { + ensure!( + self.args.len() == n, + "expected {n} arguments for #{}, got {}", + self.operator_name, + self.args.len() + ); + Ok(()) + } + + pub fn get_arg(&self, i: usize, expect: PredicateArg) -> Result { + let (val, actual) = self.get_any_arg(i); + match (actual, expect) { + (PredicateArg::Capture, PredicateArg::String) => bail!( + "{i}. argument to #{} expected a capture, got literal {val:?}", + self.operator_name + ), + (PredicateArg::String, PredicateArg::Capture) => bail!( + "{i}. argument to #{} must be a literal, got capture @{val:?}", + self.operator_name + ), + _ => (), + }; + Ok(val) + } + pub fn get_str_arg(&self, i: usize) -> Result<&'a str, String> { + let arg = self.get_arg(i, PredicateArg::String)?; + unsafe { Ok(self.query.get_pattern_string(arg)) } + } + + pub fn get_any_arg(&self, i: usize) -> (u32, PredicateArg) { + match self.args[i].kind { + PredicateStepKind::String => unsafe { (self.args[i].value_id, PredicateArg::String) }, + PredicateStepKind::Capture => (self.args[i].value_id, PredicateArg::Capture), + PredicateStepKind::Done => unreachable!(), + } + } +} + +enum PredicateArg { + Capture, + String, +} + +extern "C" { + /// Create a new query from a string containing one or more S-expression + /// patterns. The query is associated with a particular language, and can + /// only be run on syntax nodes parsed with that language. If all of the + /// given patterns are valid, this returns a [`TSQuery`]. If a pattern is + /// invalid, this returns `NULL`, and provides two pieces of information + /// about the problem: 1. The byte offset of the error is written to + /// the `error_offset` parameter. 2. The type of error is written to the + /// `error_type` parameter. + pub fn ts_query_new( + grammar: Grammar, + source: *const u8, + source_len: u32, + error_offset: &mut u32, + error_type: &mut RawQueryError, + ) -> Option>; + + /// Delete a query, freeing all of the memory that it used. + pub fn ts_query_delete(query: NonNull); + + /// Get the number of patterns, captures, or string literals in the query. + pub fn ts_query_pattern_count(query: NonNull) -> u32; + pub fn ts_query_capture_count(query: NonNull) -> u32; + pub fn ts_query_string_count(query: NonNull) -> u32; + + /// Get the byte offset where the given pattern starts in the query's + /// source. This can be useful when combining queries by concatenating their + /// source code strings. + pub fn ts_query_start_byte_for_pattern(query: NonNull, pattern_index: u32) -> u32; + + /// Get all of the predicates for the given pattern in the query. The + /// predicates are represented as a single array of steps. There are three + /// types of steps in this array, which correspond to the three legal values + /// for the `type` field: - `TSQueryPredicateStepTypeCapture` - Steps with + /// this type represent names of captures. Their `value_id` can be used + /// with the [`ts_query_capture_name_for_id`] function to obtain the name + /// of the capture. - `TSQueryPredicateStepTypeString` - Steps with this + /// type represent literal strings. Their `value_id` can be used with the + /// [`ts_query_string_value_for_id`] function to obtain their string value. + /// - `TSQueryPredicateStepTypeDone` - Steps with this type are *sentinels* + /// that represent the end of an individual predicate. If a pattern has two + /// predicates, then there will be two steps with this `type` in the array. + pub fn ts_query_predicates_for_pattern( + query: NonNull, + pattern_index: u32, + step_count: &mut u32, + ) -> *const PredicateStep; + + pub fn ts_query_is_pattern_rooted(query: NonNull, pattern_index: u32) -> bool; + pub fn ts_query_is_pattern_non_local(query: NonNull, pattern_index: u32) -> bool; + pub fn ts_query_is_pattern_guaranteed_at_step( + query: NonNull, + byte_offset: u32, + ) -> bool; + /// Get the name and length of one of the query's captures, or one of the + /// query's string literals. Each capture and string is associated with a + /// numeric id based on the order that it appeared in the query's source. + pub fn ts_query_capture_name_for_id( + query: NonNull, + index: u32, + length: &mut u32, + ) -> *const u8; + + pub fn ts_query_string_value_for_id( + self_: NonNull, + index: u32, + length: &mut u32, + ) -> *const u8; +} diff --git a/helix-syntax/src/tree_sitter/ropey.rs b/helix-syntax/src/tree_sitter/ropey.rs new file mode 100644 index 000000000..aa59b2f27 --- /dev/null +++ b/helix-syntax/src/tree_sitter/ropey.rs @@ -0,0 +1,38 @@ +use regex_cursor::{Cursor, RopeyCursor}; +use ropey::RopeSlice; + +use crate::tree_sitter::parser::{IntoParserInput, ParserInput}; + +pub struct RopeParserInput<'a> { + src: RopeSlice<'a>, + cursor: regex_cursor::RopeyCursor<'a>, +} + +impl<'a> IntoParserInput for RopeSlice<'a> { + type ParserInput = RopeParserInput<'a>; + + fn into_parser_input(self) -> Self::ParserInput { + RopeParserInput { + src: self, + cursor: RopeyCursor::new(self), + } + } +} + +impl ParserInput for RopeParserInput<'_> { + fn read(&mut self, offset: usize) -> &[u8] { + // this cursor is optimized for contigous reads which are by far the most common during parsing + // very far jumps (like injections at the other end of the document) are handelde + // by restarting a new cursor (new chunks iterator) + if offset < self.cursor.offset() && self.cursor.offset() - offset > 4906 { + self.cursor = regex_cursor::RopeyCursor::at(self.src, offset); + } else { + while self.cursor.offset() + self.cursor.chunk().len() >= offset { + if !self.cursor.advance() { + return &[]; + } + } + } + self.cursor.chunk() + } +} diff --git a/helix-syntax/src/tree_sitter/syntax_tree.rs b/helix-syntax/src/tree_sitter/syntax_tree.rs new file mode 100644 index 000000000..f1c608d6f --- /dev/null +++ b/helix-syntax/src/tree_sitter/syntax_tree.rs @@ -0,0 +1,80 @@ +use std::fmt; +use std::ptr::NonNull; + +use crate::tree_sitter::syntax_tree_node::{SyntaxTreeNode, SyntaxTreeNodeRaw}; +use crate::tree_sitter::Point; + +// opaque pointers +pub(super) enum SyntaxTreeData {} + +pub struct SyntaxTree { + ptr: NonNull, +} + +impl SyntaxTree { + pub(super) unsafe fn from_raw(raw: NonNull) -> SyntaxTree { + SyntaxTree { ptr: raw } + } + + pub(super) fn as_raw(&self) -> NonNull { + self.ptr + } + + pub fn root_node(&self) -> SyntaxTreeNode<'_> { + unsafe { SyntaxTreeNode::from_raw(ts_tree_root_node(self.ptr)).unwrap() } + } + + pub fn edit(&mut self, edit: &InputEdit) { + unsafe { ts_tree_edit(self.ptr, edit) } + } +} + +impl fmt::Debug for SyntaxTree { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{{Tree {:?}}}", self.root_node()) + } +} + +impl Drop for SyntaxTree { + fn drop(&mut self) { + unsafe { ts_tree_delete(self.ptr) } + } +} + +impl Clone for SyntaxTree { + fn clone(&self) -> Self { + unsafe { + SyntaxTree { + ptr: ts_tree_copy(self.ptr), + } + } + } +} + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct InputEdit { + pub start_byte: u32, + pub old_end_byte: u32, + pub new_end_byte: u32, + pub start_point: Point, + pub old_end_point: Point, + pub new_end_point: Point, +} + +extern "C" { + /// Create a shallow copy of the syntax tree. This is very fast. You need to + /// copy a syntax tree in order to use it on more than one thread at a time, + /// as syntax trees are not thread safe. + fn ts_tree_copy(self_: NonNull) -> NonNull; + /// Delete the syntax tree, freeing all of the memory that it used. + fn ts_tree_delete(self_: NonNull); + /// Get the root node of the syntax tree. + fn ts_tree_root_node<'tree>(self_: NonNull) -> SyntaxTreeNodeRaw; + /// Edit the syntax tree to keep it in sync with source code that has been + /// edited. + /// + /// You must describe the edit both in terms of byte offsets and in terms of + /// row/column coordinates. + fn ts_tree_edit(self_: NonNull, edit: &InputEdit); +} diff --git a/helix-syntax/src/tree_sitter/syntax_tree_node.rs b/helix-syntax/src/tree_sitter/syntax_tree_node.rs new file mode 100644 index 000000000..b50c79696 --- /dev/null +++ b/helix-syntax/src/tree_sitter/syntax_tree_node.rs @@ -0,0 +1,291 @@ +use std::ffi::c_void; +use std::marker::PhantomData; +use std::ops::Range; +use std::ptr::NonNull; + +use crate::tree_sitter::syntax_tree::SyntaxTree; +use crate::tree_sitter::Grammar; + +#[repr(C)] +#[derive(Debug, Clone, Copy)] +pub(super) struct SyntaxTreeNodeRaw { + context: [u32; 4], + id: *const c_void, + tree: *const c_void, +} + +impl From> for SyntaxTreeNodeRaw { + fn from(node: SyntaxTreeNode) -> SyntaxTreeNodeRaw { + SyntaxTreeNodeRaw { + context: node.context, + id: node.id.as_ptr(), + tree: node.tree.as_ptr(), + } + } +} + +#[derive(Debug, Clone)] +pub struct SyntaxTreeNode<'tree> { + context: [u32; 4], + id: NonNull, + tree: NonNull, + _phantom: PhantomData<&'tree SyntaxTree>, +} + +impl<'tree> SyntaxTreeNode<'tree> { + #[inline] + pub(super) unsafe fn from_raw(raw: SyntaxTreeNodeRaw) -> Option { + Some(SyntaxTreeNode { + context: raw.context, + id: NonNull::new(raw.id as *mut _)?, + tree: unsafe { NonNull::new_unchecked(raw.tree as *mut _) }, + _phantom: PhantomData, + }) + } + + #[inline] + fn as_raw(&self) -> SyntaxTreeNodeRaw { + SyntaxTreeNodeRaw { + context: self.context, + id: self.id.as_ptr(), + tree: self.tree.as_ptr(), + } + } + + /// Get this node's type as a numerical id. + #[inline] + pub fn kind_id(&self) -> u16 { + unsafe { ts_node_symbol(self.as_raw()) } + } + + /// Get the [`Language`] that was used to parse this node's syntax tree. + #[inline] + pub fn grammar(&self) -> Grammar { + unsafe { ts_node_language(self.as_raw()) } + } + + /// Check if this node is *named*. + /// + /// Named nodes correspond to named rules in the grammar, whereas + /// *anonymous* nodes correspond to string literals in the grammar. + #[inline] + pub fn is_named(&self) -> bool { + unsafe { ts_node_is_named(self.as_raw()) } + } + + /// Check if this node is *missing*. + /// + /// Missing nodes are inserted by the parser in order to recover from + /// certain kinds of syntax errors. + #[inline] + pub fn is_missing(&self) -> bool { + unsafe { ts_node_is_missing(self.as_raw()) } + } + /// Get the byte offsets where this node starts. + #[inline] + pub fn start_byte(&self) -> usize { + unsafe { ts_node_start_byte(self.as_raw()) as usize } + } + + /// Get the byte offsets where this node end. + #[inline] + pub fn end_byte(&self) -> usize { + unsafe { ts_node_end_byte(self.as_raw()) as usize } + } + + /// Get the byte range of source code that this node represents. + // TODO: use helix_stdx::Range once available + #[inline] + pub fn byte_range(&self) -> Range { + self.start_byte()..self.end_byte() + } + + /// Get the node's child at the given index, where zero represents the first + /// child. + /// + /// This method is fairly fast, but its cost is technically log(i), so if + /// you might be iterating over a long list of children, you should use + /// [`SyntaxTreeNode::children`] instead. + #[inline] + pub fn child(&self, i: usize) -> Option> { + unsafe { SyntaxTreeNode::from_raw(ts_node_child(self.as_raw(), i as u32)) } + } + + /// Get this node's number of children. + #[inline] + pub fn child_count(&self) -> usize { + unsafe { ts_node_child_count(self.as_raw()) as usize } + } + + /// Get this node's *named* child at the given index. + /// + /// See also [`SyntaxTreeNode::is_named`]. + /// This method is fairly fast, but its cost is technically log(i), so if + /// you might be iterating over a long list of children, you should use + /// [`SyntaxTreeNode::named_children`] instead. + #[inline] + pub fn named_child(&self, i: usize) -> Option> { + unsafe { SyntaxTreeNode::from_raw(ts_node_named_child(self.as_raw(), i as u32)) } + } + + /// Get this node's number of *named* children. + /// + /// See also [`SyntaxTreeNode::is_named`]. + #[inline] + pub fn named_child_count(&self) -> usize { + unsafe { ts_node_named_child_count(self.as_raw()) as usize } + } + + #[inline] + unsafe fn map( + &self, + f: unsafe extern "C" fn(SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw, + ) -> Option> { + SyntaxTreeNode::from_raw(f(self.as_raw())) + } + + /// Get this node's immediate parent. + #[inline] + pub fn parent(&self) -> Option { + unsafe { self.map(ts_node_parent) } + } + + /// Get this node's next sibling. + #[inline] + pub fn next_sibling(&self) -> Option { + unsafe { self.map(ts_node_next_sibling) } + } + + /// Get this node's previous sibling. + #[inline] + pub fn prev_sibling(&self) -> Option { + unsafe { self.map(ts_node_prev_sibling) } + } + + /// Get this node's next named sibling. + #[inline] + pub fn next_named_sibling(&self) -> Option { + unsafe { self.map(ts_node_next_named_sibling) } + } + + /// Get this node's previous named sibling. + #[inline] + pub fn prev_named_sibling(&self) -> Option { + unsafe { self.map(ts_node_prev_named_sibling) } + } + + /// Get the smallest node within this node that spans the given range. + #[inline] + pub fn descendant_for_byte_range(&self, start: usize, end: usize) -> Option { + unsafe { + Self::from_raw(ts_node_descendant_for_byte_range( + self.as_raw(), + start as u32, + end as u32, + )) + } + } + + /// Get the smallest named node within this node that spans the given range. + #[inline] + pub fn named_descendant_for_byte_range(&self, start: usize, end: usize) -> Option { + unsafe { + Self::from_raw(ts_node_named_descendant_for_byte_range( + self.as_raw(), + start as u32, + end as u32, + )) + } + } + // /// Iterate over this node's children. + // /// + // /// A [`TreeCursor`] is used to retrieve the children efficiently. Obtain + // /// a [`TreeCursor`] by calling [`Tree::walk`] or [`SyntaxTreeNode::walk`]. To avoid + // /// unnecessary allocations, you should reuse the same cursor for + // /// subsequent calls to this method. + // /// + // /// If you're walking the tree recursively, you may want to use the + // /// [`TreeCursor`] APIs directly instead. + // pub fn children<'cursor>( + // &self, + // cursor: &'cursor mut TreeCursor<'tree>, + // ) -> impl ExactSizeIterator> + 'cursor { + // cursor.reset(self.to_raw()); + // cursor.goto_first_child(); + // (0..self.child_count()).map(move |_| { + // let result = cursor.node(); + // cursor.goto_next_sibling(); + // result + // }) + // } +} + +unsafe impl Send for SyntaxTreeNode<'_> {} +unsafe impl Sync for SyntaxTreeNode<'_> {} + +extern "C" { + /// Get the node's type as a numerical id. + fn ts_node_symbol(node: SyntaxTreeNodeRaw) -> u16; + + /// Get the node's language. + fn ts_node_language(node: SyntaxTreeNodeRaw) -> Grammar; + + /// Check if the node is *named*. Named nodes correspond to named rules in + /// the grammar, whereas *anonymous* nodes correspond to string literals in + /// the grammar + fn ts_node_is_named(node: SyntaxTreeNodeRaw) -> bool; + + /// Check if the node is *missing*. Missing nodes are inserted by the parser + /// in order to recover from certain kinds of syntax errors + fn ts_node_is_missing(node: SyntaxTreeNodeRaw) -> bool; + + /// Get the node's immediate parent + fn ts_node_parent(node: SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw; + + /// Get the node's child at the given index, where zero represents the first + /// child + fn ts_node_child(node: SyntaxTreeNodeRaw, child_index: u32) -> SyntaxTreeNodeRaw; + + /// Get the node's number of children + fn ts_node_child_count(node: SyntaxTreeNodeRaw) -> u32; + + /// Get the node's *named* child at the given index. See also + /// [`ts_node_is_named`] + fn ts_node_named_child(node: SyntaxTreeNodeRaw, child_index: u32) -> SyntaxTreeNodeRaw; + + /// Get the node's number of *named* children. See also [`ts_node_is_named`] + fn ts_node_named_child_count(node: SyntaxTreeNodeRaw) -> u32; + + /// Get the node's next sibling + fn ts_node_next_sibling(node: SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw; + + fn ts_node_prev_sibling(node: SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw; + + /// Get the node's next *named* sibling + fn ts_node_next_named_sibling(node: SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw; + + fn ts_node_prev_named_sibling(node: SyntaxTreeNodeRaw) -> SyntaxTreeNodeRaw; + + /// Get the smallest node within this node that spans the given range of + /// bytes or (row, column) positions + fn ts_node_descendant_for_byte_range( + node: SyntaxTreeNodeRaw, + + start: u32, + end: u32, + ) -> SyntaxTreeNodeRaw; + + /// Get the smallest named node within this node that spans the given range + /// of bytes or (row, column) positions + fn ts_node_named_descendant_for_byte_range( + node: SyntaxTreeNodeRaw, + start: u32, + end: u32, + ) -> SyntaxTreeNodeRaw; + + /// Get the node's start byte. + fn ts_node_start_byte(self_: SyntaxTreeNodeRaw) -> u32; + + /// Get the node's end byte. + fn ts_node_end_byte(node: SyntaxTreeNodeRaw) -> u32; +}