From 26a55dcefd9774235cd8c40eb2724daa0808e376 Mon Sep 17 00:00:00 2001 From: Dmitry Sharshakov Date: Fri, 13 Aug 2021 17:56:46 +0300 Subject: [PATCH] dap: make transport IO-agnostic --- helix-dap/src/client.rs | 2 +- helix-dap/src/transport.rs | 543 ++++++++++++++++++------------------- 2 files changed, 272 insertions(+), 273 deletions(-) diff --git a/helix-dap/src/client.rs b/helix-dap/src/client.rs index e8ad2c6fb..08f7704eb 100644 --- a/helix-dap/src/client.rs +++ b/helix-dap/src/client.rs @@ -280,7 +280,7 @@ impl Client { let writer = BufWriter::new(process.stdin.take().expect("Failed to open stdin")); let reader = BufReader::new(process.stdout.take().expect("Failed to open stdout")); - let (server_rx, server_tx) = Transport::start(reader, writer, id); + let (server_rx, server_tx) = Transport::start(Box::new(reader), Box::new(writer), id); let client = Self { id, diff --git a/helix-dap/src/transport.rs b/helix-dap/src/transport.rs index 2686183b4..0a40d9fec 100644 --- a/helix-dap/src/transport.rs +++ b/helix-dap/src/transport.rs @@ -1,272 +1,271 @@ -use crate::{Error, Result}; -use anyhow::Context; -use log::{error, info}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::{ - io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter}, - process::{ChildStdin, ChildStdout}, - sync::{ - mpsc::{unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, - Mutex, - }, -}; - -#[derive(Debug, Clone, Deserialize, Serialize)] -pub struct Request { - #[serde(skip)] - pub back_ch: Option>>, - pub seq: u64, - #[serde(rename = "type")] - pub msg_type: String, - pub command: String, - pub arguments: Option, -} - -#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] -pub struct Response { - pub seq: u64, - #[serde(rename = "type")] - pub msg_type: String, - pub request_seq: u64, - pub success: bool, - pub command: String, - pub message: Option, - pub body: Option, -} - -#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] -pub struct Event { - pub seq: u64, - #[serde(rename = "type")] - pub msg_type: String, - pub event: String, - pub body: Option, -} - -#[derive(Debug, Clone, Deserialize, Serialize)] -#[serde(untagged)] -pub enum Payload { - // type = "event" - Event(Event), - // type = "response" - Response(Response), - // type = "request" - Request(Request), -} - -#[derive(Debug)] -pub struct Transport { - id: usize, - pending_requests: Mutex>>>, -} - -impl Transport { - pub fn start( - server_stdout: BufReader, - server_stdin: BufWriter, - id: usize, - ) -> (UnboundedReceiver, UnboundedSender) { - let (client_tx, rx) = unbounded_channel(); - let (tx, client_rx) = unbounded_channel(); - - let transport = Self { - id, - pending_requests: Mutex::new(HashMap::default()), - }; - - let transport = Arc::new(transport); - - tokio::spawn(Self::recv(transport.clone(), server_stdout, client_tx)); - tokio::spawn(Self::send(transport, server_stdin, client_rx)); - - (rx, tx) - } - - async fn recv_server_message( - reader: &mut (impl AsyncBufRead + Unpin + Send), - buffer: &mut String, - ) -> Result { - let mut content_length = None; - loop { - buffer.truncate(0); - reader.read_line(buffer).await?; - let header = buffer.trim(); - - if header.is_empty() { - break; - } - - let mut parts = header.split(": "); - - match (parts.next(), parts.next(), parts.next()) { - (Some("Content-Length"), Some(value), None) => { - content_length = Some(value.parse().context("invalid content length")?); - } - (Some(_), Some(_), None) => {} - _ => { - return Err(std::io::Error::new( - std::io::ErrorKind::Other, - "Failed to parse header", - ) - .into()); - } - } - } - - let content_length = content_length.context("missing content length")?; - - //TODO: reuse vector - let mut content = vec![0; content_length]; - reader.read_exact(&mut content).await?; - let msg = std::str::from_utf8(&content).context("invalid utf8 from server")?; - - info!("<- DAP {}", msg); - - // try parsing as output (server response) or call (server request) - let output: serde_json::Result = serde_json::from_str(msg); - - Ok(output?) - } - - async fn send_payload_to_server( - &self, - server_stdin: &mut BufWriter, - req: Request, - ) -> Result<()> { - let json = serde_json::to_string(&req)?; - if let Some(back) = req.back_ch { - self.pending_requests.lock().await.insert(req.seq, back); - } - self.send_string_to_server(server_stdin, json).await - } - - async fn send_string_to_server( - &self, - server_stdin: &mut BufWriter, - request: String, - ) -> Result<()> { - info!("-> DAP {}", request); - - // send the headers - server_stdin - .write_all(format!("Content-Length: {}\r\n\r\n", request.len()).as_bytes()) - .await?; - - // send the body - server_stdin.write_all(request.as_bytes()).await?; - - server_stdin.flush().await?; - - Ok(()) - } - - async fn process_server_message( - &self, - client_tx: &UnboundedSender, - msg: Payload, - ) -> Result<()> { - let (id, result) = match msg { - Payload::Response(Response { - success: true, - seq, - request_seq, - .. - }) => { - info!("<- DAP success ({}, in response to {})", seq, request_seq); - if let Payload::Response(val) = msg { - (request_seq, Ok(val)) - } else { - unreachable!(); - } - } - Payload::Response(Response { - success: false, - message, - body, - request_seq, - command, - .. - }) => { - error!( - "<- DAP error {:?} ({:?}) for command #{} {}", - message, body, request_seq, command - ); - ( - request_seq, - Err(Error::Other(anyhow::format_err!("{:?}", body))), - ) - } - Payload::Request(Request { - ref command, - ref seq, - .. - }) => { - info!("<- DAP request {} #{}", command, seq); - client_tx.send(msg).expect("Failed to send"); - return Ok(()); - } - Payload::Event(Event { - ref event, ref seq, .. - }) => { - info!("<- DAP event {} #{}", event, seq); - client_tx.send(msg).expect("Failed to send"); - return Ok(()); - } - }; - - let tx = self - .pending_requests - .lock() - .await - .remove(&id) - .expect("pending_request with id not found!"); - - match tx.send(result).await { - Ok(_) => (), - Err(_) => error!( - "Tried sending response into a closed channel (id={:?}), original request likely timed out", - id - ), - }; - - Ok(()) - } - - async fn recv( - transport: Arc, - mut server_stdout: BufReader, - client_tx: UnboundedSender, - ) { - let mut recv_buffer = String::new(); - loop { - match Self::recv_server_message(&mut server_stdout, &mut recv_buffer).await { - Ok(msg) => { - transport - .process_server_message(&client_tx, msg) - .await - .unwrap(); - } - Err(err) => { - error!("err: <- {:?}", err); - break; - } - } - } - } - - async fn send( - transport: Arc, - mut server_stdin: BufWriter, - mut client_rx: UnboundedReceiver, - ) { - while let Some(req) = client_rx.recv().await { - transport - .send_payload_to_server(&mut server_stdin, req) - .await - .unwrap() - } - } -} +use crate::{Error, Result}; +use anyhow::Context; +use log::{error, info}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::{ + io::{AsyncBufRead, AsyncBufReadExt, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + sync::{ + mpsc::{unbounded_channel, Sender, UnboundedReceiver, UnboundedSender}, + Mutex, + }, +}; + +#[derive(Debug, Clone, Deserialize, Serialize)] +pub struct Request { + #[serde(skip)] + pub back_ch: Option>>, + pub seq: u64, + #[serde(rename = "type")] + pub msg_type: String, + pub command: String, + pub arguments: Option, +} + +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +pub struct Response { + pub seq: u64, + #[serde(rename = "type")] + pub msg_type: String, + pub request_seq: u64, + pub success: bool, + pub command: String, + pub message: Option, + pub body: Option, +} + +#[derive(Debug, PartialEq, Clone, Deserialize, Serialize)] +pub struct Event { + pub seq: u64, + #[serde(rename = "type")] + pub msg_type: String, + pub event: String, + pub body: Option, +} + +#[derive(Debug, Clone, Deserialize, Serialize)] +#[serde(untagged)] +pub enum Payload { + // type = "event" + Event(Event), + // type = "response" + Response(Response), + // type = "request" + Request(Request), +} + +#[derive(Debug)] +pub struct Transport { + id: usize, + pending_requests: Mutex>>>, +} + +impl Transport { + pub fn start( + server_stdout: Box, + server_stdin: Box, + id: usize, + ) -> (UnboundedReceiver, UnboundedSender) { + let (client_tx, rx) = unbounded_channel(); + let (tx, client_rx) = unbounded_channel(); + + let transport = Self { + id, + pending_requests: Mutex::new(HashMap::default()), + }; + + let transport = Arc::new(transport); + + tokio::spawn(Self::recv(transport.clone(), server_stdout, client_tx)); + tokio::spawn(Self::send(transport, server_stdin, client_rx)); + + (rx, tx) + } + + async fn recv_server_message( + reader: &mut Box, + buffer: &mut String, + ) -> Result { + let mut content_length = None; + loop { + buffer.truncate(0); + reader.read_line(buffer).await?; + let header = buffer.trim(); + + if header.is_empty() { + break; + } + + let mut parts = header.split(": "); + + match (parts.next(), parts.next(), parts.next()) { + (Some("Content-Length"), Some(value), None) => { + content_length = Some(value.parse().context("invalid content length")?); + } + (Some(_), Some(_), None) => {} + _ => { + return Err(std::io::Error::new( + std::io::ErrorKind::Other, + "Failed to parse header", + ) + .into()); + } + } + } + + let content_length = content_length.context("missing content length")?; + + //TODO: reuse vector + let mut content = vec![0; content_length]; + reader.read_exact(&mut content).await?; + let msg = std::str::from_utf8(&content).context("invalid utf8 from server")?; + + info!("<- DAP {}", msg); + + // try parsing as output (server response) or call (server request) + let output: serde_json::Result = serde_json::from_str(msg); + + Ok(output?) + } + + async fn send_payload_to_server( + &self, + server_stdin: &mut Box, + req: Request, + ) -> Result<()> { + let json = serde_json::to_string(&req)?; + if let Some(back) = req.back_ch { + self.pending_requests.lock().await.insert(req.seq, back); + } + self.send_string_to_server(server_stdin, json).await + } + + async fn send_string_to_server( + &self, + server_stdin: &mut Box, + request: String, + ) -> Result<()> { + info!("-> DAP {}", request); + + // send the headers + server_stdin + .write_all(format!("Content-Length: {}\r\n\r\n", request.len()).as_bytes()) + .await?; + + // send the body + server_stdin.write_all(request.as_bytes()).await?; + + server_stdin.flush().await?; + + Ok(()) + } + + async fn process_server_message( + &self, + client_tx: &UnboundedSender, + msg: Payload, + ) -> Result<()> { + let (id, result) = match msg { + Payload::Response(Response { + success: true, + seq, + request_seq, + .. + }) => { + info!("<- DAP success ({}, in response to {})", seq, request_seq); + if let Payload::Response(val) = msg { + (request_seq, Ok(val)) + } else { + unreachable!(); + } + } + Payload::Response(Response { + success: false, + message, + body, + request_seq, + command, + .. + }) => { + error!( + "<- DAP error {:?} ({:?}) for command #{} {}", + message, body, request_seq, command + ); + ( + request_seq, + Err(Error::Other(anyhow::format_err!("{:?}", body))), + ) + } + Payload::Request(Request { + ref command, + ref seq, + .. + }) => { + info!("<- DAP request {} #{}", command, seq); + client_tx.send(msg).expect("Failed to send"); + return Ok(()); + } + Payload::Event(Event { + ref event, ref seq, .. + }) => { + info!("<- DAP event {} #{}", event, seq); + client_tx.send(msg).expect("Failed to send"); + return Ok(()); + } + }; + + let tx = self + .pending_requests + .lock() + .await + .remove(&id) + .expect("pending_request with id not found!"); + + match tx.send(result).await { + Ok(_) => (), + Err(_) => error!( + "Tried sending response into a closed channel (id={:?}), original request likely timed out", + id + ), + }; + + Ok(()) + } + + async fn recv( + transport: Arc, + mut server_stdout: Box, + client_tx: UnboundedSender, + ) { + let mut recv_buffer = String::new(); + loop { + match Self::recv_server_message(&mut server_stdout, &mut recv_buffer).await { + Ok(msg) => { + transport + .process_server_message(&client_tx, msg) + .await + .unwrap(); + } + Err(err) => { + error!("err: <- {:?}", err); + break; + } + } + } + } + + async fn send( + transport: Arc, + mut server_stdin: Box, + mut client_rx: UnboundedReceiver, + ) { + while let Some(req) = client_rx.recv().await { + transport + .send_payload_to_server(&mut server_stdin, req) + .await + .unwrap() + } + } +}