diff --git a/Cargo.lock b/Cargo.lock index 20dd123fe..3b47b717e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4332,8 +4332,6 @@ version = "0.1.0" dependencies = [ "anyhow", "async-channel 2.2.0", - "async-io", - "async-net", "bytemuck", "chrono", "clap", @@ -4347,11 +4345,9 @@ dependencies = [ "fluent-templates", "fontdb", "futures", - "futures-lite 2.3.0", "gilrs", "image 0.25.1", "isahc", - "macro_rules_attribute", "os_info", "rfd", "ruffle_core", @@ -4360,7 +4356,6 @@ dependencies = [ "ruffle_render_wgpu", "ruffle_video_software", "sys-locale", - "tokio", "toml_edit 0.22.9", "tracing", "tracing-appender", @@ -4381,14 +4376,22 @@ name = "ruffle_frontend_utils" version = "0.1.0" dependencies = [ "async-channel 2.2.0", + "async-io", + "async-net", + "futures", + "futures-lite 2.3.0", + "isahc", + "macro_rules_attribute", "ruffle_core", "slotmap", "tempfile", "thiserror", + "tokio", "toml_edit 0.22.9", "tracing", "url", "urlencoding", + "webbrowser", "zip", ] diff --git a/desktop/Cargo.toml b/desktop/Cargo.toml index 8631b998b..b6a25b7f3 100644 --- a/desktop/Cargo.toml +++ b/desktop/Cargo.toml @@ -42,9 +42,6 @@ wgpu = { workspace = true } futures = { workspace = true } chrono = { workspace = true } fluent-templates = "0.9.1" -futures-lite = "2.3.0" -async-io = "2.3.2" -async-net = "2.0.0" async-channel = { workspace = true } toml_edit = { version = "0.22.9", features = ["parse"] } gilrs = "0.10" @@ -77,6 +74,3 @@ render_trace = ["ruffle_render_wgpu/render_trace"] # sandboxing sandbox = [] -[dev-dependencies] -macro_rules_attribute = "0.2.0" -tokio = { version = "1.36.0", features = ["macros", "rt"] } diff --git a/desktop/src/backends.rs b/desktop/src/backends.rs index ca7ae70bd..49cbcaae8 100644 --- a/desktop/src/backends.rs +++ b/desktop/src/backends.rs @@ -7,5 +7,5 @@ mod ui; pub use audio::CpalAudioBackend; pub use external_interface::DesktopExternalInterfaceProvider; pub use fscommand::DesktopFSCommandProvider; -pub use navigator::{ExternalNavigatorBackend, RfdNavigatorInterface}; +pub use navigator::RfdNavigatorInterface; pub use ui::DesktopUiBackend; diff --git a/desktop/src/backends/navigator.rs b/desktop/src/backends/navigator.rs index 013c9e883..74924c213 100644 --- a/desktop/src/backends/navigator.rs +++ b/desktop/src/backends/navigator.rs @@ -1,45 +1,10 @@ -//! Navigator backend for web - -use async_channel::{Receiver, Sender, TryRecvError}; -use async_io::Timer; -use async_net::TcpStream; -use futures::future::select; -use futures::{AsyncReadExt, AsyncWriteExt}; -use futures_lite::FutureExt; -use isahc::http::{HeaderName, HeaderValue}; -use isahc::{ - config::RedirectPolicy, prelude::*, AsyncBody, HttpClient, Request as IsahcRequest, - Response as IsahcResponse, -}; use rfd::{AsyncMessageDialog, MessageButtons, MessageDialog, MessageDialogResult, MessageLevel}; -use ruffle_core::backend::navigator::{ - async_return, create_fetch_error, ErrorResponse, NavigationMethod, NavigatorBackend, - OpenURLMode, OwnedFuture, Request, SocketMode, SuccessResponse, -}; -use ruffle_core::indexmap::IndexMap; -use ruffle_core::loader::Error; -use ruffle_core::socket::{ConnectionState, SocketAction, SocketHandle}; -use ruffle_frontend_utils::backends::executor::FutureSpawner; -use ruffle_frontend_utils::content::PlayingContent; -use std::collections::HashSet; +use ruffle_frontend_utils::backends::navigator::NavigatorInterface; use std::fs::File; use std::io; use std::io::ErrorKind; use std::path::Path; -use std::rc::Rc; -use std::str::FromStr; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tracing::warn; -use url::{ParseError, Url}; - -pub trait NavigatorInterface: Clone + 'static { - fn confirm_website_navigation(&self, url: &Url) -> bool; - - fn open_file(&self, path: &Path) -> io::Result; - - async fn confirm_socket(&self, host: &str, port: u16) -> bool; -} +use url::Url; #[derive(Clone)] pub struct RfdNavigatorInterface; @@ -88,836 +53,3 @@ impl NavigatorInterface for RfdNavigatorInterface { .await == MessageDialogResult::Yes } } - -/// Implementation of `NavigatorBackend` for non-web environments that can call -/// out to a web browser. -pub struct ExternalNavigatorBackend { - /// Sink for tasks sent to us through `spawn_future`. - future_spawner: F, - - /// The url to use for all relative fetches. - base_url: Url, - - // Client to use for network requests - client: Option>, - - socket_allowed: HashSet, - - socket_mode: SocketMode, - - upgrade_to_https: bool, - - open_url_mode: OpenURLMode, - - content: Rc, - - interface: I, -} - -impl ExternalNavigatorBackend { - /// Construct a navigator backend with fetch and async capability. - #[allow(clippy::too_many_arguments)] - pub fn new( - mut base_url: Url, - future_spawner: F, - proxy: Option, - upgrade_to_https: bool, - open_url_mode: OpenURLMode, - socket_allowed: HashSet, - socket_mode: SocketMode, - content: Rc, - interface: I, - ) -> Self { - let proxy = proxy.and_then(|url| url.as_str().parse().ok()); - let builder = HttpClient::builder() - .proxy(proxy) - .cookies() - .redirect_policy(RedirectPolicy::Follow); - - let client = builder.build().ok().map(Rc::new); - - // Force replace the last segment with empty. // - - if let Ok(mut base_url) = base_url.path_segments_mut() { - base_url.pop().pop_if_empty().push(""); - } - - Self { - future_spawner, - client, - base_url, - upgrade_to_https, - open_url_mode, - socket_allowed, - socket_mode, - content, - interface, - } - } -} - -impl NavigatorBackend for ExternalNavigatorBackend { - fn navigate_to_url( - &self, - url: &str, - _target: &str, - vars_method: Option<(NavigationMethod, IndexMap)>, - ) { - //TODO: Should we return a result for failed opens? Does Flash care? - - //NOTE: Flash desktop players / projectors ignore the window parameter, - // unless it's a `_layer`, and we shouldn't handle that anyway. - let mut parsed_url = match self.resolve_url(url) { - Ok(parsed_url) => parsed_url, - Err(e) => { - tracing::error!( - "Could not parse URL because of {}, the corrupt URL was: {}", - e, - url - ); - return; - } - }; - - let modified_url = match vars_method { - Some((_, query_pairs)) => { - { - //lifetime limiter because we don't have NLL yet - let mut modifier = parsed_url.query_pairs_mut(); - - for (k, v) in query_pairs.iter() { - modifier.append_pair(k, v); - } - } - - parsed_url - } - None => parsed_url, - }; - - if modified_url.scheme() == "javascript" { - tracing::warn!( - "SWF tried to run a script on desktop, but javascript calls are not allowed" - ); - return; - } - - if self.open_url_mode == OpenURLMode::Confirm { - if !self.interface.confirm_website_navigation(&modified_url) { - tracing::info!("SWF tried to open a website, but the user declined the request"); - return; - } - } else if self.open_url_mode == OpenURLMode::Deny { - tracing::warn!("SWF tried to open a website, but opening a website is not allowed"); - return; - } - - // If the user confirmed or if in Allow mode, open the website - - // TODO: This opens local files in the browser while flash opens them - // in the default program for the respective filetype. - // This especially includes mailto links. Ruffle opens the browser which opens - // the preferred program while flash opens the preferred program directly. - match webbrowser::open(modified_url.as_ref()) { - Ok(_output) => {} - Err(e) => tracing::error!("Could not open URL {}: {}", modified_url.as_str(), e), - }; - } - - fn fetch(&self, request: Request) -> OwnedFuture, ErrorResponse> { - enum DesktopResponseBody { - /// The response's body comes from a file. - File(Result, std::io::Error>), - - /// The response's body comes from the network. - /// - /// This has to be stored in shared ownership so that we can return - /// owned futures. A synchronous lock is used here as we do not - /// expect contention on this lock. - Network(Arc>>), - } - - struct DesktopResponse { - url: String, - response_body: DesktopResponseBody, - status: u16, - redirected: bool, - } - - impl SuccessResponse for DesktopResponse { - fn url(&self) -> std::borrow::Cow { - std::borrow::Cow::Borrowed(&self.url) - } - - #[allow(clippy::await_holding_lock)] - fn body(self: Box) -> OwnedFuture, Error> { - match self.response_body { - DesktopResponseBody::File(file) => { - Box::pin(async move { file.map_err(|e| Error::FetchError(e.to_string())) }) - } - DesktopResponseBody::Network(response) => Box::pin(async move { - let mut body = vec![]; - response - .lock() - .expect("working lock during fetch body read") - .copy_to(&mut body) - .await - .map_err(|e| Error::FetchError(e.to_string()))?; - - Ok(body) - }), - } - } - - fn status(&self) -> u16 { - self.status - } - - fn redirected(&self) -> bool { - self.redirected - } - - #[allow(clippy::await_holding_lock)] - fn next_chunk(&mut self) -> OwnedFuture>, Error> { - match &mut self.response_body { - DesktopResponseBody::File(file) => { - let res = file - .as_mut() - .map(std::mem::take) - .map_err(|e| Error::FetchError(e.to_string())); - - Box::pin(async move { - match res { - Ok(bytes) if !bytes.is_empty() => Ok(Some(bytes)), - Ok(_) => Ok(None), - Err(e) => Err(e), - } - }) - } - DesktopResponseBody::Network(response) => { - let response = response.clone(); - - Box::pin(async move { - let mut buf = vec![0; 4096]; - let lock = response.try_lock(); - if matches!(lock, Err(std::sync::TryLockError::WouldBlock)) { - return Err(Error::FetchError( - "Concurrent read operations on the same stream are not supported." - .to_string(), - )); - } - - let result = lock - .expect("desktop network lock") - .body_mut() - .read(&mut buf) - .await; - - match result { - Ok(count) if count > 0 => { - buf.resize(count, 0); - Ok(Some(buf)) - } - Ok(_) => Ok(None), - Err(e) => Err(Error::FetchError(e.to_string())), - } - }) - } - } - } - - fn expected_length(&self) -> Result, Error> { - match &self.response_body { - DesktopResponseBody::File(file) => { - Ok(file.as_ref().map(|file| file.len() as u64).ok()) - } - DesktopResponseBody::Network(response) => { - let response = response.lock().expect("no recursive locks"); - let content_length = response.headers().get("Content-Length"); - - if let Some(len) = content_length { - Ok(Some( - len.to_str() - .map_err(|_| Error::InvalidHeaderValue)? - .parse::()?, - )) - } else { - Ok(None) - } - } - } - } - } - - // TODO: honor sandbox type (local-with-filesystem, local-with-network, remote, ...) - let mut processed_url = match self.resolve_url(request.url()) { - Ok(url) => url, - Err(e) => { - return async_return(create_fetch_error(request.url(), e)); - } - }; - - let client = self.client.clone(); - - match processed_url.scheme() { - "file" => { - let content = self.content.clone(); - let interface = self.interface.clone(); - Box::pin(async move { - // We send the original url (including query parameters) - // back to ruffle_core in the `Response` - let response_url = processed_url.clone(); - // Flash supports query parameters with local urls. - // SwfMovie takes care of exposing those to ActionScript - - // when we actually load a filesystem url, strip them out. - processed_url.set_query(None); - - let contents = - content.get_local_file(&processed_url, |path| interface.open_file(path)); - - let response: Box = Box::new(DesktopResponse { - url: response_url.to_string(), - response_body: DesktopResponseBody::File(contents), - status: 0, - redirected: false, - }); - - Ok(response) - }) - } - _ => Box::pin(async move { - let client = client.ok_or_else(|| ErrorResponse { - url: processed_url.to_string(), - error: Error::FetchError("Network unavailable".to_string()), - })?; - - let mut isahc_request = match request.method() { - NavigationMethod::Get => IsahcRequest::get(processed_url.to_string()), - NavigationMethod::Post => IsahcRequest::post(processed_url.to_string()), - }; - let (body_data, mime) = request.body().clone().unwrap_or_default(); - if let Some(headers) = isahc_request.headers_mut() { - for (name, val) in request.headers().iter() { - headers.insert( - HeaderName::from_str(name).map_err(|e| ErrorResponse { - url: processed_url.to_string(), - error: Error::FetchError(e.to_string()), - })?, - HeaderValue::from_str(val).map_err(|e| ErrorResponse { - url: processed_url.to_string(), - error: Error::FetchError(e.to_string()), - })?, - ); - } - headers.insert( - "Content-Type", - HeaderValue::from_str(&mime).map_err(|e| ErrorResponse { - url: processed_url.to_string(), - error: Error::FetchError(e.to_string()), - })?, - ); - } - - let body = isahc_request.body(body_data).map_err(|e| ErrorResponse { - url: processed_url.to_string(), - error: Error::FetchError(e.to_string()), - })?; - - let response = client.send_async(body).await.map_err(|e| { - let inner = match e.kind() { - isahc::error::ErrorKind::NameResolution => { - Error::InvalidDomain(processed_url.to_string()) - } - _ => Error::FetchError(e.to_string()), - }; - ErrorResponse { - url: processed_url.to_string(), - error: inner, - } - })?; - - let url = if let Some(uri) = response.effective_uri() { - uri.to_string() - } else { - processed_url.into() - }; - - let status = response.status().as_u16(); - let redirected = response.effective_uri().is_some(); - if !response.status().is_success() { - let error = Error::HttpNotOk( - format!("HTTP status is not ok, got {}", response.status()), - status, - redirected, - response.body().len().unwrap_or(0), - ); - return Err(ErrorResponse { url, error }); - } - - let response: Box = Box::new(DesktopResponse { - url, - response_body: DesktopResponseBody::Network(Arc::new(Mutex::new(response))), - status, - redirected, - }); - Ok(response) - }), - } - } - - fn resolve_url(&self, url: &str) -> Result { - match self.base_url.join(url) { - Ok(url) => Ok(self.pre_process_url(url)), - Err(error) => Err(error), - } - } - - fn spawn_future(&mut self, future: OwnedFuture<(), Error>) { - self.future_spawner.spawn(future); - } - - fn pre_process_url(&self, mut url: Url) -> Url { - if self.upgrade_to_https && url.scheme() == "http" && url.set_scheme("https").is_err() { - tracing::error!("Url::set_scheme failed on: {}", url); - } - url - } - - fn connect_socket( - &mut self, - host: String, - port: u16, - timeout: Duration, - handle: SocketHandle, - receiver: Receiver>, - sender: Sender, - ) { - let addr = format!("{}:{}", host, port); - let is_allowed = self.socket_allowed.contains(&addr); - let socket_mode = self.socket_mode; - let interface = self.interface.clone(); - - let future = Box::pin(async move { - match (is_allowed, socket_mode) { - (false, SocketMode::Allow) | (true, _) => {} // the process is allowed to continue. just dont do anything. - (false, SocketMode::Deny) => { - // Just fail the connection. - sender - .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) - .expect("working channel send"); - - tracing::warn!( - "SWF tried to open a socket, but opening a socket is not allowed" - ); - - return Ok(()); - } - (false, SocketMode::Ask) => { - let attempt_sandbox_connect = interface.confirm_socket(&host, port).await; - - if !attempt_sandbox_connect { - // fail the connection. - sender - .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) - .expect("working channel send"); - - return Ok(()); - } - } - } - - let host2 = host.clone(); - - let timeout = async { - Timer::after(timeout).await; - Result::::Err(io::Error::new(ErrorKind::TimedOut, "")) - }; - - let stream = match TcpStream::connect((host, port)).or(timeout).await { - Err(e) if e.kind() == ErrorKind::TimedOut => { - warn!("Connection to {}:{} timed out", host2, port); - sender - .try_send(SocketAction::Connect(handle, ConnectionState::TimedOut)) - .expect("working channel send"); - return Ok(()); - } - Ok(stream) => { - sender - .try_send(SocketAction::Connect(handle, ConnectionState::Connected)) - .expect("working channel send"); - - stream - } - Err(err) => { - warn!("Failed to connect to {}:{}, error: {}", host2, port, err); - sender - .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) - .expect("working channel send"); - return Ok(()); - } - }; - - let sender = sender; - //NOTE: We clone the sender here as we cant share it between async tasks. - let sender2 = sender.clone(); - let (mut read, mut write) = stream.split(); - - let read = std::pin::pin!(async move { - loop { - let mut buffer = [0; 4096]; - - match read.read(&mut buffer).await { - Err(e) if e.kind() == ErrorKind::TimedOut => {} // try again later. - Err(_) | Ok(0) => { - sender - .try_send(SocketAction::Close(handle)) - .expect("working channel send"); - drop(read); - break; - } - Ok(read) => { - let buffer = buffer.into_iter().take(read).collect::>(); - - sender - .try_send(SocketAction::Data(handle, buffer)) - .expect("working channel send"); - } - }; - } - }); - - let write = std::pin::pin!(async move { - let mut pending_write = vec![]; - - loop { - let close_connection = loop { - match receiver.try_recv() { - Ok(val) => { - pending_write.extend(val); - } - Err(TryRecvError::Empty) => break false, - Err(TryRecvError::Closed) => { - //NOTE: Channel sender has been dropped. - // This means we have to close the connection, - // but not here, as we might have a pending write. - break true; - } - } - }; - - if !pending_write.is_empty() { - match write.write(&pending_write).await { - Err(e) if e.kind() == ErrorKind::TimedOut => {} // try again later. - Err(_) => { - sender2 - .try_send(SocketAction::Close(handle)) - .expect("working channel send"); - drop(write); - return; - } - Ok(written) => { - let _ = pending_write.drain(..written); - } - } - } else if close_connection { - drop(write); - return; - } else { - // Receiver is empty and there's no pending data, - // we may block here and wait for new data. - match receiver.recv().await { - Ok(val) => { - pending_write.extend(val); - } - Err(_) => { - // Ignore the error here, it will be - // reported again in try_recv. - } - } - } - } - }); - - //NOTE: If one future exits, this will take the other one down too. - select(read, write).await; - - Ok(()) - }); - - self.spawn_future(future); - } -} - -#[cfg(test)] -#[allow(clippy::unwrap_used)] -mod tests { - use async_net::TcpListener; - use ruffle_core::socket::SocketAction::{Close, Connect, Data}; - use std::net::SocketAddr; - use tokio::task; - - use super::*; - - const TIMEOUT_ZERO: Duration = Duration::ZERO; - // The timeout has to be large enough to allow "instantaneous" actions - // and local IO to execute, but small enough to fail tests quickly. - const TIMEOUT: Duration = Duration::from_secs(1); - - struct TestFutureSpawner; - - impl FutureSpawner for TestFutureSpawner { - fn spawn(&self, future: OwnedFuture<(), Error>) { - task::spawn_local(future); - } - } - - macro_rules! async_timeout { - () => { - async { - Timer::after(TIMEOUT).await; - panic!("An action which should complete timed out") - } - }; - } - - macro_rules! async_test { - ( - async fn $test_name:ident() $content:block - ) => { - #[tokio::test(flavor = "current_thread")] - async fn $test_name() { - task::LocalSet::new().run_until(async move $content).await; - } - } - } - - macro_rules! dummy_handle { - () => { - SocketHandle::default() - }; - } - - macro_rules! assert_next_socket_actions { - ($receiver:expr;) => { - // no more actions - }; - ($receiver:expr; $action:expr, $($more:expr,)*) => { - assert_eq!($receiver.recv().or(async_timeout!()).await.expect("receive action"), $action); - assert_next_socket_actions!($receiver; $($more,)*); - }; - } - - fn new_test_backend( - socket_allow: bool, - ) -> ExternalNavigatorBackend { - let url = Url::parse("https://example.com/path/").unwrap(); - ExternalNavigatorBackend::new( - url.clone(), - TestFutureSpawner, - None, - false, - OpenURLMode::Allow, - Default::default(), - if socket_allow { - SocketMode::Allow - } else { - SocketMode::Deny - }, - Rc::new(PlayingContent::DirectFile(url)), - RfdNavigatorInterface, - ) - } - - async fn start_test_server() -> (task::JoinHandle, SocketAddr) { - let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); - let addr = listener.local_addr().unwrap(); - - let accept_task = task::spawn_local(async move { - let (socket, _) = listener.accept().or(async_timeout!()).await.unwrap(); - socket - }); - (accept_task, addr) - } - - fn connect_test_socket( - addr: SocketAddr, - timeout: Duration, - socket_allow: bool, - ) -> (Sender>, Receiver) { - let mut backend = new_test_backend(socket_allow); - - let (write, receiver) = async_channel::unbounded(); - let (sender, read) = async_channel::unbounded(); - - backend.connect_socket( - addr.ip().to_string(), - addr.port(), - timeout, - dummy_handle!(), - receiver, - sender, - ); - - (write, read) - } - - async fn write_server(server_socket: &mut TcpStream, data: &str) { - server_socket - .write(data.as_bytes()) - .or(async_timeout!()) - .await - .expect("server write"); - } - - async fn read_server(server_socket: &mut TcpStream) -> String { - let mut buffer = [0; 4096]; - - let read = match server_socket.read(&mut buffer).await { - Err(e) => { - panic!("server read error: {}", e); - } - Ok(read) => read, - }; - - let buffer = buffer.into_iter().take(read).collect::>(); - String::from_utf8(buffer).unwrap() - } - - async fn write_client(client_write: &Sender>, data: &str) { - client_write - .send(data.as_bytes().to_vec()) - .or(async_timeout!()) - .await - .expect("client write"); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_timeout() { - let (_accept_task, addr) = start_test_server().await; - let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT_ZERO, true); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::TimedOut), - ); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_connect() { - let (accept_task, addr) = start_test_server().await; - let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - let _server_socket = accept_task.await.unwrap(); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Connected), - ); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_deny() { - let (_accept_task, addr) = start_test_server().await; - let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, false); - - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Failed), - ); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_fail() { - let addr = SocketAddr::from_str("[100::]:42").expect("black hole address"); - let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Failed), - ); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_server_close() { - let (accept_task, addr) = start_test_server().await; - let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - - let server_socket = accept_task.await.unwrap(); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Connected), - ); - - drop(server_socket); - - assert_next_socket_actions!( - client_read; - Close(dummy_handle!()), - ); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_client_close() { - let (accept_task, addr) = start_test_server().await; - let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - - let mut server_socket = accept_task.await.unwrap(); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Connected), - ); - - drop(client_write); - - assert_eq!(read_server(&mut server_socket).await, ""); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_basic_communication() { - let (accept_task, addr) = start_test_server().await; - let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - - let mut server_socket = accept_task.await.unwrap(); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Connected), - ); - - write_server(&mut server_socket, "Hello ").await; - write_server(&mut server_socket, "World!").await; - - assert_next_socket_actions!( - client_read; - Data(dummy_handle!(), "Hello World!".as_bytes().to_vec()), - ); - - write_client(&client_write, "Hello from").await; - write_client(&client_write, " client").await; - - assert_eq!(read_server(&mut server_socket).await, "Hello from client"); - - write_server(&mut server_socket, "from server 2").await; - write_client(&client_write, "from client 2").await; - - assert_next_socket_actions!( - client_read; - Data(dummy_handle!(), "from server 2".as_bytes().to_vec()), - ); - assert_eq!(read_server(&mut server_socket).await, "from client 2"); - } - - #[macro_rules_attribute::apply(async_test)] - async fn test_socket_flush_before_close() { - let (accept_task, addr) = start_test_server().await; - let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); - - let mut server_socket = accept_task.await.unwrap(); - assert_next_socket_actions!( - client_read; - Connect(dummy_handle!(), ConnectionState::Connected), - ); - - write_client(&client_write, "Sending some").await; - write_client(&client_write, " data").await; - client_write.close(); - - assert_eq!(read_server(&mut server_socket).await, "Sending some data"); - } -} diff --git a/desktop/src/player.rs b/desktop/src/player.rs index dbb506cc6..832a70af6 100644 --- a/desktop/src/player.rs +++ b/desktop/src/player.rs @@ -1,6 +1,6 @@ use crate::backends::{ CpalAudioBackend, DesktopExternalInterfaceProvider, DesktopFSCommandProvider, DesktopUiBackend, - ExternalNavigatorBackend, RfdNavigatorInterface, + RfdNavigatorInterface, }; use crate::custom_event::RuffleEvent; use crate::gui::MovieView; @@ -15,6 +15,7 @@ use ruffle_core::{ StageScaleMode, }; use ruffle_frontend_utils::backends::executor::{AsyncExecutor, PollRequester}; +use ruffle_frontend_utils::backends::navigator::ExternalNavigatorBackend; use ruffle_frontend_utils::bundle::source::BundleSourceError; use ruffle_frontend_utils::bundle::{Bundle, BundleError}; use ruffle_frontend_utils::content::PlayingContent; diff --git a/frontend-utils/Cargo.toml b/frontend-utils/Cargo.toml index bfa2da0bc..9a760e47f 100644 --- a/frontend-utils/Cargo.toml +++ b/frontend-utils/Cargo.toml @@ -20,6 +20,14 @@ urlencoding = "2.1.3" ruffle_core = { path = "../core", default-features = false } async-channel = { workspace = true } slotmap = { workspace = true } +isahc = { version = "1.7.2", features = ["cookies"] } +futures = { workspace = true } +async-io = "2.3.2" +async-net = "2.0.0" +futures-lite = "2.3.0" +webbrowser = "0.8.14" [dev-dependencies] tempfile = "3" +tokio = { version = "1.36.0", features = ["macros", "rt"] } +macro_rules_attribute = "0.2.0" diff --git a/frontend-utils/src/backends.rs b/frontend-utils/src/backends.rs index f3c21109f..6b062080e 100644 --- a/frontend-utils/src/backends.rs +++ b/frontend-utils/src/backends.rs @@ -1,2 +1,4 @@ pub mod executor; pub mod storage; + +pub mod navigator; diff --git a/frontend-utils/src/backends/navigator.rs b/frontend-utils/src/backends/navigator.rs new file mode 100644 index 000000000..851bbf945 --- /dev/null +++ b/frontend-utils/src/backends/navigator.rs @@ -0,0 +1,890 @@ +//! Navigator backend for web + +use crate::backends::executor::FutureSpawner; +use crate::content::PlayingContent; +use async_channel::{Receiver, Sender, TryRecvError}; +use async_io::Timer; +use async_net::TcpStream; +use futures::future::select; +use futures::{AsyncReadExt, AsyncWriteExt}; +use futures_lite::FutureExt; +use isahc::http::{HeaderName, HeaderValue}; +use isahc::{ + config::RedirectPolicy, prelude::*, AsyncBody, HttpClient, Request as IsahcRequest, + Response as IsahcResponse, +}; +use ruffle_core::backend::navigator::{ + async_return, create_fetch_error, ErrorResponse, NavigationMethod, NavigatorBackend, + OpenURLMode, OwnedFuture, Request, SocketMode, SuccessResponse, +}; +use ruffle_core::indexmap::IndexMap; +use ruffle_core::loader::Error; +use ruffle_core::socket::{ConnectionState, SocketAction, SocketHandle}; +use std::collections::HashSet; +use std::fs::File; +use std::io; +use std::io::ErrorKind; +use std::path::Path; +use std::rc::Rc; +use std::str::FromStr; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tracing::warn; +use url::{ParseError, Url}; + +pub trait NavigatorInterface: Clone + 'static { + fn confirm_website_navigation(&self, url: &Url) -> bool; + + fn open_file(&self, path: &Path) -> io::Result; + + fn confirm_socket( + &self, + host: &str, + port: u16, + ) -> impl std::future::Future + Send; +} + +/// Implementation of `NavigatorBackend` for non-web environments that can call +/// out to a web browser. +pub struct ExternalNavigatorBackend { + /// Sink for tasks sent to us through `spawn_future`. + future_spawner: F, + + /// The url to use for all relative fetches. + base_url: Url, + + // Client to use for network requests + client: Option>, + + socket_allowed: HashSet, + + socket_mode: SocketMode, + + upgrade_to_https: bool, + + open_url_mode: OpenURLMode, + + content: Rc, + + interface: I, +} + +impl ExternalNavigatorBackend { + /// Construct a navigator backend with fetch and async capability. + #[allow(clippy::too_many_arguments)] + pub fn new( + mut base_url: Url, + future_spawner: F, + proxy: Option, + upgrade_to_https: bool, + open_url_mode: OpenURLMode, + socket_allowed: HashSet, + socket_mode: SocketMode, + content: Rc, + interface: I, + ) -> Self { + let proxy = proxy.and_then(|url| url.as_str().parse().ok()); + let builder = HttpClient::builder() + .proxy(proxy) + .cookies() + .redirect_policy(RedirectPolicy::Follow); + + let client = builder.build().ok().map(Rc::new); + + // Force replace the last segment with empty. // + + if let Ok(mut base_url) = base_url.path_segments_mut() { + base_url.pop().pop_if_empty().push(""); + } + + Self { + future_spawner, + client, + base_url, + upgrade_to_https, + open_url_mode, + socket_allowed, + socket_mode, + content, + interface, + } + } +} + +impl NavigatorBackend for ExternalNavigatorBackend { + fn navigate_to_url( + &self, + url: &str, + _target: &str, + vars_method: Option<(NavigationMethod, IndexMap)>, + ) { + //TODO: Should we return a result for failed opens? Does Flash care? + + //NOTE: Flash desktop players / projectors ignore the window parameter, + // unless it's a `_layer`, and we shouldn't handle that anyway. + let mut parsed_url = match self.resolve_url(url) { + Ok(parsed_url) => parsed_url, + Err(e) => { + tracing::error!( + "Could not parse URL because of {}, the corrupt URL was: {}", + e, + url + ); + return; + } + }; + + let modified_url = match vars_method { + Some((_, query_pairs)) => { + { + //lifetime limiter because we don't have NLL yet + let mut modifier = parsed_url.query_pairs_mut(); + + for (k, v) in query_pairs.iter() { + modifier.append_pair(k, v); + } + } + + parsed_url + } + None => parsed_url, + }; + + if modified_url.scheme() == "javascript" { + tracing::warn!( + "SWF tried to run a script on desktop, but javascript calls are not allowed" + ); + return; + } + + if self.open_url_mode == OpenURLMode::Confirm { + if !self.interface.confirm_website_navigation(&modified_url) { + tracing::info!("SWF tried to open a website, but the user declined the request"); + return; + } + } else if self.open_url_mode == OpenURLMode::Deny { + tracing::warn!("SWF tried to open a website, but opening a website is not allowed"); + return; + } + + // If the user confirmed or if in Allow mode, open the website + + // TODO: This opens local files in the browser while flash opens them + // in the default program for the respective filetype. + // This especially includes mailto links. Ruffle opens the browser which opens + // the preferred program while flash opens the preferred program directly. + match webbrowser::open(modified_url.as_ref()) { + Ok(_output) => {} + Err(e) => tracing::error!("Could not open URL {}: {}", modified_url.as_str(), e), + }; + } + + fn fetch(&self, request: Request) -> OwnedFuture, ErrorResponse> { + enum DesktopResponseBody { + /// The response's body comes from a file. + File(Result, std::io::Error>), + + /// The response's body comes from the network. + /// + /// This has to be stored in shared ownership so that we can return + /// owned futures. A synchronous lock is used here as we do not + /// expect contention on this lock. + Network(Arc>>), + } + + struct DesktopResponse { + url: String, + response_body: DesktopResponseBody, + status: u16, + redirected: bool, + } + + impl SuccessResponse for DesktopResponse { + fn url(&self) -> std::borrow::Cow { + std::borrow::Cow::Borrowed(&self.url) + } + + #[allow(clippy::await_holding_lock)] + fn body(self: Box) -> OwnedFuture, Error> { + match self.response_body { + DesktopResponseBody::File(file) => { + Box::pin(async move { file.map_err(|e| Error::FetchError(e.to_string())) }) + } + DesktopResponseBody::Network(response) => Box::pin(async move { + let mut body = vec![]; + response + .lock() + .expect("working lock during fetch body read") + .copy_to(&mut body) + .await + .map_err(|e| Error::FetchError(e.to_string()))?; + + Ok(body) + }), + } + } + + fn status(&self) -> u16 { + self.status + } + + fn redirected(&self) -> bool { + self.redirected + } + + #[allow(clippy::await_holding_lock)] + fn next_chunk(&mut self) -> OwnedFuture>, Error> { + match &mut self.response_body { + DesktopResponseBody::File(file) => { + let res = file + .as_mut() + .map(std::mem::take) + .map_err(|e| Error::FetchError(e.to_string())); + + Box::pin(async move { + match res { + Ok(bytes) if !bytes.is_empty() => Ok(Some(bytes)), + Ok(_) => Ok(None), + Err(e) => Err(e), + } + }) + } + DesktopResponseBody::Network(response) => { + let response = response.clone(); + + Box::pin(async move { + let mut buf = vec![0; 4096]; + let lock = response.try_lock(); + if matches!(lock, Err(std::sync::TryLockError::WouldBlock)) { + return Err(Error::FetchError( + "Concurrent read operations on the same stream are not supported." + .to_string(), + )); + } + + let result = lock + .expect("desktop network lock") + .body_mut() + .read(&mut buf) + .await; + + match result { + Ok(count) if count > 0 => { + buf.resize(count, 0); + Ok(Some(buf)) + } + Ok(_) => Ok(None), + Err(e) => Err(Error::FetchError(e.to_string())), + } + }) + } + } + } + + fn expected_length(&self) -> Result, Error> { + match &self.response_body { + DesktopResponseBody::File(file) => { + Ok(file.as_ref().map(|file| file.len() as u64).ok()) + } + DesktopResponseBody::Network(response) => { + let response = response.lock().expect("no recursive locks"); + let content_length = response.headers().get("Content-Length"); + + if let Some(len) = content_length { + Ok(Some( + len.to_str() + .map_err(|_| Error::InvalidHeaderValue)? + .parse::()?, + )) + } else { + Ok(None) + } + } + } + } + } + + // TODO: honor sandbox type (local-with-filesystem, local-with-network, remote, ...) + let mut processed_url = match self.resolve_url(request.url()) { + Ok(url) => url, + Err(e) => { + return async_return(create_fetch_error(request.url(), e)); + } + }; + + let client = self.client.clone(); + + match processed_url.scheme() { + "file" => { + let content = self.content.clone(); + let interface = self.interface.clone(); + Box::pin(async move { + // We send the original url (including query parameters) + // back to ruffle_core in the `Response` + let response_url = processed_url.clone(); + // Flash supports query parameters with local urls. + // SwfMovie takes care of exposing those to ActionScript - + // when we actually load a filesystem url, strip them out. + processed_url.set_query(None); + + let contents = + content.get_local_file(&processed_url, |path| interface.open_file(path)); + + let response: Box = Box::new(DesktopResponse { + url: response_url.to_string(), + response_body: DesktopResponseBody::File(contents), + status: 0, + redirected: false, + }); + + Ok(response) + }) + } + _ => Box::pin(async move { + let client = client.ok_or_else(|| ErrorResponse { + url: processed_url.to_string(), + error: Error::FetchError("Network unavailable".to_string()), + })?; + + let mut isahc_request = match request.method() { + NavigationMethod::Get => IsahcRequest::get(processed_url.to_string()), + NavigationMethod::Post => IsahcRequest::post(processed_url.to_string()), + }; + let (body_data, mime) = request.body().clone().unwrap_or_default(); + if let Some(headers) = isahc_request.headers_mut() { + for (name, val) in request.headers().iter() { + headers.insert( + HeaderName::from_str(name).map_err(|e| ErrorResponse { + url: processed_url.to_string(), + error: Error::FetchError(e.to_string()), + })?, + HeaderValue::from_str(val).map_err(|e| ErrorResponse { + url: processed_url.to_string(), + error: Error::FetchError(e.to_string()), + })?, + ); + } + headers.insert( + "Content-Type", + HeaderValue::from_str(&mime).map_err(|e| ErrorResponse { + url: processed_url.to_string(), + error: Error::FetchError(e.to_string()), + })?, + ); + } + + let body = isahc_request.body(body_data).map_err(|e| ErrorResponse { + url: processed_url.to_string(), + error: Error::FetchError(e.to_string()), + })?; + + let response = client.send_async(body).await.map_err(|e| { + let inner = match e.kind() { + isahc::error::ErrorKind::NameResolution => { + Error::InvalidDomain(processed_url.to_string()) + } + _ => Error::FetchError(e.to_string()), + }; + ErrorResponse { + url: processed_url.to_string(), + error: inner, + } + })?; + + let url = if let Some(uri) = response.effective_uri() { + uri.to_string() + } else { + processed_url.into() + }; + + let status = response.status().as_u16(); + let redirected = response.effective_uri().is_some(); + if !response.status().is_success() { + let error = Error::HttpNotOk( + format!("HTTP status is not ok, got {}", response.status()), + status, + redirected, + response.body().len().unwrap_or(0), + ); + return Err(ErrorResponse { url, error }); + } + + let response: Box = Box::new(DesktopResponse { + url, + response_body: DesktopResponseBody::Network(Arc::new(Mutex::new(response))), + status, + redirected, + }); + Ok(response) + }), + } + } + + fn resolve_url(&self, url: &str) -> Result { + match self.base_url.join(url) { + Ok(url) => Ok(self.pre_process_url(url)), + Err(error) => Err(error), + } + } + + fn spawn_future(&mut self, future: OwnedFuture<(), Error>) { + self.future_spawner.spawn(future); + } + + fn pre_process_url(&self, mut url: Url) -> Url { + if self.upgrade_to_https && url.scheme() == "http" && url.set_scheme("https").is_err() { + tracing::error!("Url::set_scheme failed on: {}", url); + } + url + } + + fn connect_socket( + &mut self, + host: String, + port: u16, + timeout: Duration, + handle: SocketHandle, + receiver: Receiver>, + sender: Sender, + ) { + let addr = format!("{}:{}", host, port); + let is_allowed = self.socket_allowed.contains(&addr); + let socket_mode = self.socket_mode; + let interface = self.interface.clone(); + + let future = Box::pin(async move { + match (is_allowed, socket_mode) { + (false, SocketMode::Allow) | (true, _) => {} // the process is allowed to continue. just dont do anything. + (false, SocketMode::Deny) => { + // Just fail the connection. + sender + .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) + .expect("working channel send"); + + tracing::warn!( + "SWF tried to open a socket, but opening a socket is not allowed" + ); + + return Ok(()); + } + (false, SocketMode::Ask) => { + let attempt_sandbox_connect = interface.confirm_socket(&host, port).await; + + if !attempt_sandbox_connect { + // fail the connection. + sender + .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) + .expect("working channel send"); + + return Ok(()); + } + } + } + + let host2 = host.clone(); + + let timeout = async { + Timer::after(timeout).await; + Result::::Err(io::Error::new(ErrorKind::TimedOut, "")) + }; + + let stream = match TcpStream::connect((host, port)).or(timeout).await { + Err(e) if e.kind() == ErrorKind::TimedOut => { + warn!("Connection to {}:{} timed out", host2, port); + sender + .try_send(SocketAction::Connect(handle, ConnectionState::TimedOut)) + .expect("working channel send"); + return Ok(()); + } + Ok(stream) => { + sender + .try_send(SocketAction::Connect(handle, ConnectionState::Connected)) + .expect("working channel send"); + + stream + } + Err(err) => { + warn!("Failed to connect to {}:{}, error: {}", host2, port, err); + sender + .try_send(SocketAction::Connect(handle, ConnectionState::Failed)) + .expect("working channel send"); + return Ok(()); + } + }; + + let sender = sender; + //NOTE: We clone the sender here as we cant share it between async tasks. + let sender2 = sender.clone(); + let (mut read, mut write) = stream.split(); + + let read = std::pin::pin!(async move { + loop { + let mut buffer = [0; 4096]; + + match read.read(&mut buffer).await { + Err(e) if e.kind() == ErrorKind::TimedOut => {} // try again later. + Err(_) | Ok(0) => { + sender + .try_send(SocketAction::Close(handle)) + .expect("working channel send"); + drop(read); + break; + } + Ok(read) => { + let buffer = buffer.into_iter().take(read).collect::>(); + + sender + .try_send(SocketAction::Data(handle, buffer)) + .expect("working channel send"); + } + }; + } + }); + + let write = std::pin::pin!(async move { + let mut pending_write = vec![]; + + loop { + let close_connection = loop { + match receiver.try_recv() { + Ok(val) => { + pending_write.extend(val); + } + Err(TryRecvError::Empty) => break false, + Err(TryRecvError::Closed) => { + //NOTE: Channel sender has been dropped. + // This means we have to close the connection, + // but not here, as we might have a pending write. + break true; + } + } + }; + + if !pending_write.is_empty() { + match write.write(&pending_write).await { + Err(e) if e.kind() == ErrorKind::TimedOut => {} // try again later. + Err(_) => { + sender2 + .try_send(SocketAction::Close(handle)) + .expect("working channel send"); + drop(write); + return; + } + Ok(written) => { + let _ = pending_write.drain(..written); + } + } + } else if close_connection { + drop(write); + return; + } else { + // Receiver is empty and there's no pending data, + // we may block here and wait for new data. + match receiver.recv().await { + Ok(val) => { + pending_write.extend(val); + } + Err(_) => { + // Ignore the error here, it will be + // reported again in try_recv. + } + } + } + } + }); + + //NOTE: If one future exits, this will take the other one down too. + select(read, write).await; + + Ok(()) + }); + + self.spawn_future(future); + } +} + +#[cfg(test)] +#[allow(clippy::unwrap_used)] +mod tests { + use async_net::TcpListener; + use ruffle_core::socket::SocketAction::{Close, Connect, Data}; + use std::net::SocketAddr; + use tokio::task; + + use super::*; + + impl NavigatorInterface for () { + fn confirm_website_navigation(&self, _url: &Url) -> bool { + true + } + + fn open_file(&self, path: &Path) -> io::Result { + File::open(path) + } + + async fn confirm_socket(&self, _host: &str, _port: u16) -> bool { + true + } + } + + const TIMEOUT_ZERO: Duration = Duration::ZERO; + // The timeout has to be large enough to allow "instantaneous" actions + // and local IO to execute, but small enough to fail tests quickly. + const TIMEOUT: Duration = Duration::from_secs(1); + + struct TestFutureSpawner; + + impl FutureSpawner for TestFutureSpawner { + fn spawn(&self, future: OwnedFuture<(), Error>) { + task::spawn_local(future); + } + } + + macro_rules! async_timeout { + () => { + async { + Timer::after(TIMEOUT).await; + panic!("An action which should complete timed out") + } + }; + } + + macro_rules! async_test { + ( + async fn $test_name:ident() $content:block + ) => { + #[tokio::test(flavor = "current_thread")] + async fn $test_name() { + task::LocalSet::new().run_until(async move $content).await; + } + } + } + + macro_rules! dummy_handle { + () => { + SocketHandle::default() + }; + } + + macro_rules! assert_next_socket_actions { + ($receiver:expr;) => { + // no more actions + }; + ($receiver:expr; $action:expr, $($more:expr,)*) => { + assert_eq!($receiver.recv().or(async_timeout!()).await.expect("receive action"), $action); + assert_next_socket_actions!($receiver; $($more,)*); + }; + } + + fn new_test_backend(socket_allow: bool) -> ExternalNavigatorBackend { + let url = Url::parse("https://example.com/path/").unwrap(); + ExternalNavigatorBackend::new( + url.clone(), + TestFutureSpawner, + None, + false, + OpenURLMode::Allow, + Default::default(), + if socket_allow { + SocketMode::Allow + } else { + SocketMode::Deny + }, + Rc::new(PlayingContent::DirectFile(url)), + (), + ) + } + + async fn start_test_server() -> (task::JoinHandle, SocketAddr) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let accept_task = task::spawn_local(async move { + let (socket, _) = listener.accept().or(async_timeout!()).await.unwrap(); + socket + }); + (accept_task, addr) + } + + fn connect_test_socket( + addr: SocketAddr, + timeout: Duration, + socket_allow: bool, + ) -> (Sender>, Receiver) { + let mut backend = new_test_backend(socket_allow); + + let (write, receiver) = async_channel::unbounded(); + let (sender, read) = async_channel::unbounded(); + + backend.connect_socket( + addr.ip().to_string(), + addr.port(), + timeout, + dummy_handle!(), + receiver, + sender, + ); + + (write, read) + } + + async fn write_server(server_socket: &mut TcpStream, data: &str) { + server_socket + .write(data.as_bytes()) + .or(async_timeout!()) + .await + .expect("server write"); + } + + async fn read_server(server_socket: &mut TcpStream) -> String { + let mut buffer = [0; 4096]; + + let read = match server_socket.read(&mut buffer).await { + Err(e) => { + panic!("server read error: {}", e); + } + Ok(read) => read, + }; + + let buffer = buffer.into_iter().take(read).collect::>(); + String::from_utf8(buffer).unwrap() + } + + async fn write_client(client_write: &Sender>, data: &str) { + client_write + .send(data.as_bytes().to_vec()) + .or(async_timeout!()) + .await + .expect("client write"); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_timeout() { + let (_accept_task, addr) = start_test_server().await; + let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT_ZERO, true); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::TimedOut), + ); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_connect() { + let (accept_task, addr) = start_test_server().await; + let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + let _server_socket = accept_task.await.unwrap(); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Connected), + ); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_deny() { + let (_accept_task, addr) = start_test_server().await; + let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, false); + + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Failed), + ); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_fail() { + let addr = SocketAddr::from_str("[100::]:42").expect("black hole address"); + let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Failed), + ); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_server_close() { + let (accept_task, addr) = start_test_server().await; + let (_client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + + let server_socket = accept_task.await.unwrap(); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Connected), + ); + + drop(server_socket); + + assert_next_socket_actions!( + client_read; + Close(dummy_handle!()), + ); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_client_close() { + let (accept_task, addr) = start_test_server().await; + let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + + let mut server_socket = accept_task.await.unwrap(); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Connected), + ); + + drop(client_write); + + assert_eq!(read_server(&mut server_socket).await, ""); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_basic_communication() { + let (accept_task, addr) = start_test_server().await; + let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + + let mut server_socket = accept_task.await.unwrap(); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Connected), + ); + + write_server(&mut server_socket, "Hello ").await; + write_server(&mut server_socket, "World!").await; + + assert_next_socket_actions!( + client_read; + Data(dummy_handle!(), "Hello World!".as_bytes().to_vec()), + ); + + write_client(&client_write, "Hello from").await; + write_client(&client_write, " client").await; + + assert_eq!(read_server(&mut server_socket).await, "Hello from client"); + + write_server(&mut server_socket, "from server 2").await; + write_client(&client_write, "from client 2").await; + + assert_next_socket_actions!( + client_read; + Data(dummy_handle!(), "from server 2".as_bytes().to_vec()), + ); + assert_eq!(read_server(&mut server_socket).await, "from client 2"); + } + + #[macro_rules_attribute::apply(async_test)] + async fn test_socket_flush_before_close() { + let (accept_task, addr) = start_test_server().await; + let (client_write, client_read) = connect_test_socket(addr, TIMEOUT, true); + + let mut server_socket = accept_task.await.unwrap(); + assert_next_socket_actions!( + client_read; + Connect(dummy_handle!(), ConnectionState::Connected), + ); + + write_client(&client_write, "Sending some").await; + write_client(&client_write, " data").await; + client_write.close(); + + assert_eq!(read_server(&mut server_socket).await, "Sending some data"); + } +}