diff --git a/helix-dap/examples/dap-basic.rs b/helix-dap/examples/dap-basic.rs index 82e147025..76fc0dd3f 100644 --- a/helix-dap/examples/dap-basic.rs +++ b/helix-dap/examples/dap-basic.rs @@ -56,8 +56,12 @@ pub async fn main() -> Result<()> { .read_line(&mut _in) .expect("Failed to read line"); + let mut stopped_event = client.listen_for_event("stopped".to_owned()).await; + println!("configurationDone: {:?}", client.configuration_done().await); - println!("stopped: {:?}", client.wait_for_stopped().await); + + println!("stopped: {:?}", stopped_event.recv().await); + println!("threads: {:#?}", client.threads().await); let bt = client.stack_trace(1).await.expect("expected stack trace"); println!("stack trace: {:#?}", bt); diff --git a/helix-dap/src/client.rs b/helix-dap/src/client.rs index 03e587572..f648b2482 100644 --- a/helix-dap/src/client.rs +++ b/helix-dap/src/client.rs @@ -2,15 +2,22 @@ use crate::{ transport::{Event, Payload, Request, Response, Transport}, Result, }; +use log::{error, info}; use serde::{Deserialize, Serialize}; use serde_json::{from_value, to_value, Value}; -use std::process::Stdio; -use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, +}; +use std::{collections::HashMap, process::Stdio}; use tokio::{ io::{AsyncBufRead, AsyncWrite, BufReader, BufWriter}, net::TcpStream, process::{Child, Command}, - sync::mpsc::{channel, UnboundedReceiver, UnboundedSender}, + sync::{ + mpsc::{channel, Receiver, Sender, UnboundedReceiver, UnboundedSender}, + Mutex, + }, }; #[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] @@ -252,9 +259,9 @@ pub struct Client { id: usize, _process: Option, server_tx: UnboundedSender, - server_rx: UnboundedReceiver, request_counter: AtomicU64, capabilities: Option, + awaited_events: Arc>>>, } impl Client { @@ -270,14 +277,52 @@ impl Client { id, _process: process, server_tx, - server_rx, request_counter: AtomicU64::new(0), capabilities: None, + awaited_events: Arc::new(Mutex::new(HashMap::default())), }; + tokio::spawn(Self::recv(Arc::clone(&client.awaited_events), server_rx)); + Ok(client) } + async fn recv( + awaited_events: Arc>>>, + mut server_rx: UnboundedReceiver, + ) { + while let Some(msg) = server_rx.recv().await { + match msg { + Payload::Event(ev) => { + let name = ev.event.clone(); + let tx = awaited_events.lock().await.remove(&name); + + match tx { + Some(tx) => match tx.send(ev).await { + Ok(_) => (), + Err(_) => error!( + "Tried sending event into a closed channel (name={:?})", + name + ), + }, + None => { + info!("unhandled event"); + // client_tx.send(Payload::Event(ev)).expect("Failed to send"); + } + } + } + Payload::Response(_) => unreachable!(), + Payload::Request(_) => todo!(), + } + } + } + + pub async fn listen_for_event(&self, name: String) -> Receiver { + let (rx, tx) = channel(1); + self.awaited_events.lock().await.insert(name.clone(), rx); + tx + } + pub async fn tcp(addr: std::net::SocketAddr, id: usize) -> Result { let stream = TcpStream::connect(addr).await?; let (rx, tx) = stream.into_split(); @@ -373,45 +418,25 @@ impl Client { } pub async fn launch(&mut self, args: impl Serialize) -> Result<()> { + let mut initialized = self.listen_for_event("initialized".to_owned()).await; + self.request("launch".to_owned(), to_value(args).ok()) .await?; - match self - .server_rx - .recv() - .await - .expect("Expected initialized event") - { - Payload::Event(Event { event, .. }) => { - if event == *"initialized" { - Ok(()) - } else { - unreachable!() - } - } - _ => unreachable!(), - } + initialized.recv().await; + + Ok(()) } pub async fn attach(&mut self, args: impl Serialize) -> Result<()> { + let mut initialized = self.listen_for_event("initialized".to_owned()).await; + self.request("attach".to_owned(), to_value(args).ok()) .await?; - match self - .server_rx - .recv() - .await - .expect("Expected initialized event") - { - Payload::Event(Event { event, .. }) => { - if event == *"initialized" { - Ok(()) - } else { - unreachable!() - } - } - _ => unreachable!(), - } + initialized.recv().await; + + Ok(()) } pub async fn set_breakpoints( @@ -447,19 +472,6 @@ impl Client { Ok(()) } - pub async fn wait_for_stopped(&mut self) -> Result<()> { - match self.server_rx.recv().await.expect("Expected stopped event") { - Payload::Event(Event { event, .. }) => { - if event == *"stopped" { - Ok(()) - } else { - unreachable!() - } - } - _ => unreachable!(), - } - } - pub async fn continue_thread(&mut self, thread_id: usize) -> Result> { let args = ContinueArguments { thread_id };